In [None]:
#default_exp multi_core

# Multi Core XLA extensions

## Setup torch XLA


This is the official way to install Pytorch-XLA 1.7 [instructions here](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=CHzziBW5AoZH)

In [None]:
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 78kB/s 
[K     |████████████████████████████████| 61kB 3.6MB/s 
[?25h

## Install fastai

Use latest fastai and fastcore versions

In [None]:
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 5.8MB/s 
[K     |████████████████████████████████| 61kB 5.2MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


Attach repo

In [None]:
#hide
#colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#hide
#colab
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

In [None]:
%cd /content/fastai_xla_extensions/fastai_xla_extensions

Start of kernel

In [None]:
#exporti
from fastai_xla_extensions.utils import xla_imported
from fastai_xla_extensions.misc_utils import *

In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass



In [None]:
#hide
#local
if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#exporti
if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    import torch_xla.utils.utils as xu

In [None]:
#exporti
from fastcore.basics import patch_to
from fastai.optimizer import _BaseOptimizer
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.utils.data as th_data
from fastcore.foundation import L
from pathlib import Path
from fastcore.xtras import *
from fastcore.transform import Pipeline
from fastai.data.core import DataLoaders
from functools import partial
import torch.utils.data.distributed as torch_distrib
from pathlib import Path
import fastcore.xtras
import math
from fastcore.basics import store_attr
from operator import attrgetter
from fastai.data.load import _FakeLoader
from fastai.data.core import TfmdDL
from fastai.torch_core import find_bs, TensorBase
import random
import torch
from fastai.data.load import _loaders
from fastai.torch_core import to_device
from fastcore.basics import first, patch_to, patch

In [None]:
#export
import torch
from fastai.torch_core import find_bs, TensorBase
import math

def revert_tensor(o):
    """Remove tensor subclass info from class and revert back to Tensor"""
    try:
        o.__class__ = torch.Tensor
    except:
        raise RuntimeError(f'could not convert {o} to torch.Tensor')
    return o

def _recast2tensor(o):
    if isinstance(o,TensorBase):
        # return plain tensor since pl.parallelloader doesn't
        # seem to work with tensor subclasses
        # return torch.as_tensor(o.numpy())
        return revert_tensor(o)
    return o

def _round_to_multiple(number,multiple):
    return int(math.ceil(number/multiple)*multiple)

In [None]:
#export
class TPUDistributedDL(TfmdDL):
    """A `TfmdDL` which splits a batch into equal size pieces for each TPU core
       It also recasts the output of a batch from a Tensorbase subclass to
       a regular tensor since the XLA Parallel loader doesnt seem compatible
       to it.
    """
    _default = 'dl'
    def __init__(self,dl,rank,world_size, seed=42):
        store_attr()
        self.bs,self.device,self.num_workers,self.drop_last,self.dataset,self.offs,fake = \
            attrgetter('bs','device','num_workers','drop_last','dataset','offs','fake_l')(dl)
        self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout,
                                  persistent_workers=fake.persistent_workers)
        self.epoch = 0
        random.seed(self.seed)
        self.dl.rng = random.Random(random.randint(0,2**32-1))
        self.reset_rng()

    def reset_rng(self):
        random.seed(self.seed + self.epoch)
        self.rng = random.Random(random.randint(0,2**32-1))

    def __len__(self):
        return _round_to_multiple(len(self.dl),self.world_size)//self.world_size

    def set_epoch(self, epoch):
        self.epoch = epoch

    def get_idxs(self):
        idxs = self.dl.get_idxs()
        # do your own shuffling which factors in self.epoch + self.seed in
        # generating a random sequence (underlying self.dl does not)
        if self.shuffle:
            idxs = self.shuffle_fn(idxs)
        self.n = len(idxs)
        # we assumed n was dl.n but we really care about number of idxs
        # add extra samples to make it evenly divisible
        self.n_padded = _round_to_multiple(self.n,self.world_size)
        idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n]
        # idx needs to be repeated when n_padded>>n
        # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors
        start_pos = self.rank*self.n_padded//self.world_size
        end_pos = (self.rank+1)*self.n_padded//self.world_size
        return idxs[start_pos:end_pos]

    def before_iter(self):
        self.dl.before_iter()

    def randomize(self):
        self.reset_rng()
        self.dl.randomize()

    def after_batch(self,b):
        b = self.dl.after_batch(b)
        # recast tensor subclasses to plain tensors
        # undoing work of self.retain()
        tb = [_recast2tensor(o) for o in b]
        b = tuple(tb)
        return b

    def after_iter(self):
        self.dl.after_iter()

    def create_batches(self,samps):
        return self.dl.create_batches(samps)

    def to(self, device):
        self.dl.device = device
        self.device = device
        return self

    def one_batch(self):
        return self.dl.one_batch()

In [None]:
#exporti
from fastai.torch_core import default_device, apply
import torch
from fastcore.xtras import is_listy
import torch
import torch.utils.hooks
from fastcore.basics import patch
from fastai.torch_core import TensorBase
from collections import OrderedDict

In [None]:
#export
def make_distributed_dataloaders(dls, rank, world_size, sync_valid=False):
    """Wrap dataloaders with distributed TPU aware dataloader """
    new_loaders = []
    for i,dl in enumerate(dls.loaders):
        if i == 0 or sync_valid:
            use_rank = rank
            use_size = world_size
        else:
            use_rank = 0
            use_size = 1
        dl = TPUDistributedDL(dl,
                            rank=use_rank,
                            world_size=use_size)
        new_loaders += [dl]
    return DataLoaders(*new_loaders, path=dls.path, device=dls.device)

In [None]:
#export
# def DataBlock.dataloaders(self, source, path='.', verbose=False, **kwargs):
def make_fastai_dataloaders(datablock, source, rank, world_size, device=None, path='.', sync_valid=False, verbose=False,**kwargs):
    dls = datablock.dataloaders(source=source, path=path, device=device, **kwargs)
    distrib_dls = make_distributed_dataloaders(dls, rank, world_size, sync_valid=sync_valid)
    return distrib_dls

In [None]:
#export
def wrap_parallel_loader(loader, device):
    para_loader = pl.ParallelLoader(loader, [device])
    loop_loader = para_loader.per_device_loader(device)
    return loop_loader

In [None]:
#exporti
from fastai.callback.core import TrainEvalCallback
from fastai.learner import Recorder
from fastai.torch_core import one_param
import torch
from fastai.callback.core import Callback
from fastai.learner import CancelTrainException, CancelValidException, CancelStepException
from fastai.torch_core import tensor, TensorCategory

In [None]:
#export
class XLATrainingCallback(Callback):
    run_before = Recorder
    run_valid = False
    order = -5 # after TrainEvalCallback
    def __init__(self, device, rank=0, sync_valid=False):
        self.pdevice = device
        self.rank = rank
        self.sync_valid = sync_valid

    def before_fit(self):
       xm.master_print('start fit')

    def before_epoch(self):
        # set the epoch on train only to make sure shuffle produces same seq
        # across all ranks
        if hasattr(self.learn.dls.train,'sampler'):
            if hasattr(self.learn.dls.train.sampler,'set_epoch'):
                self.learn.dls.train.sampler.set_epoch(self.learn.epoch)
        elif hasattr(self.learn.dls.train,'set_epoch'):
            self.learn.dls.train.set_epoch(self.learn.epoch)

        if self.sync_valid: # update epoch on valid if sync_valid
            if hasattr(self.learn.dls.valid,'sampler'):
                if hasattr(self.learn.dls.valid.sampler,'set_epoch'):
                    self.learn.dls.valid.sampler.set_epoch(self.learn.epoch)
            elif hasattr(self.learn.dls.valid,'set_epoch'):
                self.learn.dls.valid.set_epoch(self.learn.epoch)

    def before_train(self):
        self.learn.dl = wrap_parallel_loader(self.dls.train, self.pdevice)

    def before_validate(self):
        "Set the model in validation mode"
        if self.rank != 0 and not self.sync_valid:
        # no need to compute valid loss/ metric if not master if not sync valid
            raise CancelValidException()
        self.learn.dl = wrap_parallel_loader(self.dls.valid, self.pdevice)

    def before_step(self):
        raise CancelStepException()

    def after_cancel_step(self):
        xm.optimizer_step(self.learn.opt)

In [None]:
#exporti
import copy
from fastcore.foundation import L
import torch

In [None]:
#exporti
from fastai.learner import _maybe_item
from fastprogress.fastprogress import format_time
import time

In [None]:
#export
def pack_metric(metrics):
    counts = metrics.attrgot('count',0)
    totals = metrics.attrgot('total',0)
    metrics_list = counts + totals
    return metrics_list

def make_tensor(o, device):
    if not isinstance(o, torch.Tensor):
        o = torch.tensor(o)
    return o.float().to(device)
    
def pack_metrics(all_metrics, device):
    metrics_list = pack_metric(all_metrics['train_mets']) + pack_metric(all_metrics['valid_mets'])
    return [make_tensor(item,device) for item in metrics_list ]

def restore_metrics(reduced_metrics, all_metrics):
    n_train = len(all_metrics['train_mets'])
    n_valid = len(all_metrics['valid_mets'])
    train_counts = reduced_metrics[:n_train]
    train_totals = reduced_metrics[n_train: n_train*2]
    valid_counts = reduced_metrics[n_train*2: n_train*2 + n_valid]
    valid_totals = reduced_metrics[n_train*2 + n_valid:]
    for i,metric in enumerate(all_metrics['train_mets']):
        if hasattr(metric,'count'):
            metric.count = train_counts[i].clone().detach().long()
        if hasattr(metric,'total'):
            metric.total = train_totals[i].clone().detach()
    for i,metric in enumerate(all_metrics['valid_mets']):
        if hasattr(metric,'count'):
            metric.count = valid_counts[i].clone().detach().long()
        if hasattr(metric,'total'):
            metric.total = valid_totals[i].clone().detach()
    return all_metrics  


In [None]:
#export
class SyncRecorderCallback(Callback):
    """Sync metrics from each spawned process update statistics
       accordingly so it will display correctly in the progress callback
    """
    order  = 55 # after Recorder, before ProgressCallback
    def __init__(self):
        pass

    def before_fit(self):
        if not xm.is_master_ordinal():
            return
        if 'progress' in self.learn.cbs.attrgot('name',None):
            self._sync_stats_log = self.progress._write_stats
        else:
            self._sync_stats_log = self.learn.logger

    def after_fit(self):
        pass
        # xm.rendezvous('sync recorder after_fit')

    def before_epoch(self):
        self.sync_log = copy.copy(self.recorder.log)

    def after_epoch(self):
        if 'recorder' not in self.learn.cbs.attrgot('name'):
            all_metrics = {
                'train_mets': L([]),
                'valid_mets': L([]),
            }
        else:
            all_metrics = {
                'train_mets': self.recorder._train_mets,
                'valid_mets': self.recorder._valid_mets,
            }
        # send metrics data to sync ranks across spawned processes
        device = self.learn.xla_training.pdevice
        packed_metrics = pack_metrics(all_metrics, device) # convert metrics to tensor list on TPU
        reduced_metrics = xm.all_reduce(xm.REDUCE_SUM, packed_metrics)
        xm.mark_step()
        if xm.is_master_ordinal():
            all_metrics = restore_metrics(reduced_metrics, all_metrics) # convert list to metric objects
            for m in self.recorder._train_mets:
                self.sync_log += _maybe_item(m)

            for m in self.recorder._valid_mets:
                self.sync_log += _maybe_item(m)

            self.learn.final_record = self.sync_log[:1].copy()
            del self.recorder.values[-1] # remove last entry added by recorder
            self.recorder.values.append(self.learn.final_record) # add updated metrics
            if self.recorder.add_time:
                updated_time = (time.time() - self.recorder.start_epoch)
                self.sync_log.append(format_time(updated_time))
            self.recorder.log = self.sync_log
            self._sync_stats_log(self.sync_log) # write_stats to output
            self.learn.logger = self.orig_logger # restore orig logger after skipping recorder.logger(log)

    def before_validate(self):
        pass

    def after_validate(self):
        if xm.is_master_ordinal():
            self.orig_logger = self.learn.logger
            self.learn.logger = noop # write to logger disabled so calling recorder.logger(log) wont print
        pass

    def before_batch(self):
        pass


In [None]:
#exporti
from fastcore.imports import noop
from fastcore.basics import patch
from fastai.learner import Learner
from fastai.callback.progress import ProgressCallback
from fastcore.xtras import join_path_file
from fastai.torch_core import get_model

In [None]:
#exporti
@patch
def save(self:Learner, file, **kwargs):
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    with_opt = self.opt is not None
    state = self.model.state_dict()
    if with_opt:
        # add opt state to state to be saved
        opt_state = self.opt.state_dict()
        state = {'model': state, 'opt':opt_state}
    xm.save(state, file) # use xm.save instead of torch.save
    return file

In [None]:
#exporti
@patch
def to_xla(self:Learner,device, rank, sync_valid=False):
    if 'xla_training' not in self.cbs.attrgot('name'):
        self.dls.device = None
        self.add_cbs(XLATrainingCallback(device, rank, sync_valid=sync_valid))
    else:
        self.xla_training.pdevice = device
        self.xla_training.rank = rank
        self.xla_training.sync_valid = sync_valid

    if sync_valid and 'sync_recorder' not in self.cbs.attrgot('name'):
        self.add_cbs(SyncRecorderCallback)
    elif not sync_valid:
        self.remove_cbs(SyncRecorderCallback)

    if rank != 0: # progress bar only for rank 0
        self.remove_cbs(ProgressCallback)
    self.logger = xm.master_print

In [None]:
#export
# for testing
def do_one_loop(dls, rank, world_size, device, sync_valid, is_train=True):
    if is_train:
        dl = dls.train
    else:
        dl = dls.valid

    n_batches = len(dl)
    print(f'xla: {rank} world_size: {world_size} train:{is_train} n_batches:{n_batches} sync_valid: {sync_valid}')

    if sync_valid or is_train or rank == 0:
        print(f'xla: {rank} wrapping ploader')
        pdl = wrap_parallel_loader(dl, device=device)
    for i,b in enumerate(pdl):
        if i > 1:
            break
        xb, yb = b
        print(f'xla: {rank} iter:{i} xb.shape {xb.shape} yb.shape: {yb.shape}')
        print(f'xla: {rank} iter:{i} xb.device {xb.device} yb.device: {yb.device}')
        print(f'xla: {rank} iter:{i} xb.dtype {xb.dtype} yb.device: {yb.dtype}')

## Test out the code


In [None]:
#colab
from functools import partial
from fastai.metrics import accuracy
from fastai.optimizer import SGD, Adam

from fastcore.basics import first
from fastai.callback.schedule import *
from fastai.test_utils import VerboseCallback
from my_timesaver_utils.profiling import *
from my_timesaver_utils.profiling_callback import *

In [None]:
#colab
def train_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_model')
    print(f'xla {rank} start train model')

    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    SYNC_VALID = FLAGS['sync_valid']
    IS_PROFILING = FLAGS['is_profiling']
    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']

    world_size = xm.xrt_world_size()
    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name}')
        start_record(rec_name)

    dls = make_fastai_dataloaders(
                            DATA, 
                            PATH, 
                            rank=rank, 
                            world_size=world_size, 
                            sync_valid=SYNC_VALID,
                            bs=bs,)
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms)
                      
    learner.to_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if rank == 0:
        learner.to_my_profile()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()
    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)

    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')

    learner.save('stage-1')
    if rank == 0:
        learner.my_profile.print_stats()
    xm.mark_step()  
    xm.rendezvous('end_train_model')
    if IS_PROFILING:
        clear_prof_data() 


