In [None]:
import os
assert os.environ[
    'COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
VERSION = "1.5"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [None]:
os.chdir('/content')
! git init
! git remote add -f origin https://github.com/fengredrum/cnn-xla.git
! git config core.sparsecheckout true
! echo utils.py >> .git/info/sparse-checkout
! git pull origin master

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_xla
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

from utils import train_model_xla
from utils import Mish, Swish

In [None]:
class AlexNet(nn.Module):
    def __init__(self, activation='relu', num_classes=10):
        super(AlexNet, self).__init__()

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'mish':
            self.activation = Mish()
        elif activation == 'swish':
            self.activation = Swish()
        else:
            raise NotImplementedError

        # Convolutional part.
        # It's different from the original implementation cause the image size of CIFAR dataset is 32x32.
        self.conv = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=5, stride=1), self.activation,
            nn.MaxPool2d(kernel_size=3, stride=1),
            nn.Conv2d(96, 256, kernel_size=5, stride=1,
                      padding=2), self.activation,
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, stride=1,
                      padding=1), self.activation,
            nn.Conv2d(384, 384, kernel_size=3, stride=1,
                      padding=1), self.activation,
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            self.activation, nn.MaxPool2d(kernel_size=3, stride=2))
        # Fully connected part
        self.fc = nn.Sequential(nn.Linear(256 * 5 * 5, 4096), self.activation,
                                nn.Dropout(0.5), nn.Linear(4096, 4096),
                                self.activation, nn.Dropout(0.5),
                                nn.Linear(4096, num_classes))

    def forward(self, x):
        out = self.conv(x)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

In [None]:
# Define Parameters
FLAGS = {}
FLAGS['batch_size'] = 256
FLAGS['lr'] = 0.02
FLAGS['num_epochs'] = 20
FLAGS['num_cores'] = 8

In [None]:
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    torch.set_default_tensor_type('torch.FloatTensor')
    net = AlexNet(activation='mish')
    accuracy, data, pred, target = train_model_xla(net, FLAGS['batch_size'],
                                                   FLAGS['lr'],
                                                   FLAGS['num_epochs'])


xmp.spawn(_mp_fn,
          args=(FLAGS, ),
          nprocs=FLAGS['num_cores'],
          start_method='fork')