In [None]:
import os

! git clone https://github.com/fengredrum/cnn-xla.git
os.chdir('cnn-xla')

In [None]:
import torch
import torch.nn.functional as F
from torch import nn, optim

from utils import load_data_cifar_10, train_model
from models.alexnet import alexnet
from models.vgg import vgg11
from models.resnet import resnet18
from models.densenet import densenet121
from models.se_resnet import se_resnet_50
from models.mobilenet_v1 import mobilenet_v1
from models.mobilenet_v2 import mobilenet_v2


if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')

In [None]:
model_name = "alexnet"  #@param ["alexnet", "vgg", "resnet", "densenet", "se-resnet", "mobilenet-v1", "mobilenet-v2"]

activation = "mish"  #@param ["relu", "relu6", "swish", "mish"]

batch_size = 256  #@param {type:"integer"}

lr = 0.001  #@param {type:"number"}

num_epochs = 20  #@param {type:"integer"}

if model_name == 'alexnet':
    net = alexnet(activation=activation)
elif model_name == 'vgg':
    net = vgg11(activation=activation)
elif model_name == 'resnet':
    net = resnet18(activation=activation)
elif model_name == 'densenet':
    net = densenet121(activation=activation)
elif model_name == 'se-resnet':
    net = se_resnet_50(activation=activation)
elif model_name == 'mobilenet-v1':
    net = mobilenet_v1(activation=activation, width_multiplier=1.)
elif model_name == 'mobilenet-v2':
    net = mobilenet_v2(activation=activation, width_multiplier=1.4)
else:
    raise NotImplementedError

In [None]:
train_iter, test_iter = load_data_cifar_10(batch_size)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
train_model(net, train_iter, test_iter, batch_size, optimizer, scheduler,
            device, num_epochs, comment='-' + model_name + '-' + activation)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir runs/