<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03_multi_core.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#default_exp multi_core

# Multi Core XLA extensions

## Setup torch XLA


This is the official way to install Pytorch-XLA 1.7 [instructions here](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=CHzziBW5AoZH)

In [2]:
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 29kB/s 
[K     |████████████████████████████████| 61kB 3.4MB/s 
[?25h

## Install fastai

Use latest fastai and fastcore versions

In [3]:
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 4.9MB/s 
[K     |████████████████████████████████| 61kB 5.4MB/s 
[?25h

In [4]:
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [5]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [6]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


Start of kernel

In [1]:
#export
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu



In [2]:
#exporti
from fastcore.basics import patch_to
from fastai.optimizer import _BaseOptimizer
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.utils.data as th_data
from fastcore.foundation import L
from pathlib import Path
from fastcore.xtras import *
from fastcore.transform import Pipeline
from fastai.data.core import DataLoaders
from functools import partial
import torch.utils.data.distributed as torch_distrib
from pathlib import Path
import fastcore.xtras
import math
from fastcore.basics import store_attr
from operator import attrgetter
from fastai.data.load import _FakeLoader
from fastai.data.core import TfmdDL
from fastai.torch_core import find_bs, TensorBase
import random
import torch
from fastai.data.load import _loaders
from fastai.torch_core import to_device
from fastcore.basics import first


## Patching BaseOptimizer to be Pickable
Patching Base Optimizer `__getstate__` and `__setstate__` whichi is used in pickling
the optimizer which should fix the bug in running the learner in multiple TPU cores
in XLA by which the  `def _fetch_gradients(optimizer)` in `for param_group in optimizer.__getstate__()['param_groups']:` fails, and this patch fixes the "copy constructor" to include the param_groups.

In [3]:
#export
@patch_to(_BaseOptimizer)
def __getstate__(self):
    d = {
            'state': self.state_dict(),
            'param_groups': self.param_groups,
        }
    if hasattr(self,'defaults'): 
        d['defaults'] = self.defaults
    return d

@patch_to(_BaseOptimizer)
def __setstate__(self, data):
    if 'defaults' in data:
        self.defaults = data['defaults']
    self.load_state_dict(data['state'])
    self.param_groups = data['param_groups']

In [4]:
#export
def _recast2tensor(o):
    if isinstance(o,TensorBase):
        # return plain tensor since pl.parallelloader doesn't
        # seem to work with tensor subclasses
        return torch.tensor(o.numpy())
    return o

def _round_to_multiple(number,multiple): 
    return int(math.ceil(number/multiple)*multiple)


In [5]:
#export
class TPUDistributedDL(TfmdDL):
    """A `TfmdDL` which splits a batch into equal size pieces for each TPU core
       It also recasts the output of a batch from a Tensorbase subclass to 
       a regular tensor since the XLA Parallel loader doesnt seem compatible
       to it.
    """
    _default = 'dl'
    def __init__(self,dl,rank,world_size, seed=42):
        store_attr()
        self.bs,self.device,self.num_workers,self.drop_last,self.dataset,self.offs,fake = \
            attrgetter('bs','device','num_workers','drop_last','dataset','offs','fake_l')(dl)
        self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, 
                                  persistent_workers=fake.persistent_workers)
        self.epoch = 0
        random.seed(self.seed)
        self.dl.rng = random.Random(random.randint(0,2**32-1))
        self.reset_rng()

    def reset_rng(self):
        random.seed(self.seed + self.epoch)
        self.rng = random.Random(random.randint(0,2**32-1))

    def __len__(self): 
        return _round_to_multiple(len(self.dl),self.world_size)//self.world_size

    def set_epoch(self, epoch):
        self.epoch = epoch

    def get_idxs(self):
        idxs = self.dl.get_idxs()
        # do your own shuffling which factors in self.epoch + self.seed in
        # generating a random sequence (underlying self.dl does not)
        if self.shuffle: 
            idxs = self.shuffle_fn(idxs)
        self.n = len(idxs)              
        # we assumed n was dl.n but we really care about number of idxs
        # add extra samples to make it evenly divisible
        self.n_padded = _round_to_multiple(self.n,self.world_size)
        idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] 
        # idx needs to be repeated when n_padded>>n
        # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors
        start_pos = self.rank*self.n_padded//self.world_size
        end_pos = (self.rank+1)*self.n_padded//self.world_size
        return idxs[start_pos:end_pos]

    def before_iter(self):
        self.dl.before_iter()

    def randomize(self): 
        self.reset_rng()
        self.dl.randomize()

    def after_batch(self,b):
        b = self.dl.after_batch(b)
        # recast tensor subclasses to plain tensors
        # undoing work of self.retain()
        tb = [_recast2tensor(o) for o in b]
        b = tuple(tb)
        return b

    def after_iter(self): 
        self.dl.after_iter()

    def create_batches(self,samps): 
        return self.dl.create_batches(samps)

    def to(self, device):
        self.dl.device = device
        self.device = device
        return self



In [6]:
#exporti
from fastai.torch_core import default_device, apply
import torch 
from fastcore.xtras import is_listy
import torch
import torch.utils.hooks
from fastcore.basics import patch
from fastai.torch_core import TensorBase
from collections import OrderedDict

In [7]:
#exporti
from fastcore.basics import patch_to
import torch.utils.data.distributed as th_distrib
import torch.utils.data as th_data

In [8]:
#export
class TfmdTorchDS(th_data.Dataset):
    def __init__(self, items, x_tfm=None, y_tfm=None):
        self.items = items
        self.x_tfm = x_tfm
        self.y_tfm = y_tfm

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        item = self.items[index]
        x = self.x_tfm(item) if self.x_tfm is not None else x
        y = self.y_tfm(item) if self.y_tfm is not None else y
        return (x,y)

In [9]:
#exporti
from fastcore.xtras import is_listy
import torchvision as thv
from operator import itemgetter
from fastcore.imports import noop
from fastcore.foundation import L

In [10]:
#export
def to_list(o):
    return [] if o is None else [o] if not is_listy(o) else o

def has_setup(tfms):
    """returns last index if at least 1 `tfm` in `tfms` has a method `setup` else return -1"""
    setups = L(tfms).attrgot('setup',None).argwhere(noop) # get indexes where tfm has `setup` attribute
    return -1 if len(setups) == 0 else setups[-1]

def run_setups(tfms, items):
    """run tfm setups including tfm for all items"""
    indx = has_setup(tfms)
    if indx == -1: # no setup found
        return

    for i,tfm in enumerate(tfms):
        if hasattr(tfm,'setup'):
            tfm.setup(items)
        if i < indx:
            # tfm items to be fed into next tfm
            items = [tfm(item) for item in items]

In [11]:
#export
class TorchDatasetBuilder:
    def __init__(self, source, get_items, splitter,
                x_tfms, y_tfms,
                x_type_tfms=None,
                x_train_tfms=None, x_test_tfms=None,
                do_setup=False):
        self.source = source
        self.get_items = get_items
        self.splitter = splitter
        self.do_setup = do_setup
        self.x_tfms = to_list(x_tfms)
        self.y_tfms = to_list(y_tfms)
        self.x_type_tfms = to_list(x_type_tfms)
        self.x_train_tfms = to_list(x_train_tfms)
        self.x_test_tfms = to_list(x_test_tfms)

    def setup(self, items, do_setup=None, setup_x=False):
        self.do_setup = do_setup if do_setup is not None else self.do_setup
        if self.do_setup:
            all_x_tfms = [*self.x_type_tfms, *self.x_train_tfms, *self.x_tfms]
            if setup_x:
                run_setups(all_x_tfms, items)
            run_setups(self.y_tfms, items)
            self.do_setup = False

    def get_datasets(self, do_setup=None):
        self.do_setup = do_setup if do_setup is not None else self.do_setup
        items = self.get_items(self.source)
        train_idxs, test_idxs = self.splitter(items)

        train_items = itemgetter(*train_idxs)(items)
        test_items = itemgetter(*test_idxs)(items)
        self.setup(train_items)
        allx_test_tfms = [*self.x_type_tfms, *self.x_test_tfms, *self.x_tfms]
        allx_train_tfms = [*self.x_type_tfms, *self.x_train_tfms, *self.x_tfms]
        train_x_tfm = thv.transforms.Compose(allx_train_tfms)
        test_x_tfm = thv.transforms.Compose(allx_test_tfms)
        y_tfm = thv.transforms.Compose(self.y_tfms)
        train_ds = TfmdTorchDS(train_items, x_tfm=train_x_tfm, y_tfm=y_tfm)
        test_ds = TfmdTorchDS(test_items, x_tfm=test_x_tfm, y_tfm=y_tfm)
        return train_ds, test_ds

In [12]:
#export
from fastai.data.transforms import CategoryMap

class VocabularyMapper:
    """A simplified version of the fastai Categorize Transform"""
    def __init__(self, vocab=None):
        self.vocab = vocab
        self.c = 0
    def setup(self, items):
        self.vocab = CategoryMap(items)
        self.c = len(self.vocab)
    def __call__(self, o):
        if self.vocab is None: return o
        try:
            return torch.tensor(self.vocab.o2i[o])
        except KeyError as e:
            raise KeyError(f"Label '{o}' was not included in the training dataset") from e


In [13]:
import torchvision as thv

pil2tensor = thv.transforms.ToTensor()
resize28 = thv.transforms.Resize(28)
norm = thv.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

from fastai.vision.core import PILImage
from fastai.data.transforms import get_image_files, GrandparentSplitter, parent_label
from fastai.data.external import untar_data, URLs

path = untar_data(URLs.MNIST_TINY)
mnist_dset_builder =  TorchDatasetBuilder(
                source=path, 
                get_items=get_image_files, 
                splitter=GrandparentSplitter(),
                x_tfms=[resize28,pil2tensor,norm,], 
                y_tfms=[parent_label,VocabularyMapper(),],
                x_type_tfms=PILImage.create)

from fastcore.test import test_eq

train_ds, test_ds = mnist_dset_builder.get_datasets(do_setup=True)

test_eq(len(train_ds),709)
test_eq(len(test_ds),699)
test_eq(mnist_dset_builder.y_tfms[1].vocab, ('3','7'))
test_eq(mnist_dset_builder.y_tfms[1].c, 2)


In [14]:
#export
@patch_to(th_data.DataLoader)
def to(self, device):
    self.device = device

In [15]:
#export
def make_torch_dataloaders(train_dataset, test_dataset,
                     rank,
                     world_size,
                     bs,
                     num_workers=4,
                     distrib=True,
                     sync_valid=False):
    if distrib:
        train_sampler = th_distrib.DistributedSampler(
            train_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True)
        train_loader = th_data.DataLoader(
            train_dataset,
            batch_size=bs,
            sampler=train_sampler,
            # shuffle=True,
            num_workers=num_workers,
            drop_last=True)
        
        if sync_valid:
            test_sampler = th_distrib.DistributedSampler(
                test_dataset,
                num_replicas=world_size,
                rank=rank,
                shuffle=False)

            test_loader = th_data.DataLoader(
                test_dataset,
                batch_size=bs,
                sampler=test_sampler,
                # shuffle=False,
                num_workers=num_workers,
                drop_last=True)
        else:
            test_loader = th_data.DataLoader(
                test_dataset,
                batch_size=bs,
                shuffle=False,
                num_workers=num_workers,
                drop_last=True)

    else:
        train_loader = th_data.DataLoader(
            train_dataset,
            batch_size=bs,
            # sampler=train_sampler,
            shuffle=True,
            num_workers=num_workers,
            drop_last=True)

        test_loader = th_data.DataLoader(
            test_dataset,
            batch_size=bs,
            shuffle=False,
            num_workers=num_workers,
            drop_last=True)
    dataloaders = DataLoaders(train_loader, test_loader, device=None)
    return dataloaders

In [16]:
#exporti
import re

In [17]:
#export
class FileNamePatternLabeller:
    "Delayed action version of fastai RegexLabeller with file name selection"
    def __init__(self, pat_str, match=False):
        self.pat_str = pat_str
        self.match = match
        self.matcher = None
        self.pat = None
    def __call__(self, f):
        if isinstance(f,str):
            f = Path(f)
        o = f.name
        if self.pat is None:
            self.pat = re.compile(self.pat_str)
            self.matcher = self.pat.match if self.match else self.pat.search
        res  = self.matcher(o)
        assert res, f'Failed to find "{self.pat}" in {o}'
        return res.group(1)
            

In [18]:
#export
def make_distributed_dataloaders(dls, rank, world_size, sync_valid=False):
    """Wrap dataloaders with distributed TPU aware dataloader """
    new_loaders = []
    for i,dl in enumerate(dls.loaders):
        if i == 0 or sync_valid:
            use_rank = rank
            use_size = world_size
        else: 
            use_rank = 0
            use_size = 1         
        dl = TPUDistributedDL(dl,
                            rank=use_rank, 
                            world_size=use_size)
        new_loaders += [dl]
    return DataLoaders(*new_loaders, path=dls.path, device=dls.device)

In [19]:
#export
# def DataBlock.dataloaders(self, source, path='.', verbose=False, **kwargs):
def make_fastai_dataloaders(datablock, source, rank, world_size, device=None, path='.', sync_valid=False, verbose=False,**kwargs):
    dls = datablock.dataloaders(source=source, path=path, device=device, **kwargs)
    distrib_dls = make_distributed_dataloaders(dls, rank, world_size, sync_valid=sync_valid)
    return distrib_dls


In [20]:
#export
def wrap_parallel_loader(loader, device):
    para_loader = pl.ParallelLoader(loader, [device])
    loop_loader = para_loader.per_device_loader(device)
    return loop_loader

In [21]:
#exporti
from fastai.callback.core import TrainEvalCallback
from fastai.learner import Recorder
from fastai.torch_core import one_param
import torch
from fastai.callback.core import Callback
from fastai.learner import CancelTrainException, CancelValidException, CancelStepException
from fastai.torch_core import tensor, TensorCategory

In [22]:
#export
class XLATrainingCallback(Callback):
    run_before = Recorder
    run_valid = False
    order = -5 # after TrainEvalCallback 
    def __init__(self, device, rank=0, sync_valid=False):
        self.pdevice = device
        self.rank = rank
        self.sync_valid = sync_valid

    def before_fit(self):
       xm.master_print('start fit')

    def before_epoch(self):
        # set the epoch on train only to make sure shuffle produces same seq 
        # across all ranks
        if hasattr(self.learn.dls.train,'sampler'):
            if hasattr(self.learn.dls.train.sampler,'set_epoch'):
                self.learn.dls.train.sampler.set_epoch(self.learn.epoch) 
        elif hasattr(self.learn.dls.train,'set_epoch'):
            self.learn.dls.train.set_epoch(self.learn.epoch)

        if self.sync_valid: # update epoch on valid if sync_valid
            if hasattr(self.learn.dls.valid,'sampler'):
                if hasattr(self.learn.dls.valid.sampler,'set_epoch'):
                    self.learn.dls.valid.sampler.set_epoch(self.learn.epoch) 
            elif hasattr(self.learn.dls.valid,'set_epoch'):
                self.learn.dls.valid.set_epoch(self.learn.epoch)

    def before_train(self):
        self.learn.dl = wrap_parallel_loader(self.dls.train, self.pdevice)

    def before_validate(self):
        "Set the model in validation mode"
        if self.rank != 0 and not self.sync_valid: 
        # no need to compute valid loss/ metric if not master if not sync valid
            raise CancelValidException()    
        self.learn.dl = wrap_parallel_loader(self.dls.valid, self.pdevice)

    def before_step(self):
        raise CancelStepException()

    def after_cancel_step(self):
        xm.optimizer_step(self.learn.opt)


In [23]:
#exporti
import copy
from fastcore.imports import noop
from fastcore.foundation import L
from fastai.learner import Metric, AvgMetric, AvgLoss, AvgSmoothLoss
import torch
import pickle
from fastai.torch_core import find_bs, to_detach

In [24]:
#export

@patch
def update_metric(self:Metric, other_metrics):
    # dunno how to handle updates for metrics other than AvgMetric, AvgLoss
    pass

@patch
def update_metric(self:(AvgMetric,AvgLoss), other_metrics):
    other_metrics = L(other_metrics)
    # other metrics must also be AvgMetric or AvgLoss
    assert len(other_metrics.map(lambda o: not isinstance(o, (AvgLoss,AvgMetric))).argwhere(noop)) == 0
    # other metrics must have same name
    assert len(other_metrics.attrgot('name').map(lambda o: o != self.name).argwhere(noop)) == 0
    self.total = other_metrics.attrgot('total').sum()
    self.count = other_metrics.attrgot('count').sum()

def unpack_sync(res):
    return [pickle.loads(o) for o in res]
        

In [25]:
#exporti
from fastai.learner import _maybe_item
from fastprogress.fastprogress import format_time
import time


In [26]:
#export
class SyncRecorderCallback(Callback):
    """Sync metrics from each spawned process update statistics 
       accordingly so it will display correctly in the progress callback
    """
    order  = 55 # after Recorder, before ProgressCallback
    def __init__(self):
        pass
    
    def before_fit(self):       
        if not xm.is_master_ordinal():
            return
        if 'progress' in self.learn.cbs.attrgot('name',None):
            self._sync_stats_log = self.progress._write_stats
        else:
            self._sync_stats_log = self.learn.logger
        
    def after_fit(self):
        xm.rendezvous('sync recorder after_fit')        

    def before_epoch(self):
        self.sync_log = copy.copy(self.recorder.log)
    
    def after_epoch(self):
        if 'recorder' not in self.learn.cbs.attrgot('name'):
            all_metrics = {
                'train_mets': L([]),
                'valid_mets': L([]),
            }
        else:
            all_metrics = {
                'train_mets': self.recorder._train_mets, 
                'valid_mets': self.recorder._valid_mets,
            }
        # send metrics data to sync ranks across spawned processes     
        sync_tag = f'sync_recorder_after_epoch{self.learn.epoch}'
        res = xm.rendezvous(sync_tag, pickle.dumps(all_metrics))
        
        if xm.is_master_ordinal():
            all_metrics = unpack_sync(res)
            self._sync_log(all_metrics) # use metrics across ranks to update log
            
            self.learn.final_record = self.sync_log[:1].copy()
            del self.recorder.values[-1] # remove last entry added by recorder
            self.recorder.values.append(self.learn.final_record) # add updated metrics
            if self.recorder.add_time:
                updated_time = (time.time() - self.recorder.start_epoch)
                self.sync_log.append(format_time(updated_time))
            self.recorder.log = self.sync_log
            self._sync_stats_log(self.sync_log) # write_stats to output
            self.learn.logger = self.orig_logger # restore orig logger after skipping recorder.logger(log) 
            
       
    def before_validate(self):
        pass
    
    def after_validate(self):
        if xm.is_master_ordinal():
            self.orig_logger = self.learn.logger
            self.learn.logger = noop # write to logger disabled so calling recorder.logger(log) wont print
        pass
    
    def before_batch(self):
        pass
    
    def _sync_log(self, all_metrics):
        all_metrics = L(all_metrics)
        
        for i,m in enumerate(self.recorder._train_mets):
            m.update_metric(all_metrics.attrgot('train_mets').itemgot(i))
            self.sync_log += _maybe_item(m)
            
        for i,m in enumerate(self.recorder._valid_mets):
            m.update_metric(all_metrics.attrgot('valid_mets').itemgot(i))
            self.sync_log += _maybe_item(m)   
        


In [27]:
#exporti
from fastcore.imports import noop
from fastcore.basics import patch
from fastai.learner import Learner
from fastai.callback.progress import ProgressCallback
from fastcore.xtras import join_path_file
from fastai.torch_core import get_model


In [28]:
#export
@patch
def save(self:Learner, file, **kwargs):
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    with_opt = self.opt is not None  
    state = self.model.state_dict()
    if with_opt:
        # add opt state to state to be saved
        opt_state = self.opt.state_dict() 
        state = {'model': state, 'opt':opt_state}
    xm.save(state, file) # use xm.save instead of torch.save
    return file


In [29]:
#export
@patch
def to_xla(self:Learner,device, rank, sync_valid=False):
    if 'xla_training' not in self.cbs.attrgot('name'):
        self.dls.device = None
        self.add_cbs(XLATrainingCallback(device, rank, sync_valid=sync_valid))
    else:
        self.xla_training.pdevice = device
        self.xla_training.rank = rank
        self.xla_training.sync_valid = sync_valid

    if sync_valid and 'sync_recorder' not in self.cbs.attrgot('name'):
        self.add_cbs(SyncRecorderCallback)
    elif not sync_valid:
        self.remove_cbs(SyncRecorderCallback)

    if rank != 0: # progress bar only for rank 0
        self.remove_cbs(ProgressCallback)
    self.logger = xm.master_print

In [30]:
#export
# for testing
def do_one_loop(dls, rank, world_size, device, sync_valid, is_train=True):
    if is_train:
        dl = dls.train
    else:
        dl = dls.valid

    n_batches = len(dl)
    print(f'xla: {rank} world_size: {world_size} train:{is_train} n_batches:{n_batches} sync_valid: {sync_valid}')

    if sync_valid or is_train or rank == 0:
        print(f'xla: {rank} wrapping ploader')
        pdl = wrap_parallel_loader(dl, device=device)
    for i,b in enumerate(pdl):
        if i > 1:
            break
        xb, yb = b
        print(f'xla: {rank} iter:{i} xb.shape {xb.shape} yb.shape: {yb.shape}')        
        print(f'xla: {rank} iter:{i} xb.device {xb.device} yb.device: {yb.device}')        
        print(f'xla: {rank} iter:{i} xb.dtype {xb.dtype} yb.device: {yb.dtype}')        
        


## Test out the code


In [31]:
from functools import partial
from fastai.metrics import accuracy
from fastai.optimizer import SGD, Adam

from fastcore.basics import first
from fastai.callback.schedule import *
from fastai.test_utils import VerboseCallback
from my_timesaver_utils.profiling import *
from my_timesaver_utils.profiling_callback import *

In [32]:
def train_torch_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_torch_model')
    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    IS_PROFILING = FLAGS['is_profiling']
    SYNC_VALID = FLAGS['sync_valid']

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']
    num_workers = FLAGS['num_workers']

    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dset_build'
        print(f'start {rec_name}')
        start_record(rec_name)
    dsets = DSET_BUILDER.get_datasets()
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    if IS_PROFILING:
        rec_name2 = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name2}')
        start_record(rec_name2)
    dls = make_torch_dataloaders(*dsets, 
                                  rank=rank, 
                                  world_size=world_size, 
                                  bs=bs,
                                  num_workers=num_workers,
                                  sync_valid=SYNC_VALID,
                                 )

    if IS_PROFILING:
        end_record(rec_name2)
        print_prof_data(rec_name2)
        print(f'finished {rec_name2}')

    # do_one_loop(dls,rank,world_size,device,sync_valid=SYNC_VALID, is_train=True)
    # do_one_loop(dls,rank,world_size,device,sync_valid=SYNC_VALID, is_train=False)

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms
                      )
                      
    learner.to_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if rank == 0:
        learner.to_my_profile()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()

    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)
    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))

    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')

    learner.save('stage-1')
    if rank == 0:
        learner.my_profile.print_stats()
    xm.mark_step() 
    xm.rendezvous('end_train_torch_model')
    if IS_PROFILING:
        clear_prof_data() 


