In [None]:
#hide
#colab
# attach gdrive holding repo
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
#default_exp multi_core.callback    

# Multi Core Callback XLA Extensions

> Patches to Recorder and ParamScheduler Callbacks
to support Multi Core XLA Training

Modifications to existing callback `Recorder`, `ParamScheduler` and `LRFinder` are needed in order to store extra attributes to a temporary file after running the training as spawned processes.  

In [None]:
#hide
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 5.0MB/s 
[K     |████████████████████████████████| 61kB 4.7MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -qqq nbdev

[K     |████████████████████████████████| 51kB 3.0MB/s 
[K     |████████████████████████████████| 51kB 4.1MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/fastai_xla_extensions.git

  Building wheel for fastai-xla-extensions (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
#colab
# !pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

In [None]:
VERSION = "nightly"  #@param ["1.5", "1.7" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  5116  100  5116    0     0   108k      0 --:--:-- --:--:-- --:--:--  108k
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Collecting cloud-tpu-client
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 3.3MB/s 
Uninstalling torch-1.7.0+cu101:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Found existing in

In [None]:
#hide
#colab
# %cd /content
# !ln -s /content/drive/MyDrive/fastai_xla_extensions  fastai_xla_extensions

In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast
!pip freeze | grep timesaver
!pip freeze | grep nbdev

torch==1.9.0a0+64847c7
torch-xla==1.9+7671584
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.9.0a0+b7f3c81
fastai==2.2.7
fastai-xla-extensions==0.0.8
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrelease==0.1.11
fastrlock==0.5
my-timesaver-utils==0.0.2
nbdev==1.1.13


In [None]:
# hide
# start of kernel

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
#colab
# %cd /content/fastai_xla_extensions

In [None]:
import torch_xla

In [None]:
#exporti
from fastai_xla_extensions.utils import xla_imported
from fastai_xla_extensions.misc_utils import *
from fastai_xla_extensions.multi_core.base import *
from fastai_xla_extensions.multi_core.learner import *

In [None]:
import torch_xla

In [None]:
#hide
#colab
# %cd /content

In [None]:
#exporti
try:
    import torch_xla
except:
    pass

In [None]:
#exporti
if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp

In [None]:
#hide
#local
# fake out torch_xla modules if not running on xla supported envs
if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#exporti
from fastai.vision.all import *
# from fastai_xla_extensions.all import *


In [None]:
#export
import torch
def maybe_item(o):
    '''extract scalar values from a tensor, lists and dicts of tensors 
    (and pulling it out of gpu/tpu into cpu) else if not tensor just 
    use orig value'''
    if isinstance(o,torch.Tensor): return o.item()
    if is_listy(o):
        kls = o.__class__
        k = [maybe_item(i) for i in o]
        return kls(k)
    if isinstance(o,dict):
        return {k:maybe_item(v) for k,v in o.items()}
    # maybe scalar or object
    return o


In [None]:
# from fastcore.test import *
# t1 = torch.tensor(5.)
# test_eq(maybe_item(t1), 5.)
# test_eq(maybe_item(float(5)),5.)

Given a tensor, `maybe_item` converts it to a scalar. If given is not a tensor (e.g. already a scalar), it just returns the scalar.

In [None]:
# from fastcore.test import *
# tl1 = [tensor(2.)] * 5
# test_eq(maybe_item(tl1), [2.] * 5)
# dt1 = { 'd1': tensor(3.),
#         'd2': [tensor(1.)] * 3}
# df1 = { 'd1': 3.,
#         'd2': [1.] * 3}
# test_eq(maybe_item(dt1), df1)

`maybe_item` should also work for lists of tensors and dicts of tensors
and/or list of tensors.

In [None]:
#export
@patch
def get_extra_attrs(self:Recorder):
    'Extract state attrs of Recorder into a dict (suitable for pickling)'
    # state_attrs = lrs','iters','losses','values'
    d = {}
    for attr in self._stateattrs:
        if hasattr(self,attr):
            value = getattr(self,attr)
            d[attr] = maybe_item(value)
    return d


In [None]:
#hide_output
# from fastai.test_utils import *
# learner = synth_learner()
# learner.fit(5)


In [None]:
#hide
# setup checks
# assert hasattr(learner,'recorder')
# assert len(learner.recorder.lrs)  == 5 * 10
# assert len(learner.recorder.losses) == 5 * 10
# assert len(learner.recorder.iters) == 5
# assert len(learner.recorder.values) == 5

In [None]:
# extra_attrs = learner.recorder.get_extra_attrs()
# test_eq(extra_attrs['lrs'], learner.recorder.lrs)
# test_eq(extra_attrs['losses'], learner.recorder.losses)
# test_eq(extra_attrs['iters'], learner.recorder.iters)
# test_eq(extra_attrs['values'], learner.recorder.values)


`Recorder.get_extra_attrs` should copy the state attrs (`lrs`,`losses`,`iters` and `values`) into
a dict.

In [None]:
#export
import pickle
@patch
def dump_attrs(self:Recorder, fn='_rec_attr.pkl'):
    'dump state attrs to a file'
    d = self.get_extra_attrs()
    with open(fn,'wb') as f:
        pickle.dump(d,f)


In [None]:
#export
import pickle
@patch
def reload_attrs(self:Recorder, fn='_rec_attr.pkl'):
    'reload attrs from file `fn`'
    if isinstance(fn,str):
        fn = Path(fn)
    if not fn.is_file():
        return
    with open(fn,'rb') as f:
        d = pickle.load(f)
        for k,v in d.items():
            setattr(self,k,v)


In [None]:
# test_fn = 'test_rec_attrs.pkl'
# !rm -f {test_fn}
# learner.recorder.dump_attrs(fn=test_fn)
# f = Path(test_fn)
# assert f.is_file()


In [None]:
# delattr(learner.recorder,'lrs')
# delattr(learner.recorder,'losses')
# delattr(learner.recorder,'iters')
# delattr(learner.recorder,'values')
# assert not hasattr(learner.recorder,'lrs')
# assert not hasattr(learner.recorder,'losses')
# assert not hasattr(learner.recorder,'iters')
# assert not hasattr(learner.recorder,'values')


In [None]:

# learner.recorder.reload_attrs(fn=test_fn)
# assert hasattr(learner.recorder,'lrs')
# assert hasattr(learner.recorder,'losses')
# assert hasattr(learner.recorder,'iters')
# assert hasattr(learner.recorder,'values')
# !rm -f {test_fn}

In [None]:
#export
@patch
def after_fit(self: Recorder):
    'after fit dump extra attrs to file'
    if xm.is_master_ordinal():
        self.dump_attrs()


In [None]:
#export
@patch
def dump_hps(self:ParamScheduler, fn='_paramsched_hps.pkl'):
    if not hasattr(self, 'hps'): 
        return

    if isinstance(fn,str):
        fn = Path(fn)

    d = maybe_item(self.hps)
    with open(fn,'wb') as f:
        pickle.dump(d,f)


In [None]:
#export
@patch
def reload_hps(self:Recorder, fn='_paramsched_hps.pkl'):
    'Load hyperparameters saved by ParamScheduler to recorder'
    if isinstance(fn,str):
        fn = Path(fn)
    if not fn.is_file():
        return
    with open(fn,'rb') as f:
        d = pickle.load(f)
        setattr(self,'hps',d)

In [None]:
#export
@patch
def after_fit(self:ParamScheduler):
    "save hps to file"
    if not hasattr(self,'hps'):
        return

    if hasattr(self.learn, 'recorder'): 
        self.recorder.hps = self.hps

    if xm.is_master_ordinal():
        self.dump_hps()


In [None]:
#hide_input
show_doc(ParamScheduler.after_fit)

<h4 id="ParamScheduler.after_fit" class="doc_header"><code>ParamScheduler.after_fit</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>ParamScheduler.after_fit</code>()

save hps to file

In [None]:
#hide_output
# param_fn = '_paramsched_hps.pkl'
# !rm -f {param_fn}
# learner.fit_one_cycle(3)

In [None]:
# param_f = Path(param_fn)
# assert param_f.is_file()


In [None]:
# delattr(learner.recorder,'hps')
# assert not hasattr(learner.recorder,'hps')
# learner.recorder.reload_hps()
# assert hasattr(learner.recorder,'hps')
# !rm -f {param_fn}
# !rm -f _rec_attr.pkl

Test ParamScheduler (`fit_one_cycle` uses `ParamScheduler`) which means 

In [None]:
class SkipValidationCallback(Callback):
    order,run_valid = -9, False
    # raise CancelValidException before XLATrainingCallback.before_validate
    # to prevent call to wrap_parallel_loader on before_validate
    def before_validate(self): 
        raise CancelValidException()

    def after_cancel_validate(self):
        xm.mark_step()


In [None]:
class XLALRFinder(ParamScheduler):
    "Training with exponentially growing learning rate"
    def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True):
        if is_listy(start_lr):
            self.scheds = {'lr': [SchedExp(s, e) for (s,e) in zip(start_lr,end_lr)]}
        else: self.scheds = {'lr': SchedExp(start_lr, end_lr)}
        self.num_it,self.stop_div = num_it,stop_div
        self.skip_batch = False
        self.num_losses = 0

    def before_fit(self):
        super().before_fit()
        # no need to save orig weights 
        # since learner instances are transient on spawned procs
        # self.learn.save('_tmp')
        self.best_loss = float('inf')
        self.skip_batch = False
        self.num_losses = 0
        # dont report losses while running lrfind (override sync_recorder)
        # run after sync_recorder.before_fit (sync_recorder.order == 55)
        # while param scheduler order == 60
        if xm.is_master_ordinal() and hasattr(self.learn, 'sync_recorder'):
            self.learn.logger = noop
            self.learn.sync_recorder._sync_stats_log = noop


    def before_batch(self):
        if self.skip_batch:
            return
        self._update_val(self.train_iter/self.num_it)

    def after_batch(self):
        if self.skip_batch:
            return
        super().after_batch()
        smooth_loss = self.smooth_loss.item() # move xla tensor to cpu
        self.num_loss = len(self.recorder.losses)
        if smooth_loss < self.best_loss:
            self.best_loss = smooth_loss

        # handle continuation of batch iteration until all batches exhausted
        if smooth_loss > 4*self.best_loss and self.stop_div:
            self.skip_batch = True
            return
            
        if self.train_iter >= self.num_it:
            self.skip_batch = True
            return

    def after_fit(self):
        # no need to load old weights since these will be transient
        # self.learn.opt.zero_grad() #Need to zero the gradients of the model before detaching the optimizer for future fits
        # tmp_f = self.path/self.model_dir/'_tmp.pth'
        # if tmp_f.exists():
        #     self.learn.load('_tmp', with_opt=True)
        #     os.remove(tmp_f)
        if not xm.is_master_ordinal(): return

        if not self.skip_batch: # completed w/o copying lrs and losses from recorder to plot_data
            self.num_loss = len(self.recorder.losses)

        self.recorder.losses = self.recorder.losses[: self.num_loss]
        self.recorder.lrs = self.recorder.lrs[: self.num_loss]
        num_iters = len(self.recorder.iters)
        for i, iter in enumerate(self.recorder.iters):
            if iter >= self.num_it:
                num_iters = i + 1
                break
        self.recorder.iters = self.recorder.iters[:num_iters]
        self.recorder.values = self.recorder.values[:num_iters]
        self.recorder.dump_attrs() # rewrite updated attrs



In [None]:
#export

def xla_run_lr_find(rank, learner_args, add_args, lr_find_args, ctrl_args):
    'run xla lr_find on spawned processes'
    xm.rendezvous('start_run_lrfind')
    # print(f'xla {rank} : start run lrfind')
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)
    # start_lr=1e-7 
    # end_lr=10 
    # num_it=100
    # stop_div=True
    # show_plot= False
    # suggestions= True
    num_it = lr_find_args['num_it']
    n_epoch = num_it//len(learner.dls.train) + 1
    # learner.opt = None
    # learner.create_opt()
    lr_find_cb = XLALRFinder(**lr_find_args)

    skip_valid_cb = SkipValidationCallback()
    
    with learner.no_logging(): 
        learner.fit(n_epoch, cbs=[lr_find_cb, skip_valid_cb])
    


In [None]:
@patch
def get_suggested_lrs(self:Learner, num_it):
    lrs,losses = tensor(self.recorder.lrs[num_it//10:-5]),tensor(self.recorder.losses[num_it//10:-5])
    if len(losses) == 0: return
    lr_min = lrs[losses.argmin()].item()
    grads = (losses[1:]-losses[:-1]) / (lrs[1:].log()-lrs[:-1].log())
    lr_steep = lrs[grads.argmin()].item()
    return SuggestedLRs(lr_min/10.,lr_steep)


In [None]:
@patch
@delegates(Learner.lr_find, but='num_cores,start_method')
def xla_lr_find(self:Learner, num_cores=8, start_method='fork', **kwargs):
    # default params for lr_find
    lr_find_args = {
        'start_lr': 1e-7,
        'end_lr': 10.,
        'num_it': 100,
        'stop_div': True
    }
    show_plot = True
    suggestions = True

    # remove show_plot and suggestions param
    if 'show_plot' in kwargs:
        show_plot = kwargs.pop('show_plot')
    if 'suggestions' in kwargs:
        suggestions = kwargs.pop('suggestions')
    # override default with kwargs
    lr_find_args = {**lr_find_args, **kwargs}    

    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    xmp.spawn(xla_run_lr_find,
              args=(learner_args, add_args, lr_find_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)
    
    self.recorder.reload_attrs()
    self.recorder.reload_hps()

    if show_plot:
        self.recorder.plot_lr_find()
    if suggestions:
        return self.get_suggested_lrs(lr_find_args['num_it'])

In [None]:
# path = untar_data(URLs.MNIST_TINY)
path = untar_data(URLs.MNIST)

In [None]:
data = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
    # splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
dls = data.dataloaders(path, bs=64)

In [None]:
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
# %%time
# spawn_lrfind(learner)
learner.xla_lr_find(stop_div=True,end_lr=100, num_it=400)

start fit


Exception in device=TPU:0: Resource exhausted: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) Resource exhausted: Failed to allocate request for 4.0KiB (4096B) on device ordinal 0
	 [[{{node XRTExecute}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted: Failed to allocate request for 4.0KiB (4096B) on device ordinal 0
	 [[{{node XRTExecute}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

	 [[XRTExecute_G15]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

0 successful operations.
0 derived errors ignored.
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessi

ProcessExitedException: ignored

In [None]:
learner.xla_fit_one_cycle(5, lr_max=slice(0.026))