Skip to content
Jeremy Howard edited this page Aug 28, 2020 · 1 revision

TPU functionality in PyTorch

TPU functionality in PyTorch is provided by the PyTorch XLA package (developed by both Google and Facebook). Check the documentation and some publicly available code (like this tutorial) for example usage.

In summary, TPUs are comprised of multiple cores (usually 8 cores). Training of models can be done on an individual core, or on all 8 cores. Each core is treated as an XLA device, similar to a CUDA or CPU device in PyTorch. Given this, training on a single TPU core is as simple as initializing an XLA device and putting the data batch and model onto the device. The optimizer step is also wrapped with the XLA-specific xm.optimizer_step function. Training on 8 cores is more complicated, and is akin to multi-GPU training. It uses a multiprocessing interface, where a training loop function is spawned as a process run on each core, with the data being divided equally among the cores. The gradients and backpropagation steps are synced between the cores. It requires the use of the special ParallelLoader functionality to have proper dataloading functionality on each of the cores. Here is example code from the documentation:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  para_loader = pl.ParallelLoader(train_loader, [device])

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in para_loader.per_device_loader(device):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

TPU support in fastai

There have been multiple efforts to establish TPU support in the fastai library (ordered chronologically):

  • Efforts by @ilovescience starting as early as Sept. 2019. See the forum threads here and here. While functional callbacks for single-core and multi-core TPU training was developed for fastai v1, they had suboptimal performance (compared to multi-GPU setups). @ilovescience is further working on fastai v2 support for multi-core TPU training. See this repository. Much of this work was supported by @sgugger (Sylvain), @TomB, and PyTorch XLA developers.
  • Code by @kcturgutlu in May 2020. In this Kaggle Kernel, a TPU Learner was created for supporting multi-core TPU training in fastai v1.
  • Efforts by @butchland and @tyoc213 starting in Jun. 2020 originally proposed for the Global Summer PyTorch Hackathon. This team has focused mainly on developing tools for single-core TPU training in fastai v2. See this repository. While they have a minimally working example, they are working on fixing some more bugs and performance issues.

We plan to consolidate our efforts in the near future, and our progress is documented in the TPU support thread, with more detailed information about specific issues, design choices, and other information.

Proposed workflow:

Ideally, TPU training would be as simple as adding a flag:

learn = Learner(...).to_tpu()

where to_tpu would add a callback to the Learner allowing for TPU training. Regarding single vs. multi-core training, there could be two callbacks for the two situations, or one callback that handles both. Either way, to the user, it should be as simple as specifying number of cores (8 cores likely being the default):

learn = Learner(...).to_tpu(n_cores=8)

For 8-core TPU training, it's important to note that the training loop must be put into a separate function like so:

def train_loop(index):
    train_df = ...
    food = DataBlock(...)
    dls = food.dataloaders()
    learn = cnn_learner(dls, model, metrics).to_tpu_distributed() #adds the TPU callback
    learn.fit(3)
xmp.spawn(train_loop,nprocs=8,args=())

This is likely problematic from a user perspective. We are not sure if the DataLoader and Learner object have to be defined in the spawned function, though this issue suggests otherwise, though it places restrictions on the pickling of the dataset (should be fine for most cases).

In that case, the callback must cancel the original training loop, spawn 8 process, and call the training loop again in the spawned process. Once the current draft of the fastai v2 callback for multi-core TPU training (over [here]) works, then this can be fleshed out further and tested.

Future Challenges,

  1. Progress bars with multi-core TPUs: Currently, on multi-core TPUs, fastai2 has multiple progress bars since there are 8 processes. So we need to have a single progress bar that is synced over all 8 cores.
  2. Batch transforms As far as I am aware, batch transforms has not been done with PyTorch XLA before. Batch transforms for Tensorflow TPU training has been done though. See the Flower Classification competition kernels for inspiration. There are apparently even TPU batch transforms for MixUp and CutMix (this kernel). Since TPU training can be greatly affected by I/O bottlenecks, batch transforms are likely very important for achieving large speed-ups. @butchland currently has an ongoing issue in the PyTorch XLA repository observing performance issues with batch transforms.

We may run into other challenges along the way as well, and this wiki will be updated accordingly.

Clone this wiki locally