In [1]:
from __future__ import print_function
%load_ext autoreload
%autoreload 2
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from loss import SpreadLoss
from datasets import smallNORB

from model.capsules_op_report import CapsNet as CapsNet_op_report

In [2]:
batch_size = 1
test_batch_size = 1
test_intvl = 1
epochs = 30
lr = 3e-3
weight_decay = 2e-7
cuda = True
seed = 1
log_interval = 10
em_iters = 2
snapshot_folder = './snapshots'
data_folder = './data'
dataset = 'mnist'

In [3]:
def get_setting():
    kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
    path = os.path.join(data_folder, dataset)
    if dataset == 'mnist':
        num_class = 10
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(path, train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(path, train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_batch_size, shuffle=True, **kwargs)
    elif dataset == 'smallNORB':
        num_class = 5
        train_loader = torch.utils.data.DataLoader(
            smallNORB(path, train=True, download=True,
                      transform=transforms.Compose([
                          transforms.Resize(48),
                          transforms.RandomCrop(32),
                          transforms.ColorJitter(brightness=32./255, contrast=0.5),
                          transforms.ToTensor()
                      ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            smallNORB(path, train=False,
                      transform=transforms.Compose([
                          transforms.Resize(48),
                          transforms.CenterCrop(32),
                          transforms.ToTensor()
                      ])),
            batch_size=test_batch_size, shuffle=True, **kwargs)
    else:
        raise NameError('Undefined dataset {}'.format(dataset))
    return num_class, train_loader, test_loader


In [4]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


In [5]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [6]:
def test(test_loader, model, criterion, device):
    model.eval()
    test_loss = 0
    acc = 0
    test_len = len(test_loader)
    
    total_cnt = min(1, test_len)
    idx = 0
    buf1, buf2, buf3, buf4, buf5, buf6, buf7 = None, None, None, None, None, None, None
    with torch.no_grad():
        for data, target in test_loader:
            if idx < total_cnt:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target, r=1).item()
                acc += accuracy(output, target)[0].item()
                idx += 1
    
    test_loss /= test_len
    acc /= test_len
    print('\nTest set: Average loss: {:.6f}, Accuracy: {:.6f} \n'.format(
        test_loss, acc))
    return acc, buf1, buf2, buf3, buf4, buf5, buf6, buf7


In [7]:
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

device = torch.device("cuda:0" if cuda else "cpu")

# datasets
num_class, train_loader, test_loader = get_setting()
print(device)

# model
torch.cuda.empty_cache() 
A, B, C, D = 64, 8, 16, 16
# A, B, C, D = 32, 32, 32, 32

criterion = SpreadLoss(num_class=num_class, m_min=0.2, m_max=0.9)

cuda:0


In [8]:
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
# Test set: Average loss: 1.259673, Accuracy: 98.783623 
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
model_test = CapsNet_op_report(A=A, B=B, C=C, D=D, E=num_class,
                 iters=em_iters).to(device)
model_test.load_state_dict(torch.load("./snapshots\model_10.pth"))
model_test.eval()
test_acc = test(test_loader, model_test, criterion, device)

self.conv1 Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
self.conv1 in torch.Size([1, 1, 28, 28])
self.conv1 out torch.Size([1, 64, 14, 14])

batch norm torch.Size([1, 64, 14, 14])

PrimaryCaps: self.pose conv Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
PrimaryCaps: self.pose conv in torch.Size([1, 64, 14, 14])
PrimaryCaps: self.pose conv out torch.Size([1, 128, 14, 14])

PrimaryCaps: self.a conv Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
PrimaryCaps: self.a conv in torch.Size([1, 64, 14, 14])
PrimaryCaps: self.a conv out torch.Size([1, 8, 14, 14])

PrimaryCaps: self.sigmoid torch.Size([1, 8, 14, 14])

transform_view matmul torch.Size([36, 72, 16, 4, 4]) torch.Size([36, 72, 16, 4, 4])

m-step element mul torch.Size([36, 72, 16])

m-step div mul torch.Size([36, 72, 16])
m-step div torch.Size([36, 72, 1])

m-step div mul torch.Size([36, 72, 16])
m-step div torch.Size([36, 1, 16])

m-step element mul torch.Size([36, 72, 16, 16])

m-step element square mul t