In [33]:
def train_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_model')
    print(f'xla {rank} start train model')

    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    SYNC_VALID = FLAGS['sync_valid']
    IS_PROFILING = FLAGS['is_profiling']
    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']

    world_size = xm.xrt_world_size()
    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name}')
        start_record(rec_name)

    dls = make_fastai_dataloaders(
                            DATA, 
                            PATH, 
                            rank=rank, 
                            world_size=world_size, 
                            sync_valid=SYNC_VALID,
                            bs=bs,)
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms)
                      
    learner.to_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if rank == 0:
        learner.to_my_profile()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()
    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)

    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))
    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')

    learner.save('stage-1')
    if rank == 0:
        learner.my_profile.print_stats()
    xm.mark_step()  
    xm.rendezvous('end_train_model')
    if IS_PROFILING:
        clear_prof_data() 


In [34]:
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    train_model(rank)

def _mp_fn2(rank, flags):
    global FLAGS
    FLAGS = flags
    train_torch_model(rank)



In [35]:
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock

In [36]:
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats

In [37]:
LOSS_FUNC = nn.CrossEntropyLoss()

In [38]:
OPT_FUNC = Adam

In [39]:
from fastai.data.transforms import RandomSplitter

In [40]:
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34

In [41]:
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = False
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 5e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-4
FLAGS['num_epochs'] = 20
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 
ARCH = resnet34
USE_DBLOCK = False

