In [None]:
# !curl -s https://course19.fast.ai/setup/colab | bash

In [None]:
pip install -Uqq fastai fastcore fastprogress --upgrade

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

In [None]:
from fastai.vision.all import *

In [None]:
path = untar_data(URLs.MNIST_TINY)

In [None]:
Path.BASE_PATH = path

In [None]:
path.ls()

In [None]:
(path/'test').ls()

In [None]:
test_images = get_image_files(path/'test')

In [None]:
datablock = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
dls = datablock.dataloaders(path)

In [None]:
dls.show_batch()

In [None]:
import torch_xla.core.model as xm

class XLAOptimProxy:
    "Proxy optimizer to override `opt.step` with Pytorch XLA sync method `xm.optimizer_step` "  
    def __init__(self,opt, barrier=True):
        self.opt = opt
        self._barrier = barrier
        
    def xla_step(self):
        xm.optimizer_step(self.opt,barrier=self._barrier) # sync on gradient update
        
    def __getattr__(self,name):
        if name == 'step': # override proxying for step
            return getattr(self,'xla_step')
        if name in ('barrier','_barrier'):
            return getattr(self,name)
      
        # proxy everything else
        return getattr(self.opt,name)
    @property
    def barrier(self): return self._barrier
    @barrier.setter
    def barrier(self,v): self._barrier = v 

In [None]:
from fastai.callback.core import Callback
from fastai.data.core import DataLoaders
from fastai.vision.all import to_device

class XLAOptCallback(Callback):
    'Callback to replace `opt.step` with `xm.optimizer_step(opt)` as required to run on TPU'
    def __init__(self, barrier=True):
        self._barrier = barrier

    def before_fit(self):
        'replace opt with proxy which calls `xm.optimizer_step` instead of `opt.step` and set `dls.device` and model to `xla_device`'
        to_device(self.dls, device=xm.xla_device())
        self.model.to(self.dls.device)
        if self.learn.opt is not None:
            if not isinstance(self.learn.opt,XLAOptimProxy):
                opt = self.learn.opt
                self.learn.opt = XLAOptimProxy(opt, barrier=self._barrier)
                
    def after_fit(self):
        'restore original opt '
        if isinstance(self.learn.opt, XLAOptimProxy):
            opt = self.learn.opt.opt
            self.learn.opt = opt
#         to_device(self.dls, default_device())
#         self.model.to(self.dls.device)
    @property
    def barrier(self): return self._barrier

In [None]:
def myconv(conv_args=None, maxpool_args=None):
    conv = nn.Conv2d(*conv_args) if conv_args is not None else None
    relu = nn.ReLU()
    maxpool = nn.MaxPool2d(*maxpool_args) if maxpool_args is not None else None
    layers = filter(lambda layer: layer is not None, [conv,relu, maxpool])
    seq = nn.Sequential(*layers)
    return seq

def mylinear(lin_args=None):
    layers = []
    for i,lin_arg in enumerate(lin_args):
        layers += [nn.Linear(*lin_arg)]
        if i < len(lin_args) - 1: # add relu to layers between
            layers += [nn.ReLU()]
    seq = nn.Sequential(*layers)
    return seq
  
    
class Lenet2(Module):
    def __init__(self):
        self.conv1 = myconv(conv_args=(3,6,3),maxpool_args=(2,2))
        self.conv2 = myconv(conv_args=(6,16,3),maxpool_args=(2,))
        self.lin1 = mylinear(lin_args=[(400,120),(120,84),(84,2)])
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten
        x = x.view(-1, self.num_flat_features(x))
        x = self.lin1(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


In [None]:
def lenet_split(model):
    m = L(model.modules())
    groups = L([m[1],m[5],m[9]])
    return groups.map(params)

In [None]:
lenet2 = Lenet2()

In [None]:
learn = Learner(dls, lenet2, splitter=lenet_split, metrics=accuracy)

In [None]:
learn.summary()

In [None]:
valid_preds,valid_targs = learn.get_preds(dl=dls.valid)

In [None]:
len(valid_preds)

In [None]:
len(valid_targs)

In [None]:
valid_acc = accuracy(valid_preds,valid_targs);valid_acc

Make a test dataloader and predict an untrained learner

In [None]:
test_dl = learn.dls.test_dl(test_images,with_labels=True)

In [None]:
preds,targs = learn.get_preds(dl=test_dl)

Accuracy should be no better than random (0.5)

In [None]:
test_acc = accuracy(preds, targs);test_acc

Load learner or train model

In [None]:
learn.fit_one_cycle(10)

In [None]:
learn.save('lenet2-stage1a')

In [None]:
# learn.load('lenet2-stage1a')

In [None]:
valid_preds2,valid_targs2 = learn.get_preds(dl=dls.valid)

In [None]:
valid_acc2 = accuracy(valid_preds2,valid_targs2);valid_acc2

In [None]:
learn.freeze()

In [None]:
learn.summary()

Make a prediction on pretrained learner

In [None]:
preds2,targs2 = learn.get_preds(dl=test_dl)

Expect high accuracy on pretrained learner

In [None]:
test_acc2 = accuracy(preds2, targs2);test_acc2

In [None]:
len(learn.opt.param_groups)

In [None]:
model = learn.model

In [None]:
learn.fine_tune(1)