In [None]:
#hide
#colab
# attach gdrive holding repo
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#default_exp multi_core.torch_compat

# Torch Compatible Utilities

> Torch Dataset and Dataloader compatible classes and functions for multi-core TPU training

<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03a_multi_core.torch_compat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#hide
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 43kB/s 
[K     |████████████████████████████████| 61kB 2.7MB/s 
[?25h

In [None]:
#hide
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 6.8MB/s 
[K     |████████████████████████████████| 61kB 5.8MB/s 
[?25h

In [None]:
#hide
#colab
# get profiling utils and callback
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git


  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!pip install -qqq nbdev

[?25l[K     |███████▏                        | 10kB 18.5MB/s eta 0:00:01[K     |██████████████▎                 | 20kB 10.5MB/s eta 0:00:01[K     |█████████████████████▍          | 30kB 10.7MB/s eta 0:00:01[K     |████████████████████████████▌   | 40kB 8.3MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 3.0MB/s 
[?25h

In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast
!pip freeze | grep timesaver
!pip freeze | grep nbdev

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5
my-timesaver-utils==0.0.2
nbdev==1.1.12


In [None]:
#hide
#colab
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

/content


In [None]:
#hide
# Start of kernel

In [None]:
#hide
#colab
%cd /content/fastai_xla_extensions

/content/drive/MyDrive/fastai_xla_extensions


In [None]:
#exporti
from fastai_xla_extensions.utils import xla_imported
from fastai_xla_extensions.multi_core.base import *
from fastai_xla_extensions.misc_utils import *




In [None]:
#hide
#colab
from nbdev.showdoc import *

In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass

In [None]:
#hide

if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#exporti
if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp

In [None]:
#exporti
from fastcore.basics import patch_to
import torch
import torch.utils.data as th_data
from fastcore.foundation import L
from pathlib import Path
from fastcore.transform import Pipeline
from fastai.data.core import DataLoaders
from pathlib import Path
from fastai.torch_core import find_bs, TensorBase
from fastai.torch_core import TensorBase
from fastcore.xtras import is_listy
import torch.utils.hooks
import torch.utils.data.distributed as th_distrib

In [None]:
#export
class TfmdTorchDS(th_data.Dataset):
    "A torch dataset compatible holder for items with x and y transforms"
    def __init__(self, items, x_tfm=None, y_tfm=None):
        self.items = items
        self.x_tfm = x_tfm
        self.y_tfm = y_tfm

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        item = self.items[index]
        x = self.x_tfm(item) if self.x_tfm is not None else item
        y = self.y_tfm(item) if self.y_tfm is not None else item
        return (x,y)

In [None]:
from fastcore.test import test_eq
def neg_tfm(o): return -o
def double_tfm(o): return 2*o
items = list(range(10))
ds1 = TfmdTorchDS(items, x_tfm=neg_tfm, y_tfm=double_tfm)
test_eq(ds1[5],(-5,10))


In [None]:
#exporti
import torchvision as thv
from operator import itemgetter
from fastcore.imports import noop

In [None]:
#export
def to_list(o):
    "return item o as a list (unchanged if o is already a list and empty list if o is None)"
    return [] if o is None else [o] if not is_listy(o) else o

def has_setup(tfms):
    """returns last index if at least 1 `tfm` in `tfms` has a method `setup` else return -1"""
    setups = L(tfms).attrgot('setup',None).argwhere(noop) # get indexes where tfm has `setup` attribute
    return -1 if len(setups) == 0 else setups[-1]

def run_setups(tfms, items):
    """run tfm setups including tfm for all items"""
    indx = has_setup(tfms)
    if indx == -1: # no setup found
        return

    for i,tfm in enumerate(tfms):
        if hasattr(tfm,'setup'):
            tfm.setup(items)
        if i < indx:
            # tfm items to be fed into next tfm
            items = [tfm(item) for item in items]

In [None]:
#export
class TorchDatasetBuilder:
    "build torch compatible train and test datasets with transforms"
    def __init__(self, source, get_items, splitter,
                x_tfms, y_tfms,
                x_type_tfms=None,
                x_train_tfms=None, x_test_tfms=None,
                do_setup=False):
        self.source = source
        self.get_items = get_items
        self.splitter = splitter
        self.do_setup = do_setup
        self.x_tfms = to_list(x_tfms)
        self.y_tfms = to_list(y_tfms)
        self.x_type_tfms = to_list(x_type_tfms)
        self.x_train_tfms = to_list(x_train_tfms)
        self.x_test_tfms = to_list(x_test_tfms)

    def setup(self, items, do_setup=None, setup_x=False):
        self.do_setup = do_setup if do_setup is not None else self.do_setup
        if self.do_setup:
            all_x_tfms = [*self.x_type_tfms, *self.x_train_tfms, *self.x_tfms]
            if setup_x:
                run_setups(all_x_tfms, items)
            run_setups(self.y_tfms, items)
            self.do_setup = False

    def get_datasets(self, do_setup=None):
        self.do_setup = do_setup if do_setup is not None else self.do_setup

        items = self.get_items(self.source) if self.get_items is not None else self.source

        train_idxs, test_idxs = self.splitter(items)

        train_items = itemgetter(*train_idxs)(items)
        test_items = itemgetter(*test_idxs)(items)
        self.setup(train_items)
        allx_test_tfms = [*self.x_type_tfms, *self.x_test_tfms, *self.x_tfms]
        allx_train_tfms = [*self.x_type_tfms, *self.x_train_tfms, *self.x_tfms]
        train_x_tfm = thv.transforms.Compose(allx_train_tfms)
        test_x_tfm = thv.transforms.Compose(allx_test_tfms)
        y_tfm = thv.transforms.Compose(self.y_tfms)
        train_ds = TfmdTorchDS(train_items, x_tfm=train_x_tfm, y_tfm=y_tfm)
        test_ds = TfmdTorchDS(test_items, x_tfm=test_x_tfm, y_tfm=y_tfm)
        return train_ds, test_ds

In [None]:
#export
from fastai.data.transforms import CategoryMap

class VocabularyMapper:
    """A simplified version of the fastai Categorize Transform"""
    def __init__(self, vocab=None):
        self.vocab = vocab
        self.c = 0
    def setup(self, items):
        self.vocab = CategoryMap(items)
        self.c = len(self.vocab)
    def __call__(self, o):
        if self.vocab is None: return o
        try:
            return torch.tensor(self.vocab.o2i[o])
        except KeyError as e:
            raise KeyError(f"Label '{o}' was not included in the training dataset") from e

In [None]:
import torchvision as thv

pil2tensor = thv.transforms.ToTensor()
resize28 = thv.transforms.Resize(28)
norm = thv.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

from fastai.vision.core import PILImage
from fastai.data.transforms import get_image_files, GrandparentSplitter, parent_label
from fastai.data.external import untar_data, URLs

path = untar_data(URLs.MNIST_TINY)
mnist_dset_builder =  TorchDatasetBuilder(
                source=path, 
                get_items=get_image_files, 
                splitter=GrandparentSplitter(),
                x_tfms=[resize28,pil2tensor,norm,], 
                y_tfms=[parent_label,VocabularyMapper(),],
                x_type_tfms=PILImage.create)

from fastcore.test import test_eq

train_ds, test_ds = mnist_dset_builder.get_datasets(do_setup=True)

test_eq(len(train_ds),709)
test_eq(len(test_ds),699)
test_eq(mnist_dset_builder.y_tfms[1].vocab, ('3','7'))
test_eq(mnist_dset_builder.y_tfms[1].c, 2)
test_eq(train_ds[0][1],mnist_dset_builder.y_tfms[1](parent_label(train_ds.items[0])))
test_eq(train_ds[0][0],norm(pil2tensor(resize28(PILImage.create(train_ds.items[0])))))

In [None]:
#export
import torch.utils.data as th_data
from fastcore.basics import patch_to
@patch_to(th_data.DataLoader)
def to(self, device):
    "move torch dataloader to device (for compatibility with fastai dataloader)"
    self.device = device

In [None]:
#hide_input
#colab
show_doc(th_data.DataLoader.to)

<h4 id="DataLoader.to" class="doc_header"><code>DataLoader.to</code><a href="__main__.py#L4" class="source_link" style="float:right">[source]</a></h4>

> <code>DataLoader.to</code>(**`device`**)

move torch dataloader to device (for compatibility with fastai dataloader)

In [None]:
#export
def make_torch_dataloaders(train_dataset, test_dataset,
                     rank,
                     world_size,
                     bs,
                     num_workers=4,
                     distrib=True,
                     sync_valid=False):
    "make torch-based distributed dataloaders from torch compatible datasets"
    if distrib:
        train_sampler = th_distrib.DistributedSampler(
            train_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True)
        train_loader = th_data.DataLoader(
            train_dataset,
            batch_size=bs,
            sampler=train_sampler,
            # shuffle=True,
            num_workers=num_workers,
            drop_last=True)

        if sync_valid:
            test_sampler = th_distrib.DistributedSampler(
                test_dataset,
                num_replicas=world_size,
                rank=rank,
                shuffle=False)

            test_loader = th_data.DataLoader(
                test_dataset,
                batch_size=bs,
                sampler=test_sampler,
                # shuffle=False,
                num_workers=num_workers,
                drop_last=True)
        else:
            test_loader = th_data.DataLoader(
                test_dataset,
                batch_size=bs,
                shuffle=False,
                num_workers=num_workers,
                drop_last=True)

    else:
        train_loader = th_data.DataLoader(
            train_dataset,
            batch_size=bs,
            # sampler=train_sampler,
            shuffle=True,
            num_workers=num_workers,
            drop_last=True)

        test_loader = th_data.DataLoader(
            test_dataset,
            batch_size=bs,
            shuffle=False,
            num_workers=num_workers,
            drop_last=True)
    dataloaders = DataLoaders(train_loader, test_loader, device=None)
    return dataloaders

In [None]:
#exporti
import re

In [None]:
#export
class FileNamePatternLabeller:
    "Delayed action version of fastai RegexLabeller with file name selection"
    def __init__(self, pat_str, match=False):
        self.pat_str = pat_str
        self.match = match
        self.matcher = None
        self.pat = None
    def __call__(self, f):
        if isinstance(f,str):
            f = Path(f)
        o = f.name
        if self.pat is None:
            self.pat = re.compile(self.pat_str)
            self.matcher = self.pat.match if self.match else self.pat.search
        res  = self.matcher(o)
        assert res, f'Failed to find "{self.pat}" in {o}'
        return res.group(1)

## Test Model Training using Torch Dataloaders

In [None]:
#colab
from fastai.vision.all import *
from fastai_xla_extensions.multi_core.base import *
from fastai_xla_extensions.misc_utils import * # patch _BaseOptimizer.__get_state__ and __setstate__
from my_timesaver_utils.profiling import *
from my_timesaver_utils.profiling_callback import *

In [None]:
#hide
#colab
%cd /content

/content


In [None]:
#colab
from fastai.learner import Learner
from fastai.metrics import accuracy

def train_torch_model(rank):
    torch.manual_seed(1)
    xm.rendezvous('start_train_torch_model')
    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()
    IS_PROFILING = FLAGS['is_profiling']
    SYNC_VALID = FLAGS['sync_valid']

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    moms =(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum'])
    wd = FLAGS['weight_decay']
    num_workers = FLAGS['num_workers']

    if IS_PROFILING:
        rec_name = 'rank' + str(rank) + '_dset_build'
        print(f'start {rec_name}')
        start_record(rec_name)
    dsets = DSET_BUILDER.get_datasets()
    if IS_PROFILING:
        end_record(rec_name)
        print_prof_data(rec_name)
        print(f'finished {rec_name}')

    if IS_PROFILING:
        rec_name2 = 'rank' + str(rank) + '_dataloader_build'
        print(f'start {rec_name2}')
        start_record(rec_name2)
    dls = make_torch_dataloaders(*dsets, 
                                  rank=rank, 
                                  world_size=world_size, 
                                  bs=bs,
                                  num_workers=num_workers,
                                  sync_valid=SYNC_VALID,
                                 )

    if IS_PROFILING:
        end_record(rec_name2)
        print_prof_data(rec_name2)
        print(f'finished {rec_name2}')

    xm.master_print('build learner')
    learner = Learner(dls, model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      wd=wd,
                      moms=moms
                      )
                      
    learner.to_multi_xla(device, rank=xm.get_ordinal(), sync_valid=SYNC_VALID)
    if rank == 0 and IS_PROFILING:
        learner.to_my_profile()
                               
    epochs = FLAGS['num_epochs']
    xm.master_print('start running fit')
    learner.unfreeze()

    if IS_PROFILING:
        rec_name3 = 'rank' + str(rank) + '_run_fit'
        print(f'start {rec_name3}')
        start_record(rec_name3)
    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))

    if IS_PROFILING:
        end_record(rec_name3)
        print_prof_data(rec_name3)
        print(f'finished {rec_name3}')

    learner.save('stage-1')
    if rank == 0 and IS_PROFILING:
        learner.my_profile.print_stats()
    xm.mark_step() 



In [None]:
#colab
# Start training processes
def _mp_fn2(rank, flags):
    global FLAGS
    FLAGS = flags
    train_torch_model(rank)



In [None]:
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock

In [None]:
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats

In [None]:
import torch.nn as nn
LOSS_FUNC = nn.CrossEntropyLoss()

In [None]:
from fastai.optimizer import Adam
OPT_FUNC = Adam

In [None]:
from fastai.data.transforms import RandomSplitter

In [None]:
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet34

In [None]:
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['sync_valid'] = True
FLAGS['is_profiling'] = True
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 1e-3
FLAGS['image_size'] = 224
FLAGS['momentum'] = 0.85
FLAGS['weight_decay'] = 2e-3
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1

# FLAGS['num_cores'] = 1 
ARCH = resnet34
USE_DBLOCK = False

In [None]:
from pathlib import Path
from fastcore.xtras import *


In [None]:
#colab
PATH = untar_data(URLs.PETS)/'images'
# PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)


In [None]:
#colab
imagenet_norm = thv.transforms.Normalize(
    mean=(0.485, 0.456, 0.406), 
    std=(0.229, 0.224, 0.225))

cifar_norm = thv.transforms.Normalize(
    mean=(0.4914, 0.4822, 0.4465), 
    std=(0.2023, 0.1994, 0.2010))

image_size = FLAGS['image_size']
splitter = RandomSplitter(seed=42)
pat = r'(.+)_\d+.jpg$'
fname_labeller = FileNamePatternLabeller(pat)

DSET_BUILDER = TorchDatasetBuilder(
    PATH, 
    get_items=get_image_files,
    splitter=splitter,
    x_tfms=[thv.transforms.Resize((image_size,image_size)), thv.transforms.ToTensor(), imagenet_norm],
    y_tfms=[fname_labeller, VocabularyMapper(),],
    x_type_tfms=PILImage.create,
) 
start_record('master_vocab_setup')
DSET_BUILDER.setup(get_image_files(PATH),do_setup=True)
end_record('master_vocab_setup')
print_prof_data('master_vocab_setup')
clear_prof_data()
N_OUT = DSET_BUILDER.y_tfms[1].c     

Function master_vocab_setup called 1 times.
Execution time max: 0.055, average: 0.055


In [None]:
#colab
assert N_OUT is not None and N_OUT > 0,f'N_OUT {N_OUT} should be > 0'

In [None]:
#colab
custom_model = create_cnn_model(ARCH, N_OUT, 
                                pretrained=True,
                                concat_pool=False)


Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




In [None]:
#colab
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)

In [None]:
#colab
%%time
FLAGS['is_profiling'] = False
# !rm -f /content/models/stage-1.pth
xmp.spawn(_mp_fn2, args=(FLAGS,), nprocs=FLAGS['num_cores'],
        start_method='fork')


build learner
start running fit
start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.790961,1.388766,0.628906,01:28
1,0.661789,1.147588,0.726562,01:19
2,0.602246,0.583778,0.830078,01:21
3,0.505505,0.392288,0.873047,01:22
4,0.42194,0.340927,0.892578,01:25


CPU times: user 112 ms, sys: 131 ms, total: 243 ms
Wall time: 8min 2s


In [None]:
#colab
mdsets = DSET_BUILDER.get_datasets()
mdls = make_torch_dataloaders(*mdsets,
                                rank=0,
                                world_size=1,
                                bs=FLAGS['batch_size'],
                                num_workers=FLAGS['num_workers']
                                )

In [None]:
#colab
mlearner = Learner(mdls, custom_model, 
                    loss_func=LOSS_FUNC, 
                    opt_func=OPT_FUNC, 
                    metrics=accuracy, 
                    wd=FLAGS['weight_decay'],
                    moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
mlearner.load('stage-1');

In [None]:
#colab
mlearner.dls.device

In [None]:
from fastai.torch_core import one_param

In [None]:
#colab
one_param(mlearner.model).device

device(type='cpu')

In [None]:
#colab
%%time
valid_metrics = mlearner.validate();print(valid_metrics)

[0.2993701994419098, 0.901494562625885]
CPU times: user 3min 30s, sys: 3.18 s, total: 3min 33s
Wall time: 3min 38s
