<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/samples/torch_dataloader_multicore_pets_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 32kB/s 
[K     |████████████████████████████████| 61kB 3.7MB/s 
[?25h

In [2]:
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 5.1MB/s 
[K     |████████████████████████████████| 61kB 5.7MB/s 
[?25h

In [3]:
!pip install -Uqq git+https://github.com/butchland/fastai_xla_extensions.git

  Building wheel for fastai-xla-extensions (setup.py) ... [?25l[?25hdone


In [4]:
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [5]:
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [6]:
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastai-xla-extensions==0.0.7
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


In [7]:
# a = []
# while(1):
#     a.append('1')

Start of kernel

In [8]:
from fastai.vision.all import *
from fastai_xla_extensions.multi_core import *
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torchvision as thv



In [9]:
FLAGS = {}
FLAGS['image_size'] = 224
FLAGS['batch_size'] = 16
FLAGS['freeze_epochs'] = 1
FLAGS['epochs'] = 9
FLAGS['moms'] = (0.9,0.95,0.9)
FLAGS['weight_decay'] = 5e-4
FLAGS['learning_rate'] = 2e-3
FLAGS['num_workers'] = 4
FLAGS['nprocs'] = 8
FLAGS['sync_valid'] = True

In [10]:
path = untar_data(URLs.PETS)/'images'

In [11]:
imagenet_norm = thv.transforms.Normalize(
    mean=(0.485, 0.456, 0.406), 
    std=(0.229, 0.224, 0.225))

image_size = FLAGS['image_size']
splitter = RandomSplitter(seed=42)
pat = r'(.+)_\d+.jpg$'
fname_labeller = FileNamePatternLabeller(pat)

dset_builder = TorchDatasetBuilder(
    path, 
    get_items=get_image_files,
    splitter=splitter,
    x_tfms=[thv.transforms.Resize((image_size,image_size)), thv.transforms.ToTensor(), imagenet_norm],
    y_tfms=[fname_labeller, VocabularyMapper(),],
    x_type_tfms=PILImage.create,
) 

dset_builder.setup(get_image_files(path),do_setup=True)
n_out = dset_builder.y_tfms[1].c     

custom_model = create_cnn_model(resnet34, n_out,
                                pretrained=True,
                                concat_pool=False)

wrapped_model = xmp.MpModelWrapper(custom_model)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




In [12]:
def train_torch_model(rank, flags):
    xm.master_print('start training')
    world_size = xm.xrt_world_size()
    device = xm.xla_device()
    dsets = dset_builder.get_datasets()
    sync_valid = flags['sync_valid']
    dls = make_torch_dataloaders(
            *dsets, 
            rank=rank, 
            world_size=world_size,
            sync_valid=sync_valid,
            bs=flags['batch_size'],
            num_workers=flags['num_workers'])
    model = wrapped_model.to(device)
    learner = Learner(
            dls,
            model,
            loss_func=nn.CrossEntropyLoss(),
            opt_func=Adam,
            wd=flags['weight_decay'],
            moms=flags['moms'],
            metrics=accuracy
            )
    learner.to_xla(device,rank=rank, sync_valid=sync_valid)
    lr = flags['learning_rate'] * world_size
    learner.freeze()
    freeze_epochs = flags['freeze_epochs']
    learner.fit_one_cycle(freeze_epochs,lr_max=slice(lr/10.))
    learner.unfreeze()
    epochs = flags['epochs']
    learner.fit_one_cycle(epochs, lr_max=slice(lr/50, lr/20.))
    learner.save('stage-1')
    xm.mark_step()
    xm.rendezvous('end training')

In [13]:
%%time
xmp.spawn(train_torch_model, args=(FLAGS,), nprocs=FLAGS['nprocs'],start_method='fork')

start training
start fit


epoch,train_loss,valid_loss,accuracy,time
0,2.863358,1.573918,0.527699,01:23


start fit


epoch,train_loss,valid_loss,accuracy,time
0,1.841882,1.737729,0.492188,00:55
1,1.549578,1.458188,0.600852,00:39
2,1.323139,1.163835,0.659091,00:39
3,0.999466,1.197091,0.668324,00:39
4,0.725341,0.714214,0.790483,00:39
5,0.50427,0.608165,0.825994,00:39
6,0.300993,0.524123,0.848011,00:39
7,0.16728,0.46508,0.868608,00:39
8,0.110655,0.460137,0.870739,00:39


CPU times: user 431 ms, sys: 310 ms, total: 741 ms
Wall time: 7min 50s
