In [None]:
#hide
#colab
# attach gdrive holding repo
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#default_exp multi_core.base

# Multi Core XLA Base 

<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03_multi_core.base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Base module for Multi TPU Core implementation

Multi-core TPU implementation is enabled by importing this module.
```
from fastai_xla_extensions.multi_core.base import *
```

In [None]:
#hide
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 90kB/s 
[K     |████████████████████████████████| 61kB 3.0MB/s 
[?25h

In [None]:
#hide
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 2.9MB/s 
[K     |████████████████████████████████| 61kB 2.9MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!pip install -qqq nbdev --upgrade

[?25l[K     |███████▏                        | 10kB 20.2MB/s eta 0:00:01[K     |██████████████▎                 | 20kB 23.2MB/s eta 0:00:01[K     |█████████████████████▍          | 30kB 11.0MB/s eta 0:00:01[K     |████████████████████████████▌   | 40kB 9.0MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 3.0MB/s 
[?25h

In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast
!pip freeze | grep timesaver
!pip freeze | grep nbdev

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5
my-timesaver-utils==0.0.2
nbdev==1.1.12


In [None]:
#hide
#colab
# link repo to work dir
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

/content


In [None]:
#hide
# <!-- Start of kernel -->

In [None]:
#hide
#colab
%cd /content/fastai_xla_extensions

/content/drive/MyDrive/fastai_xla_extensions


In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#exporti

#from fastai.vision.all import *
from fastai_xla_extensions.utils import xla_imported
from fastai_xla_extensions.misc_utils import *
from fastai_xla_extensions.core import XLAOptCallback



In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass

In [None]:
#hide

# fake out torch_xla modules if not running on xla supported envs
if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#exporti

if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl

In [None]:
#exporti

import time
import torch
from fastcore.foundation import L
from fastai.data.core import DataLoaders
import math
from fastcore.basics import store_attr
from operator import attrgetter
from fastai.data.load import _FakeLoader

from fastai.torch_core import TensorBase
import random
from fastcore.basics import patch

In [None]:
#export

def revert_tensor(o):
    "Remove tensor subclass and revert to `torch.Tensor`"
    try:
        o.__class__ = torch.Tensor
    except:
        raise RuntimeError(f'could not convert {o} to torch.Tensor')
    return o

def recast2tensor(o):
    "Recast `fastai.torch_core.TensorBase` subclassed tensors to torch.Tensors"
    if isinstance(o,TensorBase):
        # return plain tensor since pl.parallelloader doesn't
        # seem to work with tensor subclasses
        # return torch.as_tensor(o.numpy())
        # TODO: recreate bug in notebook gist to file bug to torch_xla team
        return revert_tensor(o)
    return o

def round_to_multiple(number,multiple):
    "round up batch samples to fill number of cores"
    return int(math.ceil(number/multiple)*multiple)

In [None]:
#export

from fastai.data.core import TfmdDL

class TPUDistributedDL(TfmdDL):
    """A `TfmdDL` which splits a batch into equal size pieces for each TPU core
       It also recasts the output of a batch from a TensorBase subclass to
       a regular tensor since the XLA Parallel loader doesn't seem to be compatible
       to it.
       Code implementation was based on @tmabraham's `TPUDistributedDL` implementation
       here: https://github.com/tmabraham/fastai_tpu/blob/master/fastai_v2/tpu_distributed_dl.py
    """
    _default = 'dl'
    def __init__(self,dl,rank,world_size, seed=42):
        store_attr()
        self.bs,self.device,self.num_workers, \
        self.drop_last,self.dataset,self.offs,fake, self.shuffle = \
            attrgetter('bs','device','num_workers',
                       'drop_last','dataset','offs','fake_l', 'shuffle')(dl)
        self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout,
                                  persistent_workers=fake.persistent_workers)
        self.epoch = 0
        random.seed(self.seed)
        # setting inner dl rng
        self.dl.rng = random.Random(random.randint(0,2**32-1))
        self.reset_rng()

    def reset_rng(self):
        random.seed(self.seed + self.epoch)
        # setting outer dl rng
        self.rng = random.Random(random.randint(0,2**32-1))

    def __len__(self):
        return round_to_multiple(len(self.dl),self.world_size)//self.world_size

    def set_epoch(self, epoch):
        self.epoch = epoch

    def get_idxs(self):
        idxs = self.dl.get_idxs()
        # do your own shuffling which factors in self.epoch + self.seed in
        # generating a random sequence (underlying self.dl does not)
        if self.shuffle:
            idxs = self.shuffle_fn(idxs)
        self.n = len(idxs)
        # we assumed n was dl.n but we really care about number of idxs
        # add extra samples to make it evenly divisible
        self.n_padded = _round_to_multiple(self.n,self.world_size)
        idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n]
        # idx needs to be repeated when n_padded>>n
        # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors
        start_pos = self.rank*self.n_padded//self.world_size
        end_pos = (self.rank+1)*self.n_padded//self.world_size
        return idxs[start_pos:end_pos]

    def before_iter(self):
        self.dl.before_iter()

    def randomize(self):
        self.reset_rng()
        self.dl.randomize()

    def after_batch(self,b):
        b = self.dl.after_batch(b)
        # recast tensor subclasses to plain tensors
        # undoing work of self.retain()
        tb = [recast2tensor(o) for o in b]
        b = tuple(tb)
        return b

    def after_iter(self):
        self.dl.after_iter()

    def create_batches(self,samps):
        return self.dl.create_batches(samps)

    def to(self, device):
        self.dl.device = device
        self.device = device
        return self

    def one_batch(self):
        return self.dl.one_batch()

In [None]:
#hide_input
#colab
show_doc(TPUDistributedDL)

<h2 id="TPUDistributedDL" class="doc_header"><code>class</code> <code>TPUDistributedDL</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>TPUDistributedDL</code>(**`dl`**, **`rank`**, **`world_size`**, **`seed`**=*`42`*) :: `TfmdDL`

A `TfmdDL` which splits a batch into equal size pieces for each TPU core
It also recasts the output of a batch from a TensorBase subclass to
a regular tensor since the XLA Parallel loader doesn't seem to be compatible
to it.
Code implementation was based on @tmabraham's [`TPUDistributedDL`](/fastai_xla_extensions/multi_core.base.html#TPUDistributedDL) implementation
here: https://github.com/tmabraham/fastai_tpu/blob/master/fastai_v2/tpu_distributed_dl.py

In [None]:
#hide
#colab
#TODO: add tests for distrib tpu dl


In [None]:
#colab
from fastai.torch_core import TensorBase, TensorImage, TensorCategory
from fastai.data.core import TfmdDL

n_batches = 10
bs = 6
world_size = 8
# setup a dataloader as base dl for tpu 
items = [(TensorImage(torch.tensor(i).float()), TensorCategory(i)) for i in range(n_batches * bs * world_size)]
dl = TfmdDL(items, bs=bs, shuffle=True)
assert len(dl) == n_batches * world_size
b0 = next(iter(dl))
assert isinstance(b0[0], TensorImage)
assert isinstance(b0[1],TensorCategory)


In [None]:
#colab
tpu_dl = TPUDistributedDL(dl, rank=0, world_size=world_size)
# the batches for dl for each rank is divided across all ranks
assert len(tpu_dl) == n_batches
tpu_b0 = next(iter(tpu_dl))
# the types of each batch (x,y) have been reverted to torch tensors
# and are no longer Tensor subclasses (e.g. TensorBase)
assert isinstance(tpu_b0[0], torch.Tensor)
assert isinstance(tpu_b0[1], torch.Tensor)
assert not isinstance(tpu_b0[0], TensorBase)
assert not isinstance(tpu_b0[1], TensorBase)


In [None]:
#colab
# add tests to make sure all items are retrieved per epoch
# create dl for each rank across all ranks
tpu_dls = [TPUDistributedDL(dl, rank=rank, world_size=world_size) for rank in range(world_size)]
rank_batches = [list(tpu_dl) for tpu_dl in tpu_dls]
# TODO: check that each rank dont contain common items
# TODO: check that all items in dl are accounted for in the tpu_dls across all ranks

In [None]:
#export


def build_distributed_dataloaders(dls, rank, world_size, sync_valid=False):
    """Wrap dataloaders with distributed TPU aware dataloader """
    new_loaders = []
    for i,dl in enumerate(dls.loaders):
        if i == 0 or sync_valid:
            use_rank = rank
            use_size = world_size
        else:
            use_rank = 0
            use_size = 1
        dl = TPUDistributedDL(dl,
                            rank=use_rank,
                            world_size=use_size)
        new_loaders += [dl]
    return DataLoaders(*new_loaders, path=dls.path, device=dls.device)

In [None]:
#hide_input
show_doc(build_distributed_dataloaders)

<h4 id="build_distributed_dataloaders" class="doc_header"><code>build_distributed_dataloaders</code><a href="__main__.py#L5" class="source_link" style="float:right">[source]</a></h4>

> <code>build_distributed_dataloaders</code>(**`dls`**, **`rank`**, **`world_size`**, **`sync_valid`**=*`False`*)

Wrap dataloaders with distributed TPU aware dataloader 

In [None]:
#export
from fastcore.meta import delegates
from fastai.data.block import DataBlock

@delegates(DataBlock.dataloaders,but='datablock,rank,world_size,sync_valid,device')
def make_fastai_dataloaders(datablock, source, rank, world_size, device=None, path='.', sync_valid=False, verbose=False,**kwargs):
    "create fastai-based dataloaders from a datablock and wrap a tpu distributed dataloader around them"
    dls = datablock.dataloaders(source=source, path=path, device=device, **kwargs)
    distrib_dls = build_distributed_dataloaders(dls, rank, world_size, sync_valid=sync_valid)
    return distrib_dls

In [None]:
#hide_input
show_doc(make_fastai_dataloaders)

<h4 id="make_fastai_dataloaders" class="doc_header"><code>make_fastai_dataloaders</code><a href="__main__.py#L5" class="source_link" style="float:right">[source]</a></h4>

> <code>make_fastai_dataloaders</code>(**`datablock`**, **`source`**, **`rank`**, **`world_size`**, **`device`**=*`None`*, **`path`**=*`'.'`*, **`sync_valid`**=*`False`*, **`verbose`**=*`False`*)

create fastai-based dataloaders from a datablock and wrap a tpu distributed dataloader around them

In [None]:
#export
def wrap_parallel_loader(loader, device):
    "wraps a tpu distributed loader or a torch dataloader (with distributed sampler) with xla parallel loader"
    para_loader = pl.ParallelLoader(loader, [device])
    loop_loader = para_loader.per_device_loader(device)
    return loop_loader

In [None]:
#hide_input
show_doc(wrap_parallel_loader)

<h4 id="wrap_parallel_loader" class="doc_header"><code>wrap_parallel_loader</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>wrap_parallel_loader</code>(**`loader`**, **`device`**)

wraps a tpu distributed loader or a torch dataloader (with distributed sampler) with xla parallel loader

In [None]:
#exporti

from fastai.learner import Recorder
from fastai.callback.core import Callback
from fastai.learner import CancelValidException

In [None]:
#export

class XLATrainingCallback(Callback):
    "A callback for training as a spawned process on multi-core TPUs"
    run_before = Recorder
    run_valid = False
    order = -5 # after TrainEvalCallback
    def __init__(self, device, rank=0, sync_valid=False):
        self.pdevice = device
        self.rank = rank
        self.sync_valid = sync_valid

    def before_fit(self):
       xm.master_print('start fit')

    def before_epoch(self):
        # set the epoch on train only to make sure shuffle produces same seq
        # across all ranks
        if hasattr(self.learn.dls.train,'sampler'):
            if hasattr(self.learn.dls.train.sampler,'set_epoch'):
                self.learn.dls.train.sampler.set_epoch(self.learn.epoch)
        elif hasattr(self.learn.dls.train,'set_epoch'):
            self.learn.dls.train.set_epoch(self.learn.epoch)

        if self.sync_valid: # update epoch on valid if sync_valid
            if hasattr(self.learn.dls.valid,'sampler'):
                if hasattr(self.learn.dls.valid.sampler,'set_epoch'):
                    self.learn.dls.valid.sampler.set_epoch(self.learn.epoch)
            elif hasattr(self.learn.dls.valid,'set_epoch'):
                self.learn.dls.valid.set_epoch(self.learn.epoch)

    def before_train(self):
        self.learn.dl = wrap_parallel_loader(self.dls.train, self.pdevice)

    def before_validate(self):
        "Set the model in validation mode"
        if self.rank != 0 and not self.sync_valid:
        # no need to compute valid loss/ metric if not master if not sync valid
            raise CancelValidException()
        self.learn.dl = wrap_parallel_loader(self.dls.valid, self.pdevice)

The `XLATrainingCallback` is responsible for the following functions:
   * sets the `epoch` on either the torch dataloader sampler or the TPU distributed DL before each epoch. This ensures that for each epoch, samples in each batch are the same across all ranks, but each rank will pick the subset of batches for each rank.

   The `TPUDistributedDL` (and the torch distributed sampler) ensures that all the samples (with some duplication if the samples are not exactly divisible by the number of ranks) are seen by one of the dataloaders across the ranks least once per epoch.
   * wraps the dataloader (either training or validation) with the XLA Parallel Loader (`torch_xla.distributed.parallel_loader.ParallelLoader`) before each training or validation run.
   * sidesteps the call to `opt.step` and instead calls `xm.optimizer_step(opt)` to sync the model gradients across all the ranks.


In [None]:
#exporti

import copy
from fastai.learner import _maybe_item
from fastprogress.fastprogress import format_time

In [None]:
#export
def pack_metric(metrics):
    "extract counts and totals from avg metrics and avg losses into a list"
    counts = metrics.attrgot('count',0)
    totals = metrics.attrgot('total',0)
    metrics_list = counts + totals
    return metrics_list

def make_tensor(o, device):
    "convert a scalar or tensor into a float tensor and move them to `device`"
    if not isinstance(o, torch.Tensor):
        o = torch.tensor(o)
    return o.float().to(device)

def pack_metrics(all_metrics, device):
    "pack train and valid metrics into a list of float tensors and move them to `device`"
    metrics_list = pack_metric(all_metrics['train_mets']) + pack_metric(all_metrics['valid_mets'])
    return [make_tensor(item,device) for item in metrics_list ]

def restore_metrics(reduced_metrics, all_metrics):
    "restore list of float tensors (count and values) back into train and valid metrics"
    n_train = len(all_metrics['train_mets'])
    n_valid = len(all_metrics['valid_mets'])
    train_counts = reduced_metrics[:n_train]
    train_totals = reduced_metrics[n_train: n_train*2]
    valid_counts = reduced_metrics[n_train*2: n_train*2 + n_valid]
    valid_totals = reduced_metrics[n_train*2 + n_valid:]
    for i,metric in enumerate(all_metrics['train_mets']):
        if hasattr(metric,'count'):
            metric.count = train_counts[i].clone().detach().long()
        if hasattr(metric,'total'):
            metric.total = train_totals[i].clone().detach()
    for i,metric in enumerate(all_metrics['valid_mets']):
        if hasattr(metric,'count'):
            metric.count = valid_counts[i].clone().detach().long()
        if hasattr(metric,'total'):
            metric.total = valid_totals[i].clone().detach()
    return all_metrics

In [None]:
#export
class SyncRecorderCallback(Callback):
    """A `Callback` to sync the metrics from each rank and update statistics
       accordingly so it will display correctly in the progress callback
    """
    order  = 55 # after Recorder, before ProgressCallback

    def before_fit(self):
        if not xm.is_master_ordinal():
            return
        if 'progress' in self.learn.cbs.attrgot('name',None):
            self._sync_stats_log = self.progress._write_stats
        else:
            self._sync_stats_log = self.learn.logger

    def before_epoch(self):
        self.sync_log = copy.copy(self.recorder.log)

    def after_epoch(self):
        if 'recorder' not in self.learn.cbs.attrgot('name'):
            all_metrics = {
                'train_mets': L([]),
                'valid_mets': L([]),
            }
        else:
            all_metrics = {
                'train_mets': self.recorder._train_mets,
                'valid_mets': self.recorder._valid_mets,
            }
        # send metrics data to sync ranks across spawned processes
        device = self.learn.xla_training.pdevice
        packed_metrics = pack_metrics(all_metrics, device) # convert metrics to tensor list on TPU
        reduced_metrics = xm.all_reduce(xm.REDUCE_SUM, packed_metrics)
        xm.mark_step()
        if xm.is_master_ordinal():
            all_metrics = restore_metrics(reduced_metrics, all_metrics) # convert list to metric objects
            for m in self.recorder._train_mets:
                self.sync_log += _maybe_item(m)

            for m in self.recorder._valid_mets:
                self.sync_log += _maybe_item(m)

            self.learn.final_record = self.sync_log[:1].copy()
            del self.recorder.values[-1] # remove last entry added by recorder
            self.recorder.values.append(self.learn.final_record) # add updated metrics
            if self.recorder.add_time:
                updated_time = (time.time() - self.recorder.start_epoch)
                self.sync_log.append(format_time(updated_time))
            self.recorder.log = self.sync_log
            self._sync_stats_log(self.sync_log) # write_stats to output
            self.learn.logger = self.orig_logger # restore orig logger after skipping recorder.logger(log)

    def after_validate(self):
        if xm.is_master_ordinal():
            self.orig_logger = self.learn.logger
            self.learn.logger = noop # write to logger disabled so calling recorder.logger(log) wont print

In [None]:
#export
from fastcore.imports import noop
#from fastcore.basics import patch
from fastai.learner import Learner
from fastai.callback.progress import ProgressCallback
from fastcore.xtras import join_path_file
#from fastai.torch_core import get_model

In [None]:
#export

@patch
@delegates(Learner.save)
def save(self:Learner, file, **kwargs):
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    with_opt = self.opt is not None
    state = self.model.state_dict()
    if with_opt:
        # add opt state to state to be saved
        opt_state = self.opt.state_dict()
        state = {'model': state, 'opt':opt_state}
    xm.save(state, file) # use xm.save instead of torch.save
    return file

In [None]:
#hide_input
#colab
show_doc(Learner.save)

<h4 id="Learner.save" class="doc_header"><code>Learner.save</code><a href="__main__.py#L3" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.save</code>(**`file`**, **`with_opt`**=*`True`*, **`pickle_protocol`**=*`2`*)



The `Learner.save` has been patched to use the torch xla method `xm.save` which will save the model weights for the model on the TPU device. Moreover, `xm.save` only saves the weights on the master ordinal rank process by default, ensuring that only one copy of the model is written to a file. _Which is fine, since the `xm.optimizer_step` done on each training batch synchronizes the weights across all ranks anyway._

In [None]:
#export
@patch
def to_multi_xla(self:Learner,device, rank, sync_valid=False):
    "Sets up the learner on the spawned process for multi core TPU training"
    if 'xla_training' not in self.cbs.attrgot('name'):
        self.dls.device = None
        self.add_cbs([XLATrainingCallback(device, rank, sync_valid=sync_valid),
                      XLAOptCallback()])
        self.opt = None # clear opt to ensure

    else:
        self.xla_training.pdevice = device
        self.xla_training.rank = rank
        self.xla_training.sync_valid = sync_valid

    if sync_valid and 'sync_recorder' not in self.cbs.attrgot('name'):
        self.add_cbs(SyncRecorderCallback)
    elif not sync_valid:
        self.remove_cbs(SyncRecorderCallback)

    if rank != 0: # progress bar only for rank 0
        self.remove_cbs(ProgressCallback)
    self.logger = xm.master_print

In [None]:
#hide_input
#colab
show_doc(Learner.to_multi_xla)

<h4 id="Learner.to_multi_xla" class="doc_header"><code>Learner.to_multi_xla</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.to_multi_xla</code>(**`device`**, **`rank`**, **`sync_valid`**=*`False`*)

Sets up the learner on the spawned process for multi core TPU training

In [None]:
#export
# for testing
def do_one_loop(dl, rank, world_size, device, wrap_parallel=True):
    n_batches = len(dl)
    print(f'xla: {rank} world_size: {world_size} n_batches:{n_batches}')

    if wrap_parallel:
        print(f'xla: {rank} wrapping ploader')
        pdl = wrap_parallel_loader(dl, device=device)
    else:
        pdl = dl
    for i,b in enumerate(pdl):
        if i > 1:
            break
        xb, yb = b
        print(f'xla: {rank} iter:{i} xb type {type(xb)} yb type: {type(yb)}')
        print(f'xla: {rank} iter:{i} xb.shape {xb.shape} yb.shape: {yb.shape}')
        print(f'xla: {rank} iter:{i} xb.device {xb.device} yb.device: {yb.device}')
        print(f'xla: {rank} iter:{i} xb.dtype {xb.dtype} yb.device: {yb.dtype}')

## Test out the code


In [None]:
#hide
#colab
%cd /content

/content


In [None]:
#hide
#colab
from functools import partial
from fastai.metrics import accuracy
from fastai.optimizer import SGD, Adam

from fastcore.basics import first
from fastai.callback.schedule import *
from fastai.test_utils import VerboseCallback
from my_timesaver_utils.profiling import *
from my_timesaver_utils.profiling_callback import *

In [None]:
#colab
def run_dataloader_loop(rank):
    torch.manual_seed(1)
    print(f'xla {rank} start run_dataloader_loop')
    xm.rendezvous('start_run_dataloader_loop')
    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    SYNC_VALID = FLAGS['sync_valid']
    IS_PROFILING = FLAGS['is_profiling']
    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name}')
        start_record(rec_name)

    # dls = make_fastai_dataloaders(
    #                         DATA, 
    #                         PATH, 
    #                         rank=rank, 
    #                         world_size=world_size, 
    #                         sync_valid=SYNC_VALID,
    #                         bs=bs,)
    dls = DATA.dataloaders(PATH, bs=bs)
    # distrib_dls = build_distributed_dataloaders(dls, rank, world_size, 
    #                                            sync_valid=True)
    dl = dls.train
    tpu_dl = TPUDistributedDL(dl,rank=rank,world_size=world_size)
    print(f'xla: {rank} fake_l.num_workers {tpu_dl.fake_l.num_workers}')
    do_one_loop(tpu_dl, rank, world_size, device, wrap_parallel=False)
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    xm.mark_step()
    print(f'xla {rank} completed run_dataloader_loop')
    # print_prof_data()

In [None]:
#colab
def train_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_model')
    print(f'xla {rank} start train model')

    
    SYNC_VALID = FLAGS['sync_valid']
    IS_PROFILING = FLAGS['is_profiling']
    # Get loss function, optimizer, and model
    device = xm.xla_device()

    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name}')
        start_record(rec_name)

    dls = make_fastai_dataloaders(
                            DATA, 
                            PATH, 
                            rank=rank, 
                            world_size=world_size, 
                            sync_valid=SYNC_VALID,
                            bs=bs,)
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')
    model = WRAPPED_MODEL.to(device)
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms)
                      
    learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if IS_PROFILING and rank == 0:
        learner.to_my_profile()

    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()
    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)

    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')

    learner.save('stage-1')
    if rank == 0 and IS_PROFILING :
        learner.my_profile.print_stats()
    xm.mark_step()  
    


This is the main method that runs the training. 

It includes some profiling code to measure the building of the `dataloaders` and running of the `fit` methods. 

At the end of the spawned processes, the master ordinal process saves the model to a temporary file. (see `Learner.save` patch above)

The saved model will then be loaded by the main process so that it will now contain the trained weights updated by the spawned training processes.

In [None]:
#colab
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    train_model(rank)


In [None]:
#colab
# Start dataloader processes
def _mp_fn2(rank, flags):
    global FLAGS
    FLAGS = flags
    run_dataloader_loop(rank)


In [None]:
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock

In [None]:
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats
from fastcore.basics import using_attr
from fastai.data.transforms import RegexLabeller, CategoryMap

In [None]:
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()

In [None]:
from fastai.optimizer import Adam
OPT_FUNC = Adam

In [None]:
from fastai.data.transforms import RandomSplitter

In [None]:
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34

In [None]:
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = True
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-3
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 
ARCH = resnet34

In [None]:
from pathlib import Path
from fastcore.xtras import *
import torch_xla.distributed.xla_multiprocessing as xmp

In [None]:
#colab
PATH = untar_data(URLs.PETS)/'images'
# PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)


In [None]:
#colab

pat = r'(.+)_\d+.jpg$'
fname_labeller = using_attr(RegexLabeller(pat),'name') 
splitter=RandomSplitter(seed=42)
DATA = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=fname_labeller,
    splitter=splitter,
    item_tfms=[Resize(FLAGS['image_size']),],
    batch_tfms=[]
)
vocab = CategoryMap(get_image_files(PATH).map(fname_labeller))
N_OUT = len(vocab)


In [None]:
#colab
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'

The model is created by the main process and wrapped by the `xmp.MpModelWrapper`. This is to reduce the memory usage by not having multiple copies of the model in the spawned processes.

In [None]:
#colab
custom_model = create_cnn_model(ARCH, N_OUT, 
                                pretrained=True,
                                concat_pool=False)


In [None]:
#colab
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)

In [None]:
#hide_output
#colab
%%time
xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')

xla 0 start run_dataloader_loop
xla 6 start run_dataloader_loop
xla 4 start run_dataloader_loop
xla 5 start run_dataloader_loop
xla 7 start run_dataloader_loop
xla 3 start run_dataloader_loop
xla 1 start run_dataloader_loop
xla 2 start run_dataloader_loop
start rank7_dataloader_build
xla: 7 fake_l.num_workers 2
xla: 7 world_size: 8 n_batches:12
start rank1_dataloader_build
xla: 7 iter:0 xb type <class 'torch.Tensor'> yb type: <class 'torch.Tensor'>
xla: 7 iter:0 xb.shape torch.Size([64, 3, 224, 224]) yb.shape: torch.Size([64])
xla: 7 iter:0 xb.device cpu yb.device: cpu
xla: 7 iter:0 xb.dtype torch.float32 yb.device: torch.int64
xla: 7 iter:1 xb type <class 'torch.Tensor'> yb type: <class 'torch.Tensor'>
xla: 7 iter:1 xb.shape torch.Size([64, 3, 224, 224]) yb.shape: torch.Size([64])
xla: 7 iter:1 xb.device cpu yb.device: cpu
xla: 7 iter:1 xb.dtype torch.float32 yb.device: torch.int64
start rank0_dataloader_build
xla: 1 fake_l.num_workers 2
xla: 1 world_size: 8 n_batches:12
start rank5_d

In [None]:
#colab
%%time
FLAGS['is_profiling'] = False
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')

xla 0 start train model
xla 5 start train model
xla 3 start train model
xla 2 start train model
xla 1 start train model
xla 7 start train model
xla 6 start train model
xla 4 start train model
build learner
start running fit
start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.75931,1.098083,0.662162,01:55
1,0.637405,1.350417,0.69054,01:30
2,0.578879,0.578579,0.842568,01:19
3,0.505099,0.347587,0.893243,01:20
4,0.429199,0.24874,0.926351,01:20


CPU times: user 122 ms, sys: 122 ms, total: 244 ms
Wall time: 7min 51s


In [None]:
#hide
#colab
DATA.summary(PATH)

Setting-up type transforms pipelines
Collecting items from /root/.fastai/data/oxford-iiit-pet/images
Found 7390 items
2 datasets of sizes 5912,1478
Setting up Pipeline: PILBase.create
Setting up Pipeline: partial -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}