In [42]:
from pathlib import Path
from fastcore.xtras import *


In [43]:
PATH = untar_data(URLs.PETS)/'images'
# PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)


In [44]:
if USE_DBLOCK:
    pat = r'(.+)_\d+.jpg$'
    fname_labeller = FileNamePatternLabeller(pat)
    splitter=RandomSplitter(seed=42)
    DATA = DataBlock(
        blocks=(ImageBlock, CategoryBlock),
        get_items=get_image_files,
        get_y=fname_labeller,
        splitter=splitter,
        item_tfms=[Resize(FLAGS['image_size']),],
        batch_tfms=[]
    )
    vocab = CategoryMap(get_image_files(PATH).map(fname_labeller))
    N_OUT = len(vocab)


In [45]:
if not USE_DBLOCK:
    imagenet_norm = thv.transforms.Normalize(
        mean=(0.485, 0.456, 0.406), 
        std=(0.229, 0.224, 0.225))

    cifar_norm = thv.transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), 
        std=(0.2023, 0.1994, 0.2010))

    image_size = FLAGS['image_size']
    splitter = RandomSplitter(seed=42)
    pat = r'(.+)_\d+.jpg$'
    fname_labeller = FileNamePatternLabeller(pat)

    DSET_BUILDER = TorchDatasetBuilder(
        PATH, 
        get_items=get_image_files,
        splitter=splitter,
        x_tfms=[thv.transforms.Resize((image_size,image_size)), thv.transforms.ToTensor(), imagenet_norm],
        y_tfms=[fname_labeller, VocabularyMapper(),],
        x_type_tfms=PILImage.create,
    ) 
    start_record('master_vocab_setup')
    DSET_BUILDER.setup(get_image_files(PATH),do_setup=True)
    end_record('master_vocab_setup')
    print_prof_data('master_vocab_setup')
    clear_prof_data()
    N_OUT = DSET_BUILDER.y_tfms[1].c     

