In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd drive/My Drive/

Mounted at /content/drive
/content/drive/My Drive


In [2]:
VERSION = "20200516"  #@param ["1.5" , "20200516", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4139  100  4139    0     0  53753      0 --:--:-- --:--:-- --:--:-- 53753
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200516 ...
Collecting cloud-tpu-client
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 2.5MB/s 
Uninstalling torch-1.5.1+cu101:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Foun

In [3]:
%cd Lightweighted/

/content/drive/.shortcut-targets-by-id/1ri6ldikC2f57UGoFJbYCFGVeHLB2yECf/Lightweighted


# Utility Functions

In [4]:
# Result Visualization Helper
import math
from matplotlib import pyplot as plt

M, N = 4, 6
RESULT_IMG_PATH = '/test_result.png'

def plot_results(images, labels, preds):
  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]
  inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))

  num_images = images.shape[0]
  fig, axes = plt.subplots(M, N, figsize=(11, 9))
  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')

  for i, ax in enumerate(fig.axes):
    ax.axis('off')
    if i >= num_images:
      continue
    img, label, prediction = images[i], labels[i], preds[i]
    img = inv_norm(img)
    img = img.squeeze() # [1,Y,X] -> [Y,X]
    label, prediction = label.item(), prediction.item()
    if label == prediction:
      ax.set_title(u'\u2713', color='blue', fontsize=22)
    else:
      ax.set_title(
          'X {}/{}'.format(label, prediction), color='red')
    ax.imshow(img)
  plt.savefig(RESULT_IMG_PATH, transparent=True)

##Imports

In [5]:
import gc
import logging
import numpy as np
import os
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
from gtsrb_dataset import GTSRB
from torch.autograd import Variable as var
from torchvision import datasets, transforms

SERIAL_EXEC = xmp.MpSerialExecutor()

In [6]:
losses_df = pd.DataFrame(columns=['Epoch', 'Loss', 'Accuracy'])
logging.basicConfig(filename='./training_teacher_tpu.log', filemode='w', format='%(levelname)s - %(message)s')

## Save Model State

In [7]:
def saveModel(epoch, model, optimizer, loss, path):
  torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss
              }, path)

## Load Model State

In [8]:
def loadModel(model, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    print('Epoch: ',epoch,'Loss: ',loss)
    return model, epoch, loss;

# Defining the Hyper-parameters for Teacher

In [9]:
PARAMETERS = {}
PARAMETERS['saved_models'] = "saved_models/teacher_model_tpu_z.pth"
PARAMETERS['learning_rate'] = 0.001
PARAMETERS['epochs'] = 300
PARAMETERS['weight_decay'] = 0.0001
PARAMETERS['batch_size'] = 128 
PARAMETERS['growth_rate'] = 128 ## Growth rate and batch size
PARAMETERS['num_workers'] = 4
PARAMETERS['num_cores'] = 8
PARAMETERS['log_steps'] = 20
PARAMETERS['load_from_saved'] = False
PARAMETERS['start_epoch'] = 1

## Loading GTSRB Data

In [11]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.3403, 0.3121, 0.3214),
                         (0.2724, 0.2608, 0.2669))
])

train_data = GTSRB(
    root_dir = './data/', train=True,  transform=transform)
test_data = GTSRB(
    root_dir = './data/', train=False,  transform=transform)

train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data,
    num_replicas = xm.xrt_world_size(),
    rank = xm.get_ordinal(),
    shuffle = True)

train_set = torch.utils.data.DataLoader(
    train_data, 
    batch_size = PARAMETERS['batch_size'],
    sampler = train_sampler,
    num_workers = PARAMETERS['num_workers'], 
    pin_memory = True)

test_set = torch.utils.data.DataLoader(
    test_data, 
    batch_size = PARAMETERS['batch_size'],
    shuffle = False,
    num_workers = PARAMETERS['num_workers'],
    pin_memory = True)