In [None]:
#colab
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    train_model(rank)


In [None]:
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock

In [None]:
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats

In [None]:
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()

In [None]:
from fastai.optimizer import Adam
OPT_FUNC = Adam

In [None]:
from fastai.data.transforms import RandomSplitter

In [None]:
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34

In [None]:
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = False
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 32
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-4
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 
ARCH = resnet34

In [None]:
from pathlib import Path
from fastcore.xtras import *


In [None]:
#colab
PATH = untar_data(URLs.PETS)/'images'
# PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)


In [None]:
#colab

pat = r'(.+)_\d+.jpg$'
fname_labeller = using_attr(RegexLabeller(pat),'name') 
splitter=RandomSplitter(seed=42)
DATA = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=fname_labeller,
    splitter=splitter,
    item_tfms=[Resize(FLAGS['image_size']),],
    batch_tfms=[]
)
vocab = CategoryMap(get_image_files(PATH).map(fname_labeller))
N_OUT = len(vocab)


In [None]:
#colab
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'

In [None]:
#colab
custom_model = create_cnn_model(ARCH, N_OUT, 
                                pretrained=True,
                                concat_pool=False)


In [None]:
#colab
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)

In [None]:
#colab
%%time
# !rm -f /content/models/stage-1.pth
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')


