<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/explore_nbs/test_tpu_pets_slow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 75kB/s 
[K     |████████████████████████████████| 61kB 3.3MB/s 
[?25h

In [2]:
# !pip install -Uqq git+https://github.com/fastai/fastai.git 
!pip install -Uqq fastai --upgrade

[K     |████████████████████████████████| 194kB 5.3MB/s 
[K     |████████████████████████████████| 61kB 7.2MB/s 
[?25h

In [3]:
!pip install -Uqq git+https://github.com/butchland/fastai_xla_extensions.git

  Building wheel for fastai-xla-extensions (setup.py) ... [?25l[?25hdone


In [4]:
!pip install -Uqq git+https://github.com/butchland/my_timesaver_utils.git

  Building wheel for my-timesaver-utils (setup.py) ... [?25l[?25hdone


In [5]:
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [6]:
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastai-xla-extensions==0.0.7
fastcore==1.3.19
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


Start of kernel

In [7]:
from fastai.vision.all import *
from fastai_xla_extensions.multi_core import *
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp



In [8]:
import os
FLAGS = {}
FLAGS['image_size'] = 224
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['num_epochs'] = 5
FLAGS['num_procs'] =  8 if os.environ.get('TPU_NAME',False) else 1
FLAGS['learning_rate'] = 5e-3 * 8
FLAGS['weight_decay'] = 5e-4
FLAGS['momentum'] = (0.9, 0.85, 0.9)
FLAGS['sync_valid'] = True


In [9]:
ARCH = resnet34
LOSS_FUNC = nn.CrossEntropyLoss()
OPT_FUNC = Adam

In [10]:
PATH = untar_data(URLs.PETS)/'images'
pat = r'(.+)_\d+.jpg$'
fname_labeller = FileNamePatternLabeller(pat)
image_size = FLAGS['image_size']
DATA = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=fname_labeller,
    splitter=RandomSplitter(seed=42),
    item_tfms=[Resize(image_size),],
    batch_tfms=[]
)


In [11]:
vocab = CategoryMap(get_image_files(PATH).map(fname_labeller))
N_OUT = len(vocab)

In [12]:
custom_model = create_cnn_model(ARCH, N_OUT, pretrained=True, concat_pool=False)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




In [13]:
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)

In [14]:
def train_model(rank, flags):
    xm.master_print('start train_model')
    world_size = xm.xrt_world_size()
    device = xm.xla_device()
    bs = flags['batch_size']
    sync_valid = flags['sync_valid']
    dls = make_fastai_dataloaders(
                DATA,
                PATH, 
                rank=rank, 
                world_size=world_size,
                sync_valid=sync_valid,
                bs=bs)
    model = WRAPPED_MODEL.to(device)
    loss_func = LOSS_FUNC
    opt_func = OPT_FUNC
    wd = flags['weight_decay']
    moms = flags['momentum']
    learner = Learner(dls, 
                      model,
                      loss_func=loss_func, 
                      opt_func=opt_func,
                      splitter=default_split,
                      wd=wd,
                      moms=moms,
                      metrics=accuracy
                    )
    learner.to_xla(device,rank=rank,sync_valid=sync_valid)
    learner.freeze()
    lr = flags['learning_rate']
    epochs = flags['num_epochs']
    learner.fit_one_cycle(epochs, lr_max=slice(lr))
    learner.save('stage-1')

In [15]:
# %%time
xmp.spawn(train_model, args=(FLAGS,), nprocs=FLAGS['num_procs'], start_method='fork')

start train_model
start fit


epoch,train_loss,valid_loss,accuracy,time
0,1.870138,0.831724,0.815541,01:37
1,1.202719,0.82839,0.827703,01:19
2,0.937062,0.421658,0.898649,01:19
3,0.706549,0.285118,0.92973,01:18
4,0.57468,0.261792,0.933784,01:19
