In [None]:
#hide
#colab
# attach gdrive holding repo
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#default_exp multi_core.learner

# Multi Core XLA Learner extensions

<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/03b_multi_core.learner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Learner method patches to invoke multi-core `fit` and other operations prefixed by `xla_`. 

> These provide an alternate way to run multi core operations with minimal changes to existing fastai notebooks.

In [None]:
#hide
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 93kB/s 
[K     |████████████████████████████████| 61kB 2.9MB/s 
[?25h

In [None]:
#hide
#colab
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 4.7MB/s 
[K     |████████████████████████████████| 61kB 3.4MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -qqq nbdev

[K     |████████████████████████████████| 51kB 2.0MB/s 
[K     |████████████████████████████████| 51kB 2.9MB/s 
[?25h

In [None]:
#hide
#colab
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [None]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [None]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.1+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.2+cu101
fastai==2.2.7
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrelease==0.1.11
fastrlock==0.5


In [None]:
#hide
#colab
%cd /content
!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions

/content


In [None]:
#hide
# Start of kernel

In [None]:
#hide
#colab
%cd /content/fastai_xla_extensions

/content/drive/MyDrive/fastai_xla_extensions


In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass

In [None]:
#export
from fastai_xla_extensions.utils import xla_imported



In [None]:
#export
from fastai_xla_extensions.multi_core.base import *
from fastai_xla_extensions.misc_utils import *
from fastai_xla_extensions.multi_core.callback import *

In [None]:
#exporti
# import sys
# def xla_imported():
#     return 'torch_xla' in sys.modules

In [None]:
#hide
#colab
%cd /content

/content


In [None]:
#hide

if not xla_imported():
    # replace torch xla modules with fake equivalents
    from types import SimpleNamespace
    torch_xla = SimpleNamespace (
    )
    from typing import Union,BinaryIO
    import os
    import pickle
    import torch.cuda

    def fake_opt_step(opt,barrier=False):
        opt.step()
        
    def fake_device(n=None, devkind=None):
        gpu_available = torch.cuda.is_available()
        if gpu_available:
            return torch.device(torch.cuda.current_device()) 
        return torch.device('cpu')

    def fake_save(obj, f: Union[str, os.PathLike, BinaryIO], 
                master_only=True, global_master=False): 
        return torch.save(obj,f,pickle_module=pickle, 
                        pickle_protocol=2, 
                        _use_new_zipfile_serialization=True)
    def fake_rate():
        return 230.20

    def fake_global_rate():
        return 830.10

    def fake_add(*args,**kwargs):
        pass

    def fake_RateTracker():
        return SimpleNamespace(
            rate = fake_rate,
            global_rate = fake_global_rate,
            add = fake_add
        )
    def fake_xrt_world_size():
        return 1
    def fake_get_ordinal():
        return 0
    xm = SimpleNamespace(
        optimizer_step = fake_opt_step,
        xla_device = fake_device,
        save = fake_save,
        RateTracker = fake_RateTracker,
        master_print = print,
        xrt_world_size = fake_xrt_world_size,
        get_ordinal = fake_get_ordinal
    )

    def fake_metrics_report():
        return "Fake Metrics Report \n\n\n\n"
    met = SimpleNamespace (
        metrics_report = fake_metrics_report
    )

    class FakeParallelLoader:
        def __init__(self, loader, *args):
            self.loader = loader
        def per_device_loader(self,device):
            return self.loader
        
    pl = SimpleNamespace(
        ParallelLoader = FakeParallelLoader
    )

    def fake_MpModelWrapper(o):
        return o

    def fake_run(f,*args, **kwargs):
            return f(*args,**kwargs)
        
    def fake_MpSerialExecutor():
        return SimpleNamespace(
            run = fake_run
        )
    def fake_spawn(f, args=None, nprocs=0, start_method=None):
        return f(0,*args)

    xmp = SimpleNamespace (
        MpModelWrapper = fake_MpModelWrapper,
        MpSerialExecutor = fake_MpSerialExecutor,
        spawn = fake_spawn
    )

    xu = SimpleNamespace (
    )


In [None]:
#exporti
if xla_imported():
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp

## Add master_cbs property to Learner 

> Master callbacks are callbacks that will be executed on the master ordinal (rank 0 thread) only.

This means existing fastai notebooks must be checked if any additional callbacks used
can cause conflicts if run on different threads at the same time.

Note that for default callbacks (`TrainEvalCallback`, `Recorder`, `ProgressCallback`) only `ProgressCallback` causes this problem. 

However, the `fastai_xla_extensions.multi_core.base` module already handles
this so that if used (which it is, by default), the `ProgressCallback` is attached only on the master ordinal thread.

Moreover, the `Recorder` callback is also handled such that validation losses and metrics are collated correctly by the `fastai_xla_extensions.multi_core.base.SyncRecorderCallback` so that the validation metrics and losses are reported correctly at the end of each epoch.


In [None]:
#exporti
from fastcore.basics import patch
from fastai.learner import Learner
from fastcore.meta import delegates
from fastcore.foundation import L


In [None]:
#export 
from fastai.learner import Learner
from fastcore.basics import patch
@patch(as_prop=True)
def master_cbs(self:Learner):
    "list all cbs to be run on the master ordinal thread"
    if not hasattr(self,'_master_cbs'):
        self._master_cbs = L()
    return self._master_cbs


In [None]:
#hide_input
from fastai.learner import Learner
show_doc(Learner.master_cbs)

<h4 id="Learner.master_cbs" class="doc_header"><code>Learner.master_cbs</code><a href="" class="source_link" style="float:right">[source]</a></h4>

list all cbs to be run on the master ordinal thread

In [None]:
#export
@patch
def add_master_cb(self:Learner, cb):
    "add a master callback"
    if not hasattr(self,'_master_cbs'):
        self._master_cbs = L()
    if isinstance(cb, type): cb = cb()
#     cb.learn = self
#     setattr(self, cb.name, cb)
    self._master_cbs.append(cb)
    
@patch
def add_master_cbs(self:Learner, cbs):
    "add master callbacks"
    L(cbs).map(self.add_master_cb)    



In [None]:
#hide_input
show_doc(Learner.add_master_cb)
show_doc(Learner.add_master_cbs)

<h4 id="Learner.add_master_cb" class="doc_header"><code>Learner.add_master_cb</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.add_master_cb</code>(**`cb`**)

add a master callback

<h4 id="Learner.add_master_cbs" class="doc_header"><code>Learner.add_master_cbs</code><a href="__main__.py#L12" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.add_master_cbs</code>(**`cbs`**)

add master callbacks

In [None]:
#export

@patch
def grab_master_cbs(self:Learner, cb_cls):
    "find instance of `cb_cls` in master_cbs"
    return L(cb for cb in self._master_cbs if isinstance(cb, cb_cls))

@patch
def remove_master_cb(self:Learner, cb):
    "remove a cb from master callbacks"
    if isinstance(cb, type): self.remove_master_cbs(self.grab_master_cbs(cb))
    else:
#         cb.learn = None
#         if hasattr(self, cb.name): delattr(self, cb.name)
        if cb in self._master_cbs: self._master_cbs.remove(cb)
    return self

@patch
def remove_master_cbs(self:Learner, cbs):
    "remove callbacks from master callbacks"
    L(cbs).map(self.remove_master_cb)
    return self

In [None]:
#hide_input
show_doc(Learner.grab_master_cbs)
show_doc(Learner.remove_master_cbs)
show_doc(Learner.remove_master_cb)

<h4 id="Learner.grab_master_cbs" class="doc_header"><code>Learner.grab_master_cbs</code><a href="__main__.py#L3" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.grab_master_cbs</code>(**`cb_cls`**)

find instance of `cb_cls` in master_cbs

<h4 id="Learner.remove_master_cbs" class="doc_header"><code>Learner.remove_master_cbs</code><a href="__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.remove_master_cbs</code>(**`cbs`**)

remove callbacks from master callbacks

<h4 id="Learner.remove_master_cb" class="doc_header"><code>Learner.remove_master_cb</code><a href="__main__.py#L8" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.remove_master_cb</code>(**`cb`**)

remove a cb from master callbacks

## Utility methods to implement XLA `fit` methods

In [None]:
#export

from fastai.callback.progress import ProgressCallback
from fastai.learner import Learner

def make_xla_child_learner(rank, sync_valid,learner_args, add_args, ctrl_args):
    "create a learner using passed parameters"
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
    dls = build_distributed_dataloaders(learner_args.pop('base_dls'),
                                       rank, world_size, sync_valid=sync_valid)

    model = learner_args.pop('wrapped_model').to(device)
    master_cbs = learner_args.pop('master_cbs')
    if master_cbs is None:
        master_cbs = L()
    learner = Learner(dls, model,**learner_args)
    learner.__stored_args__ = {**learner.__stored_args__, **add_args}

    learner.to_multi_xla(device, rank, sync_valid=sync_valid)
    
    if not ctrl_args['use_progress'] and 'progress' in L(learner.cbs).attrgot('name'):
        learner.remove_cbs(ProgressCallback)
        
    if rank == 0: 
        learner.add_cbs(master_cbs)
        
    return learner

In [None]:
#export
def setup_fit_cbs(rank, fit_args):
    "add master cbs to cbs fit args if rank 0"
    master_cbs = L(fit_args.pop('master_cbs'))
    if rank != 0:
        master_cbs = L()
    if 'cbs' in fit_args:
        cbs = L(fit_args.pop('cbs'))
    else:
        cbs = L()
    if len(master_cbs) > 0 or len(cbs) > 0: 
        fit_args['cbs'] = [*cbs, *master_cbs]  
    return fit_args

In [None]:
#export
def xla_run_method(rank, fit_method, learner_args, add_args, fit_args, ctrl_args):
    "run fit method on spawned process"
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)    
    fit_args = setup_fit_cbs(rank, fit_args)
    fit_method(learner, **fit_args)
    xm.rendezvous('xla_run_method')
    learner.save('_xla_tmp_model',rendezvous=False)
    xm.mark_step()
    

In [None]:
#export
from fastcore.basics import defaults, patch_to, patch

_extra_args = ['concat_pool', 'arch', 'n_out', 'pretrained','normalize']

@patch
def pack_learner_args(self:Learner):
    "pack learner args into dict to pass to spawned process"
    learner_args = {**self.__stored_args__}
    learner_args['wrapped_model'] =  xmp.MpModelWrapper(self.model)
    learner_args['base_dls'] = self.dls
    # fetch only cbs not in defaults
    if ProgressCallback not in defaults.callbacks:
        defaults.callbacks.append(ProgressCallback)
    default_cbs = [cls() for cls in defaults.callbacks]
    learner_args['cbs'] = [cb for cb in self.cbs
                      if cb.name not in L(default_cbs).attrgot('name')]
    
    learner_args['master_cbs'] = self.master_cbs 
    
    # remove extra args from learner args (in __stored_args__ but not in init args)
    add_args = {}
    for arg in _extra_args:
        if arg in learner_args:
            add_args[arg] = learner_args.pop(arg)
    return learner_args, add_args

In [None]:
#export
import os

@patch
def reload_child_model(self:Learner):
    "reload model built by spawned processes"
    # blatantly stolen from fastai LRFinder after_fit :)
    tmp_f = self.path/self.model_dir/'_xla_tmp_model.pth'
    if tmp_f.exists():
        self.opt.zero_grad()
        self.load('_xla_tmp_model', with_opt=False)
        os.remove(tmp_f)
        self.create_opt()

In [None]:
#export

from fastcore.foundation import L
from pathlib import Path

tmp_files = ['_paramsched_hps.pkl', '_rec_attr.pkl']
@patch
def delete_tmp_files(self:Learner):
    '''remove files created by spawned process prior to 
    potentially recreating them'''
    for fn in tmp_files:
        fn = Path(fn)
        if fn.is_file():
            fn.unlink()


@patch
def pre_xla_fit(self:Learner, ctrl_args={}):
    "prepare learner for running spawned processes"
    progress_removed = False
    if 'progress' in L(self.cbs).attrgot('name'):
        self.remove_cbs(ProgressCallback)
        progress_removed = True
    ctrl_args['use_progress'] = progress_removed
    self.delete_tmp_files()
    return ctrl_args

@patch
def post_xla_fit(self:Learner, ctrl_args):
    "clean up learner after running spawned processes"
    self.recorder.reload_attrs()
    self.recorder.reload_hps()
    if ctrl_args['use_progress']:
        self.add_cbs(ProgressCallback)

In [None]:
#export
def prep_fit_args(n_epoch, master_cbs, **kwargs):
    "prepare fit method args for running spawned processes"
    fit_args={**kwargs}
    fit_args['master_cbs'] = master_cbs
    fit_args['n_epoch'] = n_epoch 
    return fit_args

## XLA fit methods

In [None]:
#export

from fastcore.meta import delegates

@patch
@delegates(Learner.fit, but='num_cores,start_method,master_cbs')
def xla_fit(self:Learner, n_epoch, num_cores=8, 
            start_method='fork', master_cbs=None, **kwargs):
    """call fit in a multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    
    fit_args = prep_fit_args(n_epoch, master_cbs, **kwargs)
   
    xmp.spawn(xla_run_method,
              args=(Learner.fit, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *
@patch
@delegates(Learner.fit_one_cycle, but='num_cores,start_method,master_cbs')
def xla_fit_one_cycle(self:Learner, n_epoch, num_cores=8, 
                      start_method='fork', master_cbs=None, **kwargs):
    """call fit_one_cycle in a multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    
    fit_args = prep_fit_args(n_epoch, master_cbs, **kwargs)
    
    xmp.spawn(xla_run_method,
              args=(Learner.fit_one_cycle, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *
@patch
@delegates(Learner.fit_flat_cos, but='num_cores,start_method,master_cbs')
def xla_fit_flat_cos(self:Learner, n_epoch, num_cores=8, 
                      start_method='fork', master_cbs=None, **kwargs):
    """call fit_flat_cos in a multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    
    fit_args = prep_fit_args(n_epoch, master_cbs, **kwargs)
    
    xmp.spawn(xla_run_method,
              args=(Learner.fit_flat_cos, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *

def prep_fit_sgdr_args(n_cycles, cycle_len, master_cbs, **kwargs):
    "prepare fit_sgdr method args for running spawned processes"
    fit_args={**kwargs}
    fit_args['master_cbs'] = master_cbs
    fit_args['n_cycles'] = n_cycles
    fit_args['cycle_len'] = cycle_len
    return fit_args    

@patch
@delegates(Learner.fit_sgdr, but='num_cores,start_method,master_cbs')
def xla_fit_sgdr(self:Learner, n_cycles, cycle_len, num_cores=8, 
                      start_method='fork', master_cbs=None, **kwargs):
    """call fit_sgdr in multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    fit_args = prep_fit_sgdr_args(n_cycles, cycle_len, master_cbs, **kwargs)
    
    xmp.spawn(xla_run_method,
              args=(Learner.fit_sgdr, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

In [None]:
#export
from fastai.learner import Learner
from fastai.callback.schedule import *

def prep_finetune_args(epochs, master_cbs, **kwargs):
    "prepare finetune method args for running spawned processes"
    fit_args={**kwargs}
    fit_args['master_cbs'] = master_cbs
    fit_args['epochs'] = epochs
    return fit_args

@patch
@delegates(Learner.fine_tune, but='num_cores,start_method,master_cbs')
def xla_fine_tune(self:Learner, epochs, num_cores=8, 
                      start_method='fork', master_cbs=None, **kwargs):
    """call fine_tune in multicore tpu environment"""
    ctrl_args = self.pre_xla_fit()
    learner_args, add_args = self.pack_learner_args()
    
    fit_args = prep_finetune_args(epochs, master_cbs, **kwargs)
    
    xmp.spawn(xla_run_method,
              args=(Learner.fine_tune, learner_args, add_args, fit_args, ctrl_args),
              nprocs=num_cores,
              start_method=start_method)

    self.reload_child_model()
    self.post_xla_fit(ctrl_args)

## Example: Train MNIST

In [None]:
#hide
#colab
from fastai.vision.all import *

In [None]:
#hide
#colab
%cd /content

/content


In [None]:
#colab
path = untar_data(URLs.MNIST_TINY)

In [None]:
#colab
data = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
#colab
dls = data.dataloaders(path, bs=16)

In [None]:
#colab
# concat_pool must be false due to a TPU bug that is triggered if using fastai AdaptivePool
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet18


In [None]:
#colab
learner = cnn_learner(dls, resnet18, metrics=accuracy, concat_pool=False)

In [None]:
#colab
learner.add_master_cbs([SaveModelCallback(fname='best_model')])

In [None]:
#colab
#hide
assert hasattr(learner,'xla_fit')

In [None]:
#colab
class PrintValuesCallback(Callback):
    order = 56 # after recorder, sync recorder, before save model callback  
    def after_epoch(self):
        print(f'final record: {self.learn.final_record}')
        vlen = len(self.recorder.values)
        print(f'values len: {vlen}')
        if vlen > 0:   
            last_idx = self.recorder.values[-1]  
            len_last_idx = len(last_idx)
            print(f'values last idx len: {len_last_idx}')
            print(f'last idx: {last_idx}')
            if 'save_model' in L(self.cbs).attrgot('name'):
                save_model_idx = self.save_model.idx
                print(f'save_model idx: {save_model_idx}')     
                if save_model_idx < len_last_idx:
                    val = self.recorder.values[-1][self.save_model.idx]
                    print(f'best_value: {val}')
        if 'sync_recorder' in L(self.cbs).attrgot('name'):
            sync_log = self.sync_recorder.sync_log
            len_sync_log = len(sync_log)
            print(f'sync rec sync_log len: {len_sync_log}')
            print(f'sync rec sync_log: {sync_log}')

            if len_sync_log > 0:
                print(f'sync rec sync_log[1:]: {sync_log[1:]}')


In [None]:
#colab
# cbs = [PrintValuesCallback(), SaveModelCallback(fname='best_model')]
cbs = [PrintValuesCallback()]

In [None]:
#colab
learner.xla_fit_one_cycle(5,lr_max=slice(2e-3))


start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.248156,0.515241,0.81392,00:18
1,0.200013,0.743561,0.538352,00:03
2,0.24693,0.687081,0.596591,00:04
3,0.261705,0.880469,0.538352,00:04
4,0.295985,1.625116,0.511364,00:04


Better model found at epoch 0 with valid_loss value: 0.5152414441108704.


In [None]:
#colab
res = learner.get_preds()
print(accuracy(*res))

TensorBase(0.7568)


In [None]:
#colab
learner.load('best_model')

  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


<fastai.learner.Learner at 0x7f70b9c63e50>

In [None]:
#colab
res = learner.get_preds()
print(accuracy(*res))

TensorBase(0.7568)


In [None]:
#hide
#colab
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                      

In [None]:
#colab
learner.unfreeze()

In [None]:
#hide
#colab
learner.summary()

Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 64 x 14 x 14   
Conv2d                                    9408       True      
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                      

In [None]:
#colab
one_param(learner.model).device

device(type='cpu')

In [None]:
#colab
learner.xla_fit(n_epoch=5, lr=2e-3)

start fit


epoch,train_loss,valid_loss,accuracy,time
0,0.050473,10.583854,0.492898,00:21
1,0.047812,2.253268,0.72017,00:03
2,0.056125,0.019719,0.991477,00:03
3,0.051844,0.034788,0.990057,00:03
4,0.047623,0.010792,0.997159,00:04


Better model found at epoch 0 with valid_loss value: 10.583853721618652.
Better model found at epoch 1 with valid_loss value: 2.253268003463745.
Better model found at epoch 2 with valid_loss value: 0.01971946842968464.
Better model found at epoch 4 with valid_loss value: 0.010792052373290062.


In [None]:
#colab
learner.validate()

(#2) [0.012101042084395885,0.9942775368690491]

## Train using torch datasets and dataloaders


In [None]:
from pathlib import Path
FLAGS = {}
FLAGS['batch_size']  = 64
FLAGS['num_workers'] = 4
FLAGS['data_dir'] = Path('/content/data/cifar')

In [None]:
from torchvision import datasets, transforms

In [None]:
def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm,
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        norm,
    ])
    train_dataset = datasets.CIFAR10(
        root=FLAGS['data_dir'],
        train=True,
        download=True,
        transform=transform_train)
    test_dataset = datasets.CIFAR10(
        root=FLAGS['data_dir'],
        train=False,
        download=True,
        transform=transform_test)
    
    return train_dataset, test_dataset


In [None]:
#colab
train_dataset, test_dataset = get_dataset()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/cifar/cifar-10-python.tar.gz to /content/data/cifar
Files already downloaded and verified


In [None]:
#colab
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=FLAGS['batch_size'],
#   sampler=train_sampler,
    shuffle=True,
    num_workers=FLAGS['num_workers'],
    drop_last=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    shuffle=False,
    num_workers=FLAGS['num_workers'],
    drop_last=True)





In [None]:
#colab
# fastai dls using torch dataloaders
dls = DataLoaders(train_loader, test_loader)

In [None]:
#colab
learner = cnn_learner(dls, resnet18, metrics=accuracy, 
                      n_out=10, 
                      loss_func=nn.CrossEntropyLoss(),
                      concat_pool=False 
                      )

In [None]:
#colab
learner.xla_fit(5,lr=2e-2)

start fit


epoch,train_loss,valid_loss,accuracy,time
0,1.27558,1.169747,0.596968,01:30
1,1.121143,0.998941,0.652231,01:17
2,1.024472,1.067758,0.639523,01:16
3,0.993983,0.907262,0.687425,01:20
4,0.973521,0.885192,0.695187,01:23
