<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03b_multi_core.learner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#default_exp multi_core.learner

# Multi Core XLA Learner extensions

## Setup torch XLA


This is the official way to install Pytorch-XLA 1.7 [instructions here](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=CHzziBW5AoZH)

In [None]:
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 78kB/s 
[K     |████████████████████████████████| 61kB 3.6MB/s 
[?25h

## Install fastai

Use latest fastai and fastcore versions

In [None]:
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 5.8MB/s 
[K     |████████████████████████████████| 61kB 5.2MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


In [None]:
#hide
#colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#hide
#colab
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

In [None]:
%cd /content/fastai_xla_extensions/fastai_xla_extensions

Start of kernel

In [None]:
#exporti
from fastai_xla_extensions.utils import xla_imported

In [None]:
#exporti
# import sys
# def xla_imported():
#     return 'torch_xla' in sys.modules

In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass



In [None]:
#hide
#local
if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


## Expose xla fit methods on learner to simplify usage

In [None]:
#exporti
from fastai_xla_extensions.multi_core import *
from fastai_xla_extensions.misc_utils import *

In [None]:
#export
def _make_xla_child_learner(rank, sync_valid,learner_args):
    sync_valid = True
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
    dls = make_distributed_dataloaders(learner_args.pop('base_dls'), 
                                       rank, world_size, sync_valid=sync_valid)
    
    model = learner_args.pop('wrapped_model').to(device)

    learner = Learner(dls, model,**learner_args)
    learner.to_xla(device, rank, sync_valid=sync_valid)
    return learner
    


In [None]:
#export
def _xla_run_fit(rank, learner_args, fit_args):
    sync_valid = True
    learner = _make_xla_child_learner(rank, sync_valid, learner_args)    
    learner.fit(**fit_args)
    learner.save('_xla_tmp_model')
    xm.mark_step()

In [None]:
#export
def _xla_run_fit_one_cycle(rank, learner_args, fit_args):
    sync_valid = True
    learner = _make_xla_child_learner(rank, sync_valid, learner_args)      
    learner.fit_one_cycle(**fit_args)
    learner.save('_xla_tmp_model')
    xm.mark_step()

In [None]:
#export
from fastcore.basics import defaults, patch_to, patch
from fastai.learner import Learner
from fastai.callback.progress import ProgressCallback
@patch_to(Learner)
def pack_learner_args(self):
    learner_args = {}
    learner_args['wrapped_model'] =  xmp.MpModelWrapper(self.model)
    learner_args['base_dls'] = self.dls
    learner_args['opt_func'] = self.opt_func
    learner_args['loss_func'] = self.loss_func
    learner_args['metrics'] = self.metrics
    # fetch only cbs not in defaults
    if ProgressCallback not in defaults.callbacks:
        defaults.callbacks.append(ProgressCallback)
    learner_args['cbs'] = [cb for cb in self.cbs 
                      if cb.name not in L(defaults.callbacks).attrgot('name')]

    learner_args['wd'] = self.wd
    learner_args['moms'] = self.moms
    learner_args['lr'] = self.lr
    learner_args['splitter'] = self.splitter
    learner_args['path'] = self.path
    learner_args['model_dir'] = self.model_dir
    learner_args['wd_bn_bias'] = self.wd_bn_bias
    learner_args['train_bn'] = self.train_bn
    return learner_args

In [None]:
#export
@patch_to(Learner)
def reload_child_model(self):
    # blatantly stolen from fastai LRFinder after_fit :)
    tmp_f = self.path/self.model_dir/'_xla_tmp_model.pth'
    if tmp_f.exists():
        self.opt.zero_grad() 
        self.load('_xla_tmp_model', with_opt=False)
        os.remove(tmp_f)  
        self.create_opt()  

In [None]:
#export

from fastcore.meta import delegates
@patch
@delegates(Learner.fit, but='num_cores')
def xla_fit(self:Learner, n_epoch, num_cores=8, **kwargs):
    """call fit in multicore tpu environment"""
    self.remove_cbs(ProgressCallback)
    learner_args = self.pack_learner_args()
    fit_args={**kwargs}
    fit_args['n_epoch'] = n_epoch
    xmp.spawn(_xla_run_fit,
              args=(learner_args, fit_args,),
              nprocs=num_cores,
              start_method='fork')
    self.reload_child_model()
    self.add_cbs(ProgressCallback)


In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *
@patch
@delegates(Learner.fit_one_cycle, but='num_cores')
def xla_fit_one_cycle(self:Learner, n_epoch, num_cores=8, **kwargs):
    """call fit_one_cycle in multicore tpu environment"""
    self.remove_cbs(ProgressCallback)
    learner_args = self.pack_learner_args()
    fit_args={**kwargs}
    fit_args['n_epoch'] = n_epoch
    xmp.spawn(_xla_run_fit_one_cycle,
              args=(learner_args, fit_args,),
              nprocs=num_cores,
              start_method='fork')
    self.reload_child_model()
    self.add_cbs(ProgressCallback)

In [None]:
#colab
path = untar_data(URLs.MNIST_TINY)

In [None]:
#colab
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
#colab
dls = data.dataloaders(path, bs=16)

In [None]:
#colab
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)

In [None]:
#colab
#hide
assert hasattr(learner,'xla_fit')

In [None]:
#colab
%%time
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))


start fit


epoch,train_loss,valid_loss,accuracy,time


epoch,train_loss,valid_loss,accuracy,time
0,0.204781,0.384089,0.927557,00:14
1,0.184112,0.361853,0.828125,00:05
2,0.212741,0.367217,0.846591,00:05
3,0.243097,0.301363,0.879261,00:05
4,0.253019,0.289174,0.901989,00:05


CPU times: user 164 ms, sys: 229 ms, total: 392 ms
Wall time: 59.6 s


In [None]:
#colab
res = learner.get_preds()
print(accuracy(*res))

TensorBase(0.8984)


In [None]:
#colab
#hide
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                      

In [None]:
#colab
learner.unfreeze()

In [None]:
#colab
#hide
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       True      
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                      

In [None]:
#colab
learner.xla_fit(n_epoch=5, lr=2e-3)

start fit


epoch,train_loss,valid_loss,accuracy,time


epoch,train_loss,valid_loss,accuracy,time
0,0.073536,0.028147,0.990057,00:22
1,0.082206,0.767614,0.946023,00:05
2,0.073394,1.637862,0.84375,00:05
3,0.071335,0.842607,0.901989,00:05
4,0.075052,0.100458,0.974432,00:05


In [None]:
#colab
learner.validate()

(#2) [0.04073491320014,0.9871244430541992]