<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/samples/minimal_fastai_pytorch_tpu_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Minimal fastai torch tpu training example

> Train models using plain pytorch models, datasets and dataloaders using the fastai training loop on TPUs.

Using pytorch datasets and dataloaders, we train plain pytorch models using fastai's training loop on TPUs using `torch-xla` and the `fastai_xla_extensions` package.

Inspired by Zach Mueller's minimal fastai example
from the [fastai-minima package](https://pypi.org/project/fastai-minima/) and [Pytorchtofastai blog post](https://muellerzr.github.io/fastblog/2021/02/14/Pytorchtofastai.html) 



Assumptions:
 * python 3.7 install (Google Colab default)

## Installation and Setup

Install torch 1.7.1

In [1]:
!pip install -qqq --no-cache-dir torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

[K     |████████████████████████████████| 735.4MB 1.1MB/s 
[K     |████████████████████████████████| 12.8MB 77.8MB/s 
[K     |████████████████████████████████| 7.0MB 5.9MB/s 
[?25h

(Optional) Link fastai data and model dirs to content dir

In [3]:
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


Install fastai  

In [2]:
!pip install -Uqq fastai==2.3.0

[K     |████████████████████████████████| 194kB 4.7MB/s 
[K     |████████████████████████████████| 61kB 3.8MB/s 
[?25h

Install fastai_xla_extensions

In [4]:
!pip install -Uqq fastai_xla_extensions
# !pip install -Uqq git+https://github.com/butchland/fastai_xla_extensions.git


Install torch-xla 1.7

In [5]:
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl
# VERSION = "1.7" #@param ["1.5" , "20200707", "20200325", "nightly", "1.7"]
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null
# !python pytorch-xla-env-setup.py --version $VERSION > /dev/null

[K     |████████████████████████████████| 133.6MB 32kB/s 
[K     |████████████████████████████████| 61kB 3.2MB/s 
[31mERROR: earthengine-api 0.1.258 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
[?25h

Document package versions

In [6]:
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.1+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.8.0
torchvision==0.8.2+cu101
fastai==2.3.0
fastai-xla-extensions==0.0.11
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.6


## Model Training

Import `fastai` and `fastai_xla_extensions` packages

In [1]:
from fastai.vision.all import *
from fastai_xla_extensions.all import *



Use plain pytorch datasets and dataloaders

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
norm = transforms.Normalize(
    mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

transform = transforms.Compose(
    [transforms.ToTensor(),
     norm])

dset_train = torchvision.datasets.CIFAR10(root='/content/data', train=True,
                                        download=True, transform=transform)

dset_test = torchvision.datasets.CIFAR10(root='/content/data', train=False,
                                       download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(dset_train, batch_size=64,
                                          shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(dset_test, batch_size=64,
                                         shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


Use plain pytorch model

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Use plain pytorch loss functions

In [4]:
criterion = nn.CrossEntropyLoss()

Comment out fastai-minima code, as this example relies on fastai code directly.

In [5]:
# from torch import optim
# from fastai_minima.optimizer import OptimWrapper
# # from fastai_minima.learner import Learner, DataLoaders
# from fastai_minima.callback.training import CudaCallback, ProgressCallback

Wrap the pytorch SGD optimizer with fastai's `OptimWrapper`. 

In [6]:
from torch import optim

# def opt_func(params, **kwargs): 
#     return OptimWrapper(optim.SGD,params, **kwargs)
opt_func = partial(OptimWrapper, opt=optim.SGD)

Wrap the pytorch train and test dataloaders with fastai's `DataLoaders` class.

In [7]:
dls = DataLoaders(trainloader, testloader)

Create a fastai `Learner` which ties together the dataloaders, model, loss function and optimizer.

Also add in a fastai metrics function to monitor performance during training.


In [8]:
learn = Learner(dls, Net(), loss_func=criterion, opt_func=opt_func, metrics=accuracy)

# To use the GPU, do 
# learn = Learner(dls, Net(), loss_func=criterion, opt_func=opt_func, cbs=[CudaCallback()])

You can use the fastai_xla_extensions `xla_` functions to train it on the TPU (by default it uses 8 TPU cores for multi tpu training).

Notice the number of batches per epoch is divided by 8 (as the batches shown are per TPU core). 

Instead of using the fastai's plain `fit` method, we opt to use `fit_one_cycle` for cyclic training, which improves convergence by varying the learning rate during training.


In [9]:
# learn.fit(2, lr=0.001)
# learn.xla_fit(20, lr=0.02)

We also include fastai's `SaveModelCallback` which will save the best performing model during training.

Note that the `SaveModelCallback` is set to run only on the master ordinal process because running save model callback will overwrite each other if run on multiple processes at the same time.

In [10]:
# learn.fit(2, lr=0.001)
learn.xla_fit_one_cycle(20, lr_max=slice(2e-1), master_cbs=[SaveModelCallback()])

start fit


epoch,train_loss,valid_loss,accuracy,time
0,1.968986,2.249635,0.1815,00:24
1,2.031315,1.932846,0.3037,00:19
2,1.799929,1.664288,0.3947,00:18
3,1.62438,1.545714,0.4403,00:18
4,1.492614,1.428614,0.4855,00:18
5,1.397967,1.369284,0.5175,00:19
6,1.325682,1.306306,0.533,00:19
7,1.270271,1.303401,0.537,00:18
8,1.223189,1.248857,0.5556,00:18
9,1.167863,1.229344,0.5665,00:18


Better model found at epoch 0 with valid_loss value: 2.2496345043182373.
Better model found at epoch 1 with valid_loss value: 1.9328464269638062.
Better model found at epoch 2 with valid_loss value: 1.6642876863479614.
Better model found at epoch 3 with valid_loss value: 1.5457141399383545.
Better model found at epoch 4 with valid_loss value: 1.428613543510437.
Better model found at epoch 5 with valid_loss value: 1.3692837953567505.
Better model found at epoch 6 with valid_loss value: 1.3063063621520996.
Better model found at epoch 7 with valid_loss value: 1.3034014701843262.
Better model found at epoch 8 with valid_loss value: 1.2488573789596558.
Better model found at epoch 9 with valid_loss value: 1.2293442487716675.
Better model found at epoch 10 with valid_loss value: 1.2190930843353271.
Better model found at epoch 11 with valid_loss value: 1.1849007606506348.
Better model found at epoch 12 with valid_loss value: 1.143892526626587.
Better model found at epoch 13 with valid_loss val

## Model Checkpointing and Performance Evaluation

We can check that the best performing model has been saved to the learner by comparing the best performing model (stored in `model.pth` by the `SaveModelCallback`) is also the one loaded in the learner even though it is not the last one made during training.


In [11]:
learn.validate()

(#2) [1.0895700454711914,0.6273000240325928]

In [12]:
learn.save('stage-1')

Path('models/stage-1.pth')

In [13]:
learn.load('model')

  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


<fastai.learner.Learner at 0x7f311809be50>

The validation performance of the best model should be the same as the validation performance of the model after training.


In [14]:
learn.validate()

(#2) [1.0895700454711914,0.6273000240325928]