In [1]:
from __future__ import print_function
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from model import capsules
from loss import SpreadLoss
from datasets import smallNORB


In [2]:
batch_size = 128
test_batch_size = 128
test_intvl = 1
epochs = 10
lr = 3e-3
weight_decay = 2e-7
cuda = True
seed = 1
log_interval = 10
em_iters = 2
snapshot_folder = './snapshots'
data_folder = './data'
dataset = 'mnist'

In [3]:
def get_setting():
    kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
    path = os.path.join(data_folder, dataset)
    if dataset == 'mnist':
        num_class = 10
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(path, train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(path, train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_batch_size, shuffle=True, **kwargs)
    elif dataset == 'smallNORB':
        num_class = 5
        train_loader = torch.utils.data.DataLoader(
            smallNORB(path, train=True, download=True,
                      transform=transforms.Compose([
                          transforms.Resize(48),
                          transforms.RandomCrop(32),
                          transforms.ColorJitter(brightness=32./255, contrast=0.5),
                          transforms.ToTensor()
                      ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            smallNORB(path, train=False,
                      transform=transforms.Compose([
                          transforms.Resize(48),
                          transforms.CenterCrop(32),
                          transforms.ToTensor()
                      ])),
            batch_size=test_batch_size, shuffle=True, **kwargs)
    else:
        raise NameError('Undefined dataset {}'.format(dataset))
    return num_class, train_loader, test_loader


In [4]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


In [5]:
def exp_lr_decay(optimizer, global_step, init_lr = 3e-3, decay_steps = 20000,
                                        decay_rate = 0.96, lr_clip = 3e-3 ,staircase=False):
    
    ''' decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)  '''
    
    if staircase:
        lr = (init_lr * decay_rate**(global_step // decay_steps)) 
    else:
        lr = (init_lr * decay_rate**(global_step / decay_steps)) 
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


In [6]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [7]:
def train(train_loader, model, criterion, optimizer, epoch, device):
# def train(train_loader, model, criterion, optimizer, scheduler, epoch, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    train_len = len(train_loader)
    epoch_acc = 0
    end = time.time()

    for batch_idx, (data, target) in enumerate(train_loader):
        data_time.update(time.time() - end)

        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        r = (1.*batch_idx + (epoch-1)*train_len) / (epochs*train_len)
        loss = criterion(output, target, r)
        acc = accuracy(output, target)
        
        global_step = (batch_idx+1) + (epoch - 1) * len(train_loader) 
        exp_lr_decay(optimizer = optimizer, init_lr=lr ,global_step = global_step) # moein - change the learning rate exponentially
        
        
        loss.backward()
        optimizer.step()
        # scheduler.step()

        batch_time.update(time.time() - end)
        end = time.time()

        epoch_acc += acc[0].item()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {}\t[{}/{} ({:.0f}%)]\t'
                  'Loss: {:.6f}\tAccuracy: {:.6f}\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                  epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader),
                  loss.item(), acc[0].item(),
                  batch_time=batch_time, data_time=data_time))
    return epoch_acc


In [8]:
def snapshot(model, folder, epoch):
    path = os.path.join(folder, 'model_{}.pth'.format(epoch))
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    print('saving model to {}'.format(path))
    torch.save(model.state_dict(), path)


In [9]:
def test(test_loader, model, criterion, device):
    model.eval()
    test_loss = 0
    acc = 0
    test_len = len(test_loader)
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target, r=1).item()
            acc += accuracy(output, target)[0].item()

    test_loss /= test_len
    acc /= test_len
    print('\nTest set: Average loss: {:.6f}, Accuracy: {:.6f} \n'.format(
        test_loss, acc))
    return acc


In [10]:
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

device = torch.device("cuda:0" if cuda else "cpu")

# datasets
num_class, train_loader, test_loader = get_setting()
print(device)

# model
torch.cuda.empty_cache() 
A, B, C, D = 64, 8, 16, 16
# A, B, C, D = 32, 32, 32, 32
model = capsules(A=A, B=B, C=C, D=D, E=num_class,
                 iters=em_iters).to(device)

print(model)

criterion = SpreadLoss(num_class=num_class, m_min=0.2, m_max=0.9)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)

best_acc = test(test_loader, model, criterion, device)
torch.cuda.empty_cache()
for epoch in range(1, epochs + 1):
    print('lr: {}'.format(optimizer.param_groups[0]['lr']))
    acc = train(train_loader, model, criterion, optimizer, epoch, device)
#     acc = train(train_loader, model, criterion, optimizer, scheduler, epoch, device)
#     scheduler.step()
    acc /= len(train_loader)
    if epoch % test_intvl == 0:
        best_acc = max(best_acc, test(test_loader, model, criterion, device))
best_acc = max(best_acc, test(test_loader, model, criterion, device))
print('best test accuracy: {:.6f}'.format(best_acc))

snapshot(model, snapshot_folder, epochs)


cuda:0
CapsNet(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (bn1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (primary_caps): PrimaryCaps(
    (pose): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (a): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
    (sigmoid): Sigmoid()
  )
  (conv_caps1): ConvCaps(
    (sigmoid): Sigmoid()
    (softmax): Softmax(dim=2)
  )
  (conv_caps2): ConvCaps(
    (sigmoid): Sigmoid()
    (softmax): Softmax(dim=2)
  )
  (class_caps): ConvCaps(
    (sigmoid): Sigmoid()
    (softmax): Softmax(dim=2)
  )
)

Test set: Average loss: 7.312526, Accuracy: 11.164953 

lr: 0.003

Test set: Average loss: 3.141815, Accuracy: 96.578323 

lr: 0.0029971295468123466



Test set: Average loss: 2.725735, Accuracy: 97.715585 

lr: 0.0029942618401251936

Test set: Average loss: 2.433813, Accuracy: 98.417722 

lr: 0.002991396877310641



Test set: Average loss: 2.124062, Accuracy: 98.486946 

lr: 0.0029885346557433036



Test set: Average loss: 1.887410, Accuracy: 98.655063 

lr: 0.0029856751728003063

Test set: Average loss: 1.751949, Accuracy: 98.585839 

lr: 0.002982818425861285



Test set: Average loss: 1.670467, Accuracy: 98.615506 

lr: 0.0029799644123083835



Test set: Average loss: 1.605779, Accuracy: 98.457278 

lr: 0.002977113129526248

Test set: Average loss: 1.474082, Accuracy: 98.842959 

lr: 0.00297426457490203



Test set: Average loss: 1.392701, Accuracy: 98.714399 


Test set: Average loss: 1.389521, Accuracy: 98.783623 

best test accuracy: 98.842959
saving model to ./snapshots\model_10.pth


In [11]:
# !python train.py --lr=0.01 --weight-decay=0. --epochs=30 --batch-size=16
# %run -i train --lr=0.01 --weight_decay=0. --epochs=30 --batch-size=64