Function master_vocab_setup called 1 times.
Execution time max: 0.131, average: 0.131


In [46]:
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'

In [47]:

custom_model = create_cnn_model(ARCH, N_OUT, 
                                pretrained=True,
                                concat_pool=False)


In [48]:
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)

In [49]:
SERIAL_EXEC = xmp.MpSerialExecutor()

In [50]:
%%time
# !rm -f /content/models/stage-1.pth
if USE_DBLOCK:
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
            start_method='fork')
else:
    xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
            start_method='fork')


build learner
start running fit
start fit


epoch,train_loss,valid_loss,accuracy,time
0,2.415562,1.13098,0.697266,00:19
1,1.404932,2.204097,0.494141,00:37
2,1.089429,4.24343,0.316406,00:37
3,1.030001,14.848968,0.139648,00:37
4,1.033107,5.174943,0.225586,00:37
5,0.942981,3.685695,0.236328,00:37
6,0.843764,2.449688,0.486328,00:37
7,0.729406,1.739083,0.605469,00:37
8,0.615905,1.985155,0.603516,00:37
9,0.516536,1.263756,0.707031,00:37


fit  called 1 times. max: 733.510 avg: 733.510
   epoch  called 20 times. max: 37.731 avg: 35.775
      train  called 20 times. max: 31.429 avg: 30.357
         train_batch  called 220 times. max: 12.297 avg: 1.382
            train_pred  called 220 times. max: 12.121 avg: 1.030
            train_loss  called 220 times. max: 0.001 avg: 0.000
            train_backward  called 220 times. max: 0.024 avg: 0.005
            train_step  called 220 times. max: 12.137 avg: 0.343
            train_zero_grad  called 220 times. max: 0.005 avg: 0.003
      valid  called 20 times. max: 6.395 avg: 5.417
         valid_batch  called 40 times. max: 0.011 avg: 0.005
            valid_pred  called 40 times. max: 0.010 avg: 0.004
            valid_loss  called 40 times. max: 0.001 avg: 0.000
