In [None]:
#hide
#colab
# attach gdrive holding repo
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#default_exp multi_core.inference

# Multi Core XLA Inference 



<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03e_multi_core.inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Multi Core XLA Extensions for inference

Multi-core TPU implementation for inference is enabled by importing this module.
```
from fastai_xla_extensions.multi_core.inference import *
```

In [None]:
#hide
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 80kB/s 
[K     |████████████████████████████████| 61kB 405kB/s 
[31mERROR: earthengine-api 0.1.254 has requirement google-api-python-client>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
[?25h

In [None]:
#hide
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 5.4MB/s 
[K     |████████████████████████████████| 61kB 3.8MB/s 
[K     |████████████████████████████████| 776.8MB 19kB/s 
[K     |████████████████████████████████| 12.8MB 44.7MB/s 
[31mERROR: torchtext 0.9.0 has requirement torch==1.8.0, but you'll have torch 1.7.1 which is incompatible.[0m
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!pip install -qqq nbdev --upgrade

[K     |████████████████████████████████| 51kB 3.0MB/s 
[K     |████████████████████████████████| 51kB 3.9MB/s 
[?25h

In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast
!pip freeze | grep timesaver
!pip freeze | grep nbdev

torch==1.7.1
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.9.0
torchvision==0.8.2
fastai==2.2.7
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrelease==0.1.11
fastrlock==0.5
my-timesaver-utils==0.0.2
nbdev==1.1.13


In [None]:
#hide
#colab
# link repo to work dir
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

/content


In [None]:
#hide
# <!-- Start of kernel -->

In [None]:
#hide
#colab
%cd /content/fastai_xla_extensions

/content/drive/MyDrive/fastai_xla_extensions


In [None]:
#hide

from nbdev.showdoc import *

In [None]:
#export
try:
    import torch_xla
except ImportError:
    pass

In [None]:
#export

#from fastai.vision.all import *
from fastai_xla_extensions.utils import xla_imported
from fastai_xla_extensions.misc_utils import *
from fastai_xla_extensions.core import XLAOptCallback
from fastai_xla_extensions.multi_core.base import *
from fastai_xla_extensions.multi_core.learner import *
from fastai_xla_extensions.multi_core.callback import *



In [None]:
#hide
#local

# fake out torch_xla modules if not running on xla supported envs
if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    def fake_is_master_ordinal(*args,**kwargs): 
        return True
    def fake_maybe_convert_to_cpu(data,*args,**kwargs):
        return data

    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal,
        is_master_ordinal = fake_is_master_ordinal,
        _maybe_convert_to_cpu = fake_maybe_convert_to_cpu
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakePerDeviceLoader:
        def __init__(self, *args):
            pass
        def close(self):
            pass
            
    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader,
        PerDeviceLoader = FakePerDeviceLoader

    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#export

if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp

In [None]:
#export
from fastai.vision.all import *


## Implement Multi Core TPU Inference

In [None]:
#export
from fastai.learner import _ConstantFunc
# from fastcore.basics import patch
# from fastai.learner import Learner

@patch
def inner_get_preds(self:Learner, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,
                inner=False, reorder=True, cbs=None, **kwargs):
    
    xla_rank = getattr(self,'xla_rank',None)
    if xla_rank is None:
        return

    if dl is None: 
        dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
    else:
        try: len(dl)
        except TypeError as e:
            raise TypeError("`dl` is something other than a single `DataLoader` object")
        if not isinstance(dl, TPUDistributedDL):            
            world_size = kwargs.pop('world_size', xm.xrt_world_size())
            seed = kwargs.pop('dl_seed',42)
            dl = TPUDistributedDL(dl, xla_rank, world_size=world_size, seed=seed)

    if reorder and hasattr(dl, 'get_idxs'):
        idxs = dl.dl.get_idxs()
        dl = dl.new(get_idxs = _ConstantFunc(idxs))
        rank_idxs = dl.get_idxs()
        rank_idxs_len = len(rank_idxs)

    cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
    ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)
    if with_loss: 
        ctx_mgrs.append(self.loss_not_reduced())
    
    with ContextManagers(ctx_mgrs):
        self._do_epoch_validate(dl=dl)
       
        if act is None: 
            act = getattr(self.loss_func, 'activation', noop)

        res = cb.all_tensors()
        
        pred_i = 1 if with_input else 0
        if res[pred_i] is not None:
            if act != noop:                
                # compute activation on tpu device and detach after
                tmp_pred = res[pred_i].to(xm.xla_device())
                tmp_res = act(tmp_pred)
                res[pred_i] = self.to_detach(tmp_res)
                
            if with_decoded:
                res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))

        if reorder and hasattr(dl, 'get_idxs'):
            t_idxs = tensor(rank_idxs)
            start_idx = xla_rank * rank_idxs_len
            t_idxs = t_idxs - tensor(start_idx) # broadcast
            sorted_idxs = t_idxs.argsort()
            res = nested_reorder(res, sorted_idxs )
        
        return tuple(res)
    self._end_cleanup()


In [None]:
#export

from fastai.learner import CancelValidException

@patch
def before_validate(self:XLATrainingCallback):
    "Set the model in validation mode"
    if not getattr(self.learn,'inner_xla',False):
        return # skip if not spawned
    
    if self.rank != 0 and not self.sync_valid:
    # no need to compute valid loss/ metric if not master if not sync valid
        raise CancelValidException()

    if not isinstance(self.learn.dl, pl.PerDeviceLoader):
        self.learn.dl = wrap_parallel_loader(self.learn.dl, self.pdevice)



In [None]:
#export
@patch
def new(self:TPUDistributedDL, dataset=None, cls=None, **kwargs):
    new_dl = self.dl.new(dataset=dataset, cls=cls, **kwargs)
    use_rank = self.rank
    use_size = self.world_size
    seed = self.seed

    new_dl = TPUDistributedDL(new_dl,
                        rank=use_rank,
                        world_size=use_size, 
                        seed=seed)
    
    return new_dl

In [None]:
#export

def setup_inference_args(rank, inference_args):
    master_cbs = ifnone(inference_args.pop('master_cbs', None),[])
    return inference_args, master_cbs


In [None]:
#export

import pickle
def save_pred_results(rank, results):
    fn = f'preds{rank}.pkl'
    fn = Path(fn)
    with open(fn,'wb') as f:
        pickle.dump(results, f)

In [None]:
#export

def xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args):
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)
    pred_args, master_cbs = setup_inference_args(rank, inference_args)

    if rank == 0 and len(master_cbs) > 0:
        learner.add_cbs(master_cbs)

    learner.synced_cancel.before_fit()

    if rank == 0:
        learner.sync_recorder.orig_logger = learner.logger

    results = learner.inner_get_preds(**pred_args)
    xm.rendezvous('xla_run_inference')

    save_pred_results(rank, results)
    xm.mark_step()
    

In [None]:
#export
from fastcore.foundation import L

def reload_pred_results(num_files, n_samples):
    all_preds = L()
    for rank in range(num_files):
        fn = f'preds{rank}.pkl'

        fn = Path(fn)
        if fn.is_file():
            with open(fn,'rb') as f:
                rank_preds = pickle.load(f)
                all_preds.append(rank_preds)
        else:
            raise RuntimeException(f'Missing preds file for rank {rank}')

    for rank in range(num_files):
        fn = f'preds{rank}.pkl'
        fn = Path(fn)
        fn.unlink()

    n_items = len(all_preds[0]) # num items per preds

    all_res = []
    for i in range(n_items):
        items = all_preds.itemgot(i)

        if isinstance(items[0], torch.Tensor):
            all_items = torch.cat(tuple(items))
        elif is_listy(items[0]):
            all_items = [*items]
        else:
            all_items = items
        all_res.append(all_items)
    res = []
    for i, pred in enumerate(all_res):
        pred = pred[:n_samples] # take only first 
        res.append(pred)  
    return res
        


In [None]:
#export

@patch
def pre_xla_inference(self:Learner):
    ctrl_args = {}
    progress_removed = False
    if 'progress' in L(self.cbs).attrgot('name'):
        self.remove_cbs(ProgressCallback)
        progress_removed = True
    ctrl_args['use_progress'] = progress_removed
    return ctrl_args

In [None]:
#export

@patch
def post_xla_inference(self:Learner, ctrl_args):
    if ctrl_args['use_progress']:
        self.add_cbs(ProgressCallback)
    self.recorder.reload_attrs()

In [None]:
#export

def prep_inference_args(**kwargs):
    return kwargs

In [None]:

#export

@patch
@delegates(Learner.get_preds, but='num_cores,start_method,master_cbs')
def xla_get_preds(self:Learner, ds_idx=1, dl=None, 
                  with_input=False, with_decoded=False, 
                  with_loss=False, act=None, inner=False, 
                  reorder=True, cbs=None, num_cores=8, 
                  start_method='fork', master_cbs=None,**kwargs):
    ctrl_args = self.pre_xla_inference()
    learner_args, add_args = self.pack_learner_args()

    inference_args = prep_inference_args(ds_idx=ds_idx, dl=dl, 
                                         with_input=with_input, with_decoded=with_decoded, 
                                         with_loss=with_loss,
                                         act=act, inner=inner, 
                                         reorder=reorder, 
                                         cbs=cbs, master_cbs=master_cbs, **kwargs)
    if dl:
        n_results = len(dl.dataset)
    else:
        n_results = len(self.dls.loaders[ds_idx].dataset)

    xmp.spawn(xla_run_inference,
              args=(learner_args, add_args, inference_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    all_results = reload_pred_results(num_cores, n_results)
    self.post_xla_inference(ctrl_args)
    return all_results
    

## Testout the code

In [None]:
#colab
path = untar_data(URLs.MNIST)
# path = untar_data(URLs.PETS)/'images'

In [None]:
#colab
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
    item_tfms=Resize(28),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
# pat = r'(.+)_\d+.jpg$'
# data = DataBlock(
#     blocks=(ImageBlock,CategoryBlock),
#     get_items=get_image_files,
#     get_y=using_attr(RegexLabeller(pat),'name'),
#     splitter=RandomSplitter(seed=42),
#     item_tfms=Resize(224),
#     batch_tfms=[Normalize.from_stats(*imagenet_stats)]
# )

In [None]:
#colab
dls = data.dataloaders(path, bs=64)

In [None]:
#colab
# loss_func=nn.CrossEntropyLoss()
loss_func=CrossEntropyLossFlat()

In [None]:
#colab
learner = cnn_learner(dls, resnet18, metrics=accuracy, loss_func=loss_func, concat_pool=False)
# learner = cnn_learner(dls, resnet34, metrics=accuracy, loss_func=loss_func, concat_pool=False)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
#colab
learner.xla_fit_one_cycle(3, lr_max=slice(3e-2))

start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.557177,0.198589,0.9401,02:26
1,0.229188,0.099969,0.9685,02:14
2,0.141549,0.079136,0.9745,02:20


In [None]:
#colab
learner.unfreeze()

In [None]:
# learner.load('pets-stage-3')

In [None]:
# learner.save('pets-stage-2')

In [None]:
#colab
learner.xla_fit_one_cycle(5,lr_max=slice(1e-6,2e-4))

start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.08769,0.068158,0.977,02:28
1,0.058692,0.069429,0.9772,02:48
2,0.080852,0.059883,0.9807,03:07
3,0.086271,0.059273,0.9817,03:24
4,0.077349,0.055652,0.9817,03:08


In [None]:
# learner.save('pets-stage-3')

In [None]:
# %%time
# learner.validate()

In [None]:
#colab
%%time
res = learner.get_preds()

CPU times: user 38.2 s, sys: 1.09 s, total: 39.3 s
Wall time: 41.7 s


In [None]:
#colab
print(len(res))
print(res[0].shape, res[1].shape)

2
torch.Size([10000, 10]) torch.Size([10000])


In [None]:
#colab
print(accuracy(*res))

TensorBase(0.9813)


In [None]:
#colab
res[1][:10]

TensorCategory([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
#colab
%%time
res2 = learner.get_preds(reorder=False)

CPU times: user 37.8 s, sys: 1.02 s, total: 38.8 s
Wall time: 40.6 s


In [None]:
#colab
print(len(res2))
print(res2[0].shape, res2[1].shape)

2
torch.Size([10000, 10]) torch.Size([10000])


In [None]:
#colab
res2[1][:10]

TensorCategory([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
#colab
print(accuracy(*res2))

TensorBase(0.9813)


In [None]:
#colab
%%time
xla_res = learner.xla_get_preds(reorder=False)

start fit
CPU times: user 187 ms, sys: 153 ms, total: 340 ms
Wall time: 42.6 s


In [None]:
#colab
print(len(xla_res))

2


In [None]:
#colab
(xla_res[0].shape, xla_res[1].shape)

(torch.Size([10000, 10]), torch.Size([10000]))

In [None]:
#colab
xla_res[1][:10]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
#colab
xla_res[0][:10]

tensor([[7.3227e-05, 9.9777e-01, 1.8993e-05, 1.4147e-05, 4.4221e-04, 4.3196e-04,
         1.2528e-04, 1.8559e-04, 7.8385e-04, 1.5749e-04],
        [3.7930e-06, 9.9990e-01, 1.0723e-06, 1.7454e-06, 1.3211e-05, 9.3661e-06,
         2.8669e-06, 2.6495e-05, 4.4178e-05, 1.5767e-06],
        [1.5091e-04, 9.9920e-01, 1.1045e-04, 5.1943e-05, 4.7102e-05, 1.9531e-05,
         2.2440e-04, 1.1227e-04, 5.6824e-05, 2.4695e-05],
        [1.0292e-05, 9.9979e-01, 1.9649e-05, 1.1793e-06, 1.8352e-05, 3.1917e-06,
         7.2904e-06, 1.9972e-05, 1.1706e-04, 8.7981e-06],
        [7.8981e-06, 9.9990e-01, 9.7631e-06, 2.2887e-06, 1.3191e-05, 2.5938e-06,
         2.3175e-05, 2.3824e-05, 1.0166e-05, 2.3397e-06],
        [3.0717e-05, 9.9962e-01, 4.3642e-05, 1.5713e-05, 3.4024e-05, 3.8474e-05,
         1.2209e-04, 6.3268e-05, 1.3420e-05, 1.9892e-05],
        [1.3539e-04, 9.9956e-01, 4.1930e-06, 1.8683e-06, 1.1204e-05, 2.9642e-06,
         2.3243e-04, 1.3482e-05, 3.8045e-05, 3.7877e-06],
        [1.4129e-06, 9.9996

In [None]:
#colab
print(accuracy(*xla_res))

TensorBase(0.9825)


In [None]:
#colab
%%time
xla_res2 = learner.xla_get_preds(reorder=True)

start fit
CPU times: user 160 ms, sys: 272 ms, total: 431 ms
Wall time: 31.8 s


In [None]:
#colab
print(len(xla_res2))

2


In [None]:
#colab
(xla_res2[0].shape, xla_res2[1].shape)

(torch.Size([10000, 10]), torch.Size([10000]))

In [None]:
#colab
xla_res2[1][:10]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
#colab
xla_res2[0][:10]

tensor([[7.3227e-05, 9.9777e-01, 1.8993e-05, 1.4147e-05, 4.4221e-04, 4.3196e-04,
         1.2528e-04, 1.8559e-04, 7.8385e-04, 1.5749e-04],
        [3.7930e-06, 9.9990e-01, 1.0723e-06, 1.7454e-06, 1.3211e-05, 9.3661e-06,
         2.8669e-06, 2.6495e-05, 4.4178e-05, 1.5767e-06],
        [1.5091e-04, 9.9920e-01, 1.1045e-04, 5.1943e-05, 4.7102e-05, 1.9531e-05,
         2.2440e-04, 1.1227e-04, 5.6824e-05, 2.4695e-05],
        [1.0292e-05, 9.9979e-01, 1.9649e-05, 1.1793e-06, 1.8352e-05, 3.1917e-06,
         7.2904e-06, 1.9972e-05, 1.1706e-04, 8.7981e-06],
        [7.8981e-06, 9.9990e-01, 9.7631e-06, 2.2887e-06, 1.3191e-05, 2.5938e-06,
         2.3175e-05, 2.3824e-05, 1.0166e-05, 2.3397e-06],
        [3.0717e-05, 9.9962e-01, 4.3642e-05, 1.5713e-05, 3.4024e-05, 3.8474e-05,
         1.2209e-04, 6.3268e-05, 1.3420e-05, 1.9892e-05],
        [1.3539e-04, 9.9956e-01, 4.1930e-06, 1.8683e-06, 1.1204e-05, 2.9642e-06,
         2.3243e-04, 1.3482e-05, 3.8045e-05, 3.7877e-06],
        [1.4129e-06, 9.9996

In [None]:
#colab
print(accuracy(*xla_res2))

TensorBase(0.9825)