build learner
start running fit
start fit


epoch,train_loss,valid_loss,accuracy,time
0,1.20713,4.118487,0.047852,00:30
1,1.725443,7.102302,0.06543,00:43
2,2.201487,4.366318,0.119141,00:40
3,2.506165,3.342862,0.15332,00:41
4,2.727757,3.203462,0.166992,00:42


fit  called 1 times. max: 199.537 avg: 199.537
   epoch  called 5 times. max: 43.851 avg: 39.726
      train  called 5 times. max: 37.078 avg: 33.036
         train_batch  called 55 times. max: 6.104 avg: 1.449
            train_pred  called 55 times. max: 5.079 avg: 0.705
            train_loss  called 55 times. max: 0.041 avg: 0.005
            train_backward  called 55 times. max: 0.245 avg: 0.042
            train_step  called 55 times. max: 1.463 avg: 0.628
            train_zero_grad  called 55 times. max: 0.145 avg: 0.066
      valid  called 5 times. max: 6.849 avg: 6.639
         valid_batch  called 10 times. max: 0.087 avg: 0.049
            valid_pred  called 10 times. max: 0.073 avg: 0.040
            valid_loss  called 10 times. max: 0.023 avg: 0.009
CPU times: user 111 ms, sys: 126 ms, total: 237 ms
Wall time: 4min 4s


