In [1]:
import sys
sys.path.append("./../..")


In [2]:
import math
from tqdm import tqdm
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch import optim
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
#
from misc.utils import count_parameters
from effcn.em_git import CapsNet, SpreadLoss

### Train model

In [3]:
ds_train = datasets.MNIST(root='../../data', train=True, download=True, transform=T.ToTensor())
ds_valid = datasets.MNIST(root="../../data", train=False, download=True, transform=T.ToTensor())

In [4]:
dl_train = torch.utils.data.DataLoader(ds_train, 
                                        batch_size=8, 
                                        shuffle=True,
                                        num_workers=4)

dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                        batch_size=8, 
                                        shuffle=True,
                                        num_workers=4)

In [5]:
device = torch.device("cuda")


In [6]:
model = CapsNet()
model = model.to(device)

count_parameters(model)

319028

In [7]:
loss_func = SpreadLoss(num_class=10, m_min=0.2, m_max=0.9)
optimizer = optim.Adam(model.parameters(), lr = 0.01, weight_decay=2e-7) 

In [8]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def exp_lr_decay(optimizer, global_step, init_lr = 3e-3, decay_steps = 20000,
                                        decay_rate = 0.96, lr_clip = 3e-3 ,staircase=False):
    
    ''' decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)  '''
    
    if staircase:
        lr = (init_lr * decay_rate**(global_step // decay_steps)) 
    else:
        lr = (init_lr * decay_rate**(global_step / decay_steps)) 
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [9]:
num_epochs = 2
train_len = len(dl_train)
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for idx, (x,y_true) in enumerate(pbar):
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        

        y_pred = model(x)
        r = (1.*idx + (epoch_idx-1)*train_len) / (num_epochs*train_len)
        loss = loss_func(y_pred, y_true,r)         
        acc = accuracy(y_pred, y_true)

        global_step = (batch_idx+1) + (epoch - 1) * len(train_loader) 
        exp_lr_decay(optimizer = optimizer, global_step = global_step) # moein - change the learning rate exponentially

        loss.backward()
        optimizer.step()

        epoch_acc += acc[0].item()
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc[0].item()
                 }
        )
    

        
    # I guess this is done once per epoch
    #lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    model.eval()
    
    test_loss = 0
    acc = 0
    test_len = len(test_loader)

    for x,y_true in dl_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            y_pred = model(x)

            test_loss += criterion(y_pred, y_true, r=1).item()
            acc += accuracy(y_pred, y_true)[0].item()

    test_loss /= test_len
    acc /= test_len
    print("   acc_valid: {:.3f}".format(acc))

Train [  0/  2]:▌         | 433/7500 [01:52<30:44,  3.83it/s, loss=-.0164, acc=[tensor([12.5000], device='cuda:0')]]


KeyboardInterrupt: 