<a href="https://colab.research.google.com/github/florescl/composer/blob/laura%2Ftpu-perf/TPU_Training_in_composer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### TPU training in composer

Composer provides beta support for single core training on TPUs. 
In this tutorial, we walk through how to train ReSnet-20 on CIFAR10 with minimal changes in composer. 





As prerequisites, first install torch_xla and composer.

In [None]:
!pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
%pip install mosaicml

from composer import Trainer
from composer import models


Define the model, import xla and transfer the model to the xla device. This 
step needs to be performed before the optimizer is constructed. 

In [None]:
import torch
import torch_xla.core.xla_model as xm

model = models.composer_resnet_cifar(model_name='resnet_20', num_classes=10)
model = model.to(xm.xla_device())

Now, let's load the CIFAR10 dataset, transforms and define the dataLoader, just like you would for your pytorch models.

In [None]:
from torchvision import datasets, transforms

data_directory = "../data"

# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Similarly, we can define an optimizer in the same way as above.

In [None]:
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.02,
    momentum=0.9)


Now, we are ready to define the composer trainer by simply adding `device="tpu"`. 

Now the model is ready to be trained on a single core TPU. Stay tuned for the next composer release for alpha support for multi core TPUs.

In [None]:
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    device="tpu",
    eval_dataloader=test_dataloader,
    optimizers=optimizer,
    max_duration='20ep',
    eval_interval=1,
)

trainer.fit()