In [None]:
#colab
DATA.summary(PATH)

In [None]:
#colab
mdls = DATA.dataloaders(PATH, bs=FLAGS['batch_size'])

In [None]:
#colab
mlearner = Learner(mdls, custom_model, 
                    loss_func=LOSS_FUNC, 
                    opt_func=OPT_FUNC, 
                    metrics=accuracy, 
                    wd=FLAGS['weight_decay'],
                    moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
mlearner.load('stage-1')

<fastai.learner.Learner at 0x7f0e778ba198>

In [None]:
#colab
mlearner.dls.device

In [None]:
from fastai.torch_core import one_param

In [None]:
#colab
one_param(mlearner.model).device

device(type='cpu')

In [None]:
#colab
%%time
valid_metrics = mlearner.validate();print(valid_metrics)

[3.198190450668335, 0.16915760934352875]
CPU times: user 9.33 s, sys: 198 ms, total: 9.53 s
Wall time: 14 s


## Expose xla fit methods on learner to simplify usage

In [None]:
#exporti
# from fastai.vision.all import *
# import torch_xla.core.xla_model as xm
# import torch_xla.distributed.xla_multiprocessing as xmp
# import torch

In [None]:
#export
def _make_xla_child_learner(rank, sync_valid,learner_args):
    sync_valid = True
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
    dls = make_distributed_dataloaders(learner_args.pop('base_dls'), 
                                       rank, world_size, sync_valid=sync_valid)
    
    model = learner_args.pop('wrapped_model').to(device)

    learner = Learner(dls, model,**learner_args)
    learner.to_xla(device, rank, sync_valid=sync_valid)
    return learner
    


In [None]:
#export
def _xla_run_fit(rank, learner_args, fit_args):
    sync_valid = True
    learner = _make_xla_child_learner(rank, sync_valid, learner_args)    
    learner.fit(**fit_args)
    learner.save('_xla_tmp_model')
    xm.mark_step()

In [None]:
#export
def _xla_run_fit_one_cycle(rank, learner_args, fit_args):
    sync_valid = True
    learner = _make_xla_child_learner(rank, sync_valid, learner_args)      
    learner.fit_one_cycle(**fit_args)
    learner.save('_xla_tmp_model')
    xm.mark_step()

In [None]:
#export
from fastcore.basics import defaults, patch_to, patch
from fastai.learner import Learner
from fastai.callback.progress import ProgressCallback
@patch_to(Learner)
def pack_learner_args(self):
    learner_args = {}
    learner_args['wrapped_model'] =  xmp.MpModelWrapper(self.model)
    learner_args['base_dls'] = self.dls
    learner_args['opt_func'] = self.opt_func
    learner_args['loss_func'] = self.loss_func
    learner_args['metrics'] = self.metrics
    # fetch only cbs not in defaults
    if ProgressCallback not in defaults.callbacks:
        defaults.callbacks.append(ProgressCallback)
    learner_args['cbs'] = [cb for cb in self.cbs 
                      if cb.name not in L(defaults.callbacks).attrgot('name')]

    learner_args['wd'] = self.wd
    learner_args['moms'] = self.moms
    learner_args['lr'] = self.lr
    learner_args['splitter'] = self.splitter
    learner_args['path'] = self.path
    learner_args['model_dir'] = self.model_dir
    learner_args['wd_bn_bias'] = self.wd_bn_bias
    learner_args['train_bn'] = self.train_bn
    return learner_args

In [None]:
#export
@patch_to(Learner)
def reload_child_model(self):
    # blatantly stolen from fastai LRFinder after_fit :)
    tmp_f = self.path/self.model_dir/'_xla_tmp_model.pth'
    if tmp_f.exists():
        self.opt.zero_grad() 
        self.load('_xla_tmp_model', with_opt=False)
        os.remove(tmp_f)  
        self.create_opt()  

In [None]:
#export

from fastcore.meta import delegates
@patch
@delegates(Learner.fit, but='num_cores')
def xla_fit(self:Learner, n_epoch, num_cores=8, **kwargs):
    """call fit in multicore tpu environment"""
    learner_args = self.pack_learner_args()
    fit_args={**kwargs}
    fit_args['n_epoch'] = n_epoch
    xmp.spawn(_xla_run_fit,
              args=(learner_args, fit_args,),
              nprocs=num_cores,
              start_method='fork')
    self.reload_child_model()


In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *
@patch
@delegates(Learner.fit_one_cycle, but='num_cores')
def xla_fit_one_cycle(self:Learner, n_epoch, num_cores=8, **kwargs):
    """call fit_one_cycle in multicore tpu environment"""
    learner_args = self.pack_learner_args()
    fit_args={**kwargs}
    fit_args['n_epoch'] = n_epoch
    xmp.spawn(_xla_run_fit_one_cycle,
              args=(learner_args, fit_args,),
              nprocs=num_cores,
              start_method='fork')
    self.reload_child_model()


In [None]:
#colab
path = untar_data(URLs.MNIST_TINY)

In [None]:
#colab
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
#colab
dls = data.dataloaders(path, bs=16)

In [None]:
#colab
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)

