# 1. Connect to Google Drive

In [None]:
from google.colab import drive

drive.mount('/gdrive')
gdrive_root = '/gdrive/My Drive'

# 2. Import modules

In [None]:
import os

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

!pip install pyngrok
# import TensorBoardColab
!pip install -U tensorboardcolab
from tensorboardcolab import TensorBoardColab

torch.manual_seed(470)
torch.cuda.manual_seed(470)

# 3. Configure the experiments

In [None]:
# training & optimization hyper-parameters
max_epoch = 10
learning_rate = 0.0001
batch_size = 200
device = 'cuda'

# model hyper-parameters
input_dim = 784 # 28x28=784
hidden_dim = 512
output_dim = 10 

# initialize tensorboard for visualization
# Note : click the Tensorboard link to see the visualization of training/testing results
# tbc = TensorBoardColab()

# 4. Construct data pipeline

In [None]:
data_dir = os.path.join(gdrive_root, 'my_data')

transform = transforms.ToTensor()

train_dataset = MNIST(data_dir, train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

test_dataset = MNIST(data_dir, train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

# 5. Construct a neural network builder

In [None]:
class MyClassifier(nn.Module):
  def __init__(self, input_dim=784, hidden_dim=512, output_dim=10):
    super(MyClassifier, self).__init__()
    self.layers = nn.Sequential(
      nn.Linear(input_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, output_dim),
    )
    
  def forward(self, x):
    batch_size = x.size(0)
    x = x.view(batch_size, -1)
    outputs = self.layers(x)
    return outputs

# 6. Initialize the model and optimizer

In [None]:
my_classifier = MyClassifier(input_dim, hidden_dim, output_dim)
my_classifier = my_classifier.to(device)

optimizer = optim.Adam(my_classifier.parameters(), lr=learning_rate)


# 7. Load pre-trained weights if exist

In [None]:
ckpt_dir = os.path.join(gdrive_root, 'checkpoints')
if not os.path.exists(ckpt_dir):
  os.makedirs(ckpt_dir)
  
best_acc = 0.
ckpt_path = os.path.join(ckpt_dir, 'lastest.pt')
if os.path.exists(ckpt_path):
  ckpt = torch.load(ckpt_path)
  try:
    my_classifier.load_state_dict(ckpt['my_classifier'])
    optimizer.load_state_dict(ckpt['optimizer'])
    best_acc = ckpt['best_acc']
  except RuntimeError as e:
      print('wrong checkpoint')
  else:    
    print('checkpoint is loaded !')
    print('current best accuracy : %.2f' % best_acc)

# 8. Train

In [None]:
it = 0
train_losses = []
test_losses = []
for epoch in range(max_epoch):
  # train phase
  # Note: Behaviours of some layers/modules, such as dropout, batchnorm, etc., are different (or should be treated differently) depending on whether the phase is in train or test
  #       For example, dropout modules turn off some activations with probability p in training time, but not in test time.
  #       However, our network "my_classifier" does not know which phase is under-going, and we need to give the network a signal to handle this issue.
  #       Fortuntely, Pytorch provides us the utility functions for this, which are `.train()` and `.eval()`
  my_classifier.train()
  for inputs, labels in train_dataloader:
    it += 1
    
    # load data to the GPU.
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    # feed data into the network and get outputs.
    logits = my_classifier(inputs)
    
    # calculate loss
    # Note: `F.cross_entropy` function receives logits, or pre-softmax outputs, rather than final probability scores.
    loss = F.cross_entropy(logits, labels)
    
    # Note: You should flush out gradients computed at the previous step before computing gradients at the current step. 
    #       Otherwise, gradients will accumulate.
    optimizer.zero_grad()
    
    # backprogate loss.
    loss.backward()
    
    # update the weights in the network.
    optimizer.step()
    
    # calculate accuracy.
    acc = (logits.argmax(dim=1) == labels).float().mean()
    
    if it % 200 == 0:
      # tbc.save_value('Loss', 'train_loss', it, loss.item())
      print('[epoch:{}, iteration:{}] train loss : {:.4f} train accuracy : {:.4f}'.format(epoch, it, loss.item(), acc.item()))
    
  # save losses in a list so that we can visualize them later.
  train_losses.append(loss.item())  
    
  # test phase
  n = 0.
  test_loss = 0.
  test_acc = 0.
  my_classifier.eval()
  for test_inputs, test_labels in test_dataloader:
    test_inputs = test_inputs.to(device)
    test_labels = test_labels.to(device)
    
    logits = my_classifier(test_inputs)
    test_loss += F.cross_entropy(logits, test_labels, reduction='sum')
    test_acc += (logits.argmax(dim=1) == test_labels).float().sum()
    n += test_inputs.size(0)
    
  test_loss /= n
  test_acc /= n
  test_losses.append(test_loss.item())
  # tbc.save_value('Loss', 'test_loss', it, test_loss.item())
  print('[epoch:{}, iteration:{}] test_loss : {:.4f} test accuracy : {:.4f}'.format(epoch, it, test_loss.item(), test_acc.item())) 
  
  # tbc.flush_line('train_loss')
  # tbc.flush_line('test_loss')
  
  # save checkpoint whenever there is improvement in performance
  if test_acc > best_acc:
    best_acc = test_acc
    # Note: optimizer also has states ! don't forget to save them as well.
    ckpt = {'my_classifier':my_classifier.state_dict(),
            'optimizer':optimizer.state_dict(),
            'best_acc':best_acc}
    torch.save(ckpt, ckpt_path)
    print('checkpoint is saved !')
    
# tbc.close()

# 9. Visualize results

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.legend()