Building one sample
  Pipeline: PILBase.create
    starting from
      /root/.fastai/data/oxford-iiit-pet/images/newfoundland_143.jpg
    applying PILBase.create gives
      PILImage mode=RGB size=500x375
  Pipeline: partial -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
    starting from
      /root/.fastai/data/oxford-iiit-pet/images/newfoundland_143.jpg
    applying partial gives
      newfoundland
    applying Categorize -- {'vocab': None, 'sort': True, 'add_na': False} gives
      TensorCategory(27)

Final sample: (PILImage mode=RGB size=500x375, TensorCategory(27))


Collecting items from /root/.fastai/data/oxford-iiit-pet/images
Found 7390 items
2 datasets of sizes 5912,1478
Setting up Pip

In [None]:
#colab
mdls = DATA.dataloaders(PATH, bs=FLAGS['batch_size'])

In [None]:
#colab
mlearner = Learner(mdls, custom_model, 
                    loss_func=LOSS_FUNC, 
                    opt_func=OPT_FUNC, 
                    metrics=accuracy, 
                    wd=FLAGS['weight_decay'],
                    moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
# load trained weights from multi core tpu training
mlearner.load('stage-1')

<fastai.learner.Learner at 0x7ff01f35aba8>

In [None]:
#colab
mlearner.dls.device

device(type='cpu')

In [None]:
from fastai.torch_core import one_param

In [None]:
#colab
one_param(mlearner.model).device

device(type='cpu')

In [None]:
#colab
%%time
valid_metrics = mlearner.validate();print(valid_metrics)

[0.24909608066082, 0.9282814860343933]
CPU times: user 3min 30s, sys: 3.24 s, total: 3min 33s
Wall time: 3min 35s