# Defining the Teacher Network
The following sub-sections define the various parts of the Teacher Network.
We start by defining the Cell block, followed by the Stage module and finally the complete teacher model as defined in the paper \\
[Lightweight deep network for traffic sign classification, Zhang et. al (2019)](https://rdcu.be/b5aTv)

Before coding the network, we show the visual description of how the network looks with images taken from the above paper

## Cell Block

*The 1 × 1 kernels and the 3 × 3 kernels execute convolution
operations in parallel and splice all output results*

![Cell Block](https://i.imgur.com/RWMjelN.png)

Please note, the numbers 64 on each on the convolution operations are, as per our interpretation, used to denote that each of the convolution operations see exactly half of the input(as the batch size mentioned in the paper is 128)

In [12]:
class Cell(nn.Module):
  def __init__(self,cell_in_channels,cell_out_channels):
    super(Cell, self).__init__()

    self.activation_function = nn.ReLU()
    self.batch_norm = nn.BatchNorm2d(cell_out_channels)

    ## Reflect padding is used for the 3 x 3 convolution as it creates a 
    ## feature map of size 30 X 30, and needs to be padded to 32 x 32
    ## in order to concatenate with the 1 x 1 conv tensor

    self.cnn3 = nn.Conv2d(in_channels=int(cell_in_channels/2),
                          out_channels=int(cell_out_channels/2),
                          kernel_size=3,padding=1,padding_mode='reflect', 
                          stride=1)
    
    self.cnn1 = nn.Conv2d(in_channels=int(cell_in_channels/2), 
                          out_channels=int(cell_out_channels/2),
                          kernel_size=1, stride=1)
    
    
    '''
    I had initially thought about directly using grouped convolution feature, but could not find an implemented way of using different sized kernels for the parallel groups
    '''
    # self.grouped_conv = nn.Conv2d(in_channels=int(cell_in_channels/2), 
    #                               out_channels=int(cell_out_channels/2),
    #                               kernel_size=1, stride=1,groups=2)


    ## Used to split the input tensor in half in order to run parallel convolution
    self.split_size = int(cell_in_channels/2)


  def forward(self,x):

    (path1,path2) = torch.split(x,split_size_or_sections=[self.split_size,self.split_size],dim=1)
    path1 = self.cnn1(path1)

    path2 = self.cnn3(path2)

    x = torch.cat([path1,path2],1)
    x = self.batch_norm(x)
    x = self.activation_function(x)

    return x

## Stage Module
*Six cells are
used to establish the direct connection between different
layers, making full use of the feature maps of each layer*

![stage](https://i.imgur.com/szNZQ9Cm.jpg)

The outputs from each of the cells, as well as the 1 x 1 convolution are accumulated into the input for the next cell

The two 1 x 1 convolutions are used to reduce the number of feature maps when connecting between the two stages

In [13]:
class Stage(nn.Module):
  def __init__(self,cell_connections_in, cell_connections_out,stage_in,stage_out):
    super(Stage,self).__init__()

    self.activation_function = nn.ReLU()
    self.batch_norm = nn.BatchNorm2d(k)

    self.cnn1 = nn.Conv2d(in_channels=stage_in,
                          out_channels=k,kernel_size=1,
                          stride=1)
    
    self.cnn2 = nn.Conv2d(in_channels=7*k,
                          out_channels=stage_out,kernel_size=1,
                          stride=1)
    
    ## Densely connected six cell blocks
    self.cells = nn.ModuleList([
                                Cell(cell_connections_in[i],
                                     cell_connections_out[i]) for i in range(6)
                                     ])      


  def forward(self,x):
    cell_results = []
    x = self.cnn1(x)
    x = self.batch_norm(x)
    x = self.activation_function(x)

    cell_results.append(x)
    for i in range(6):
      x = torch.cat(cell_results,1)
      x = self.cells[i](x)
      cell_results.append(x);
      
    x = torch.cat(cell_results,1)

    x = self.cnn2(x)
    x = self.batch_norm(x)
    x = self.activation_function(x)

    return x

## Teacher Network
![teacher](https://i.imgur.com/bTH8KSCm.jpg)

Finally, we define the teacher network which consists of 4 stage modules connected in a dense fashion, with each stage producing a 'k' feature maps where 'k' is the growth rate of the network.

Stage 0 takes the input tensor which has 3 x H X W tensor and outputs a k x H x W tensor. The remaining Stages take 'k' feature maps as input and output 'k' feature maps

Finally, the Stage 3 output is pooled using a 3 x 3 max pooling with stride of 2
and finally a fully connected linear layer which produces the probability vector for classification.

In [14]:
class TeacherNetwork(nn.Module):
  def __init__(self,cell_onnections_in,cell_connections_out,stage_connections_in,stage_connections_out):
    super(TeacherNetwork, self).__init__()

    self.stages = nn.ModuleList([Stage(cell_onnections_in,cell_connections_out,stage_connections_in[i],stage_connections_out[i]) for i in range(4)])
    self.max_pool = torch.nn.MaxPool2d(kernel_size=2,stride=2)
    self.activation_function = nn.ReLU()
    self.linear = nn.Linear(in_features=131072,out_features=43)


  def forward(self,x):
    stage_results = []
    for i in range(4):
      if i != 0:
        x = torch.cat(stage_results,1)
        x = self.stages[i](x)
        stage_results.append(x);

      else:
        x = self.stages[0](x)
        stage_results.append(x)
    
    x = torch.cat(stage_results,1)
    x = self.max_pool(x)
    x = x.view(x.size(0),-1)
    x = self.linear(x)
    return x;

# Validation Function

Validates the model against the test data

In [15]:
def validate(model,data):
  total = 0
  correct = 0
  
  with torch.no_grad():
    for i,(images,labels) in enumerate(data):
      images = var(images)
      x = model(images)
      value,pred = torch.max(x,1)
      pred = pred.data.cpu()
      total += x.size(0)
      correct += torch.sum(pred == labels)
    return correct*100./total

# Defining the Teacher Model

In [16]:
k = PARAMETERS['batch_size']
'''
The cells use feature maps from every preceding output in the subsequent cells,
increasing the number of feature maps for every next cell by k * 2^(i-1) where 
i is the cell number. Therefore, the first cell inputs 'k' feature maps and 
outputs k feature maps, and the last cell(6th) inputs 2^5 * k feature maps and 
outputs the same number
'''
cell_connections_in = [k,2*k,3*k,4*k,5*k,6*k]
cell_connections_out = [k] * 6

'''
The stages also use feature maps from every preceding output in the subsequent 
cells, increasing the number of feature maps for every next stage linearly. 
This is due to the fact that the 1 x 1 convolution at the end of every stage 
reduces the output feature maps to size 'k'
'''

stage_connections_in = [3,k,2*k,3*k]
stage_connections_out = [k] * 4

# Training the Model

In [17]:
def train_teacher_model():

  # Scale learning rate to world size
  lr = PARAMETERS['learning_rate'] * xm.xrt_world_size()

  device = xm.xla_device()  

  WRAPPED_MODEL = xmp.MpModelWrapper(TeacherNetwork(cell_connections_in,cell_connections_out,stage_connections_in,stage_connections_out))

  if(PARAMETERS['load_from_saved']):
    teacher = TeacherNetwork(cell_connections_in,cell_connections_out,stage_connections_in,stage_connections_out)
    teacher, PARAMETERS['start_epoch'], loss = loadModel(teacher, PARAMETERS['saved_models'])
    PARAMETERS['start_epoch'] += 2
    PARAMETERS['load_from_saved'] = False
    print("Loaded model loss", loss)
    WRAPPED_MODEL = xmp.MpModelWrapper(teacher)

  # Only instantiate model weights once in memory.
  teacher = WRAPPED_MODEL.to(device)

  loss_function = nn.CrossEntropyLoss()
  optimizer = optim.Adam(teacher.parameters(), lr, weight_decay = PARAMETERS['weight_decay'])

  def training_loop(data):
    epoch_loss = 0.0
    running_loss = 0.0
    tracker = xm.RateTracker()
    teacher.train()

    for i, (images, labels) in enumerate(data):
      optimizer.zero_grad()

      prediction = teacher(images)
      loss = loss_function(prediction, labels)
      loss.backward()

      xm.optimizer_step(optimizer)

      epoch_loss += prediction.shape[0] * loss.item()
      running_loss += loss.item()

      tracker.add(PARAMETERS['batch_size'])
      if(i % 100 == 0):
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
              xm.get_ordinal(), i, loss.item(), tracker.rate(),
              tracker.global_rate(), time.asctime()), flush=True)
      
    return epoch_loss, running_loss
    
  def testing_loop(data):
    total = 0
    correct = 0
    validation_loss = 0
    teacher.eval()
    images, labels, pred = None, None, None
    with torch.no_grad():
      for images, labels in data:
        output = teacher(images)
        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(labels.view_as(pred)).sum().item()
        total += images.size()[0]
        # validation_loss += loss_function(correct, labels)
        # print(validation_loss)
        # pred = pred.cpu()
        # total += output.size(0)
        # print("testing loop 6")
        # correct += torch.sum(pred == labels)
        # print("testing loop 7")
    
    accuracy = correct * 100.0 / total
    
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, images, pred, labels, validation_loss
    

  accuracy = 0.0
  data, pred, target = None, None, None

  for epoch in range(PARAMETERS['start_epoch'], PARAMETERS['epochs'] + 1):
    para_loader = pl.ParallelLoader(train_set, [device])
    epoch_loss, running_loss = training_loop(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    xm.save({
              'epoch': epoch,
              'model_state_dict': teacher.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': (epoch_loss/len(train_set))
              }, PARAMETERS['saved_models'])
    
    para_loader = pl.ParallelLoader(test_set, [device])
    accuracy, data, pred, target, validation_loss = testing_loop(para_loader.per_device_loader(device))
    print('Epoch: ', epoch + 1, 'Loss: ', (epoch_loss/len(train_set)),'Accuracy: ',accuracy,'%', 'Validation Loss ', validation_loss)

    #logging.INFO('Epoch %s || Loss %s', str(epoch+1), str(epoch_loss/len(train_set)))
    losses_df.append({'Epoch': epoch + 1, 'Loss': float(epoch_loss/len(train_set)), 'Accuracy': accuracy}, ignore_index=True)
    losses_df.to_csv("./training_teacher_tpu.log", encoding='utf-8', index=False)

  return accuracy, data, pred, target, teacher

In [None]:
def start_training(rank, parameters):
  global PARAMETERS
  PARAMETERS = parameters
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target, trained_teacher = train_teacher_model()
  print("Final accuracy = ", accuracy)

  torch.save(teacher, 'saved_models_final_teacher_tpu.pth')
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

PARAMETERS['load_from_saved'] = True

xmp.spawn(start_training, args=(PARAMETERS, ), nprocs = PARAMETERS['num_cores'],
          start_method='fork')

Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
Epoch:  135 Loss:  2.5766902464875763
Loaded model loss 2.5766902464875763
[xla:7](0) Loss=0.05618 Rate=0.80 GlobalRate=0.80 Time=Wed Jul  1 12:35:32 2020
[xla:6](0) Loss=0.05618 Rate=0.79 GlobalRate=0.79 Time=Wed Jul  1 12:35:32 2020
[xla:1](0) Loss=0.05618 Rate=0.81 GlobalRate=0.81 Time=Wed Jul  1 12:35:32 2020
[xla:4](0) Loss=0.05618 Rate=0.82 GlobalRate=0.82 Time=Wed Jul  1 12:35:32 2020
[xla:3](0) Loss=0.05618 Rate=0.80 GlobalRate=0.80 Time=Wed Jul  1 12:35:32 2020


# Evaluating the trained model

In [None]:
accuracy = float(validate(teacher,test_set))
print('Accuracy = ',accuracy)

# Saving Model

In [None]:
torch.save(teacher,"./saved_models/trained_teacher_model_final_tpu.pth")