Skip to content
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
Cannot retrieve contributors at this time

Running on TPUs


All Tensorflow-compatible rerankers support training and inference on Google TPUs. Capreolus has been tested with both v2-8 TPUs and v3-8 TPUs.

To use a TPU with a Tensorflow-compatible Reranker (i.e., a reranker that depends on the tensorflow Trainer module), set the following config options:

  • tpuname: the name of your TPU, such as mytpu1. If you are using a TPU VM, set this to LOCAL
  • storage: path to a GCS bucket where data should be stored, such as gs://your-bucket/abc/. If you are using a TPU VM, this can also be a path on the TPU VM itself
  • tpuzone: the cloud zone your TPU is in, such as us-central1-f

It's recommended that you also set usecache=True with the trainer and extractor.

After setting these options, you can run Capreolus as normal. Watch for INFO logging messages at the beginning of training to confirm the TPU is being used.

.. note:: While any Tensorflow-compatible `Reranker` can be used with TPUs, this will actually slow down small models like KNRM. TPUs are most useful with large Transformer-based models.
.. warning:: TPUs may stream their data from buckets in Google Cloud Storage rather than reading their input data from the local machine (i.e., the machine running Capreolus). Capreolus will automatically preprocess and upload the data to this bucket. However, note that GCS is not free and the user is responsible for manually deleting this data once it is no longer needed.


The following models are good candidates for running on TPUs:

.. autoapiclass:: capreolus.reranker.TFBERTMaxP.TFBERTMaxP
   .. autoapiattribute:: module_name
.. autoapiclass:: capreolus.reranker.TFCEDRKNRM.TFCEDRKNRM
   .. autoapiattribute:: module_name
.. autoapiclass:: capreolus.reranker.parade.TFParade
   .. autoapiattribute:: module_name