CPU times: user 204 ms, sys: 237 ms, total: 440 ms
Wall time: 13min 4s


In [51]:
if USE_DBLOCK: DATA.summary(PATH)

In [52]:
if USE_DBLOCK:
    mdls = DATA.dataloaders(PATH, bs=FLAGS['batch_size'])
else:
    mdsets = DSET_BUILDER.get_datasets()
    mdls = make_torch_dataloaders(*mdsets,
                                  rank=0,
                                  world_size=1,
                                  bs=FLAGS['batch_size'],
                                  num_workers=FLAGS['num_workers']
                                  )

In [53]:
mlearner = Learner(mdls, custom_model, 
                    loss_func=LOSS_FUNC, 
                    opt_func=OPT_FUNC, 
                    metrics=accuracy, 
                    wd=FLAGS['weight_decay'],
                    moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
mlearner.load('stage-1')

<fastai.learner.Learner at 0x7f71605a9668>

In [54]:
mlearner.dls.device

In [55]:
from fastai.torch_core import one_param

In [56]:
one_param(mlearner.model).device

device(type='cpu')

In [63]:
%%time
valid_metrics = mlearner.validate();print(valid_metrics)

[0.6023236513137817, 0.8620923757553101]
CPU times: user 7min 52s, sys: 6.23 s, total: 7min 58s
Wall time: 25.7 s


In [58]:
# master_device = xm.xla_device()

In [59]:
# mlearner.dls.device = master_device
# mlearner.model.to(master_device)
# mlearner.opt = None
# mlearner.create_opt()

In [60]:
# %%time
# valid_metrics = mlearner.validate(); valid_metrics

In [61]:
# mlearner.dls.device

In [62]:
# one_param(mlearner.model).device