In [None]:
#hide
#colab
# attach gdrive holding repo
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#default_exp multi_core.learner

# Multi Core XLA Learner extensions

<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03b_multi_core.learner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Learner method patches to invoke multi-core `fit` and other operations prefixed by `xla_`. 

> These provide an alternate way to run multi core operations with minimal changes to existing fastai notebooks.

In [None]:
#hide
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 63kB/s 
[K     |████████████████████████████████| 61kB 3.4MB/s 
[?25h

In [None]:
#hide
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 4.9MB/s 
[K     |████████████████████████████████| 61kB 5.6MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -qqq nbdev

[?25l[K     |███████▏                        | 10kB 20.0MB/s eta 0:00:01[K     |██████████████▎                 | 20kB 15.3MB/s eta 0:00:01[K     |█████████████████████▍          | 30kB 9.1MB/s eta 0:00:01[K     |████████████████████████████▌   | 40kB 7.9MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 3.2MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


In [None]:
#hide
#colab
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

/content


In [None]:
#hide
# Start of kernel

In [None]:
#hide
#colab
%cd /content/fastai_xla_extensions

/content/drive/MyDrive/fastai_xla_extensions


In [None]:
#exporti
from fastai_xla_extensions.utils import xla_imported



In [None]:
#exporti
from fastai_xla_extensions.multi_core.base import *
from fastai_xla_extensions.misc_utils import *

In [None]:
#exporti
# import sys
# def xla_imported():
#     return 'torch_xla' in sys.modules

In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass

In [None]:
#hide

if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#exporti
if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp
    

## Utility methods to implement XLA `fit` methods

In [None]:
#export

from fastai.callback.progress import ProgressCallback

from fastai.learner import Learner
def make_xla_child_learner(rank, sync_valid,learner_args, add_args, ctrl_args):
    "create a learner using passed parameters"
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
    dls = build_distributed_dataloaders(learner_args.pop('base_dls'),
                                       rank, world_size, sync_valid=sync_valid)

    model = learner_args.pop('wrapped_model').to(device)

    learner = Learner(dls, model,**learner_args)
    learner.__stored_args__ = {**learner.__stored_args__, **add_args}
    learner.to_multi_xla(device, rank, sync_valid=sync_valid)
    if not ctrl_args['use_progress'] and 'progress' in L(learner.cbs).attrgot('name'):
        learner.remove_cbs(ProgressCallback)

    return learner

In [None]:
#export
def xla_run_method(rank, fit_method, learner_args, add_args, fit_args, ctrl_args):
    "run fit method on spawned process"
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)
    fit_method(learner, **fit_args)    
    learner.save('_xla_tmp_model')
    xm.mark_step()

In [None]:
#export
from fastcore.basics import defaults, patch_to, patch

_extra_args = ['concat_pool', 'arch', 'n_out', 'pretrained','normalize']

@patch
def pack_learner_args(self:Learner):
    "pack learner args into dict to pass to spawned process"
    learner_args = {**self.__stored_args__}
    learner_args['wrapped_model'] =  xmp.MpModelWrapper(self.model)
    learner_args['base_dls'] = self.dls
   # fetch only cbs not in defaults
    if ProgressCallback not in defaults.callbacks:
        defaults.callbacks.append(ProgressCallback)
    learner_args['cbs'] = [cb for cb in self.cbs
                      if cb.name not in L(defaults.callbacks).attrgot('name')]
    
    add_args = {}
    for arg in _extra_args:
        if arg in learner_args:
            add_args[arg] = learner_args.pop(arg)
    return learner_args, add_args

In [None]:
#export
@patch
def reload_child_model(self:Learner):
    "reload model built by spawned processes"
    # blatantly stolen from fastai LRFinder after_fit :)
    tmp_f = self.path/self.model_dir/'_xla_tmp_model.pth'
    if tmp_f.exists():
        self.opt.zero_grad()
        self.load('_xla_tmp_model', with_opt=False)
        os.remove(tmp_f)
        self.create_opt()

In [None]:
#export

from fastcore.foundation import L

@patch
def pre_xla_fit(self:Learner, ctrl_args={}):
    "prepare learner for running spawned processes"
    progress_removed = False
    if 'progress' in L(self.cbs).attrgot('name'):
        self.remove_cbs(ProgressCallback)
        progress_removed = True
    ctrl_args['use_progress'] = progress_removed
    return ctrl_args

@patch
def post_xla_fit(self:Learner, ctrl_args):
    "clean up learner after running spawned processes"
    if ctrl_args['use_progress']:
        self.add_cbs(ProgressCallback)

## XLA fit methods

In [None]:
#export

from fastcore.meta import delegates

@patch
@delegates(Learner.fit, but='num_cores,start_method')
def xla_fit(self:Learner, n_epoch, num_cores=8, start_method='fork', **kwargs):
    """call fit in multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    fit_args={**kwargs}
    
    fit_args['n_epoch'] = n_epoch
    xmp.spawn(xla_run_method,
              args=(Learner.fit, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *
@patch
@delegates(Learner.fit_one_cycle, but='num_cores,start_method')
def xla_fit_one_cycle(self:Learner, n_epoch, num_cores=8, start_method='fork', **kwargs):
    """call fit_one_cycle in multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    fit_args={**kwargs}
    fit_args['n_epoch'] = n_epoch
    xmp.spawn(xla_run_method,
              args=(Learner.fit_one_cycle, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

## Example: Train MNIST

In [None]:
#hide
#colab
from fastai.vision.all import *

In [None]:
#hide
#colab
%cd /content

/content


In [None]:
#colab
path = untar_data(URLs.MNIST_TINY)

In [None]:
#colab
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
#colab
dls = data.dataloaders(path, bs=16)

In [None]:
#colab
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
#colab
#hide
assert hasattr(learner,'xla_fit')

In [None]:
#colab
%%time
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))


start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.213612,0.471034,0.819602,00:20
1,0.176937,0.35657,0.84233,00:05
2,0.185203,0.30291,0.879261,00:05
3,0.200545,0.315121,0.87358,00:05
4,0.215351,0.336673,0.857955,00:05


CPU times: user 116 ms, sys: 166 ms, total: 282 ms
Wall time: 1min 2s


In [None]:
#colab
res = learner.get_preds()
print(accuracy(*res))

TensorBase(0.8627)


In [None]:
#hide
#colab
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                      

In [None]:
#colab
learner.unfreeze()

In [None]:
#hide
#colab
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       True      
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                      

In [None]:
#colab
learner.xla_fit(n_epoch=5, lr=2e-3)

start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.121027,11.478638,0.504261,00:20
1,0.075965,2.694753,0.772727,00:05
2,0.085595,0.126528,0.981534,00:05
3,0.084676,0.206319,0.934659,00:05
4,0.082818,0.065421,0.984375,00:05


In [None]:
#colab
learner.validate()

(#2) [0.0529780350625515,0.9856938719749451]