In [None]:
#colab
#hide
assert hasattr(learner,'xla_fit')

In [None]:
#colab
%%time
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))


start fit


epoch,train_loss,valid_loss,accuracy,time


epoch,train_loss,valid_loss,accuracy,time
0,0.204781,0.384089,0.927557,00:14
1,0.184112,0.361853,0.828125,00:05
2,0.212741,0.367217,0.846591,00:05
3,0.243097,0.301363,0.879261,00:05
4,0.253019,0.289174,0.901989,00:05


CPU times: user 164 ms, sys: 229 ms, total: 392 ms
Wall time: 59.6 s


In [None]:
#colab
res = learner.get_preds()
print(accuracy(*res))

TensorBase(0.8984)


In [None]:
#colab
#hide
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                      

In [None]:
#colab
learner.unfreeze()

In [None]:
#colab
#hide
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       True      
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                      

In [None]:
#colab
learner.xla_fit(n_epoch=5, lr=2e-3)

start fit


epoch,train_loss,valid_loss,accuracy,time


epoch,train_loss,valid_loss,accuracy,time
0,0.073536,0.028147,0.990057,00:22
1,0.082206,0.767614,0.946023,00:05
2,0.073394,1.637862,0.84375,00:05
3,0.071335,0.842607,0.901989,00:05
4,0.075052,0.100458,0.974432,00:05


In [None]:
#colab
learner.validate()

(#2) [0.04073491320014,0.9871244430541992]