In [None]:
import os

os.chdir('/content')
! git init
! git remote add -f origin https://github.com/fengredrum/cnn-xla.git
! git config core.sparsecheckout true
! echo utils.py >> .git/info/sparse-checkout
! git pull origin master

In [5]:
import torch
import torch.nn.functional as F
from torch import nn, optim

from utils import load_data_cifar_10, train_model
from utils import Mish, Swish

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')

In [58]:
class AlexNet(nn.Module):
    def __init__(self, activation='relu', num_classes=10):
        super(AlexNet, self).__init__()

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'mish':
            self.activation = Mish()
        elif activation == 'swish':
            self.activation = Swish()
        else:
            raise NotImplementedError

        # Convolutional part.
        # It's different from the original implementation cause the image size of CIFAR dataset is 32x32.
        self.conv = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=5, stride=1), self.activation,
            nn.MaxPool2d(kernel_size=3, stride=1),
            nn.Conv2d(96, 256, kernel_size=5, stride=1,
                      padding=2), self.activation,
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, stride=1,
                      padding=1), self.activation,
            nn.Conv2d(384, 384, kernel_size=3, stride=1,
                      padding=1), self.activation,
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            self.activation, nn.MaxPool2d(kernel_size=3, stride=2))
        # Fully connected part
        self.fc = nn.Sequential(nn.Linear(256 * 5 * 5, 4096), self.activation,
                                nn.Dropout(0.5), nn.Linear(4096, 4096),
                                self.activation, nn.Dropout(0.5),
                                nn.Linear(4096, num_classes))

    def forward(self, x):
        out = self.conv(x)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


net = AlexNet(activation='mish')

In [None]:
batch_size, lr, num_epochs = 256, 0.01, 20

train_iter, test_iter = load_data_cifar_10(batch_size)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
train_model(net, train_iter, test_iter, batch_size, optimizer, scheduler,
            device, num_epochs)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir runs/