In [28]:
import torch
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [29]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from torch.optim import lr_scheduler
from torch.autograd import Variable

NUM_CAPS = 8


def squash(x):
    lengths2 = x.pow(2).sum(dim=2)
    lengths = lengths2.sqrt()
    x = x * (lengths2 / (1 + lengths2) / lengths).view(x.size(0), x.size(1), 1)
    return x


class AgreementRouting(nn.Module):
    def __init__(self, input_caps, output_caps, n_iterations):
        super(AgreementRouting, self).__init__()
        self.n_iterations = n_iterations
        self.b = nn.Parameter(torch.zeros((input_caps, output_caps)))

    def forward(self, u_predict):
        batch_size, input_caps, output_caps, output_dim = u_predict.size()
        
        c = F.softmax(self.b)
        s = (c.unsqueeze(2) * u_predict).sum(dim=1)
        v = squash(s)

        if self.n_iterations > 0:
            b_batch = self.b.expand((batch_size, input_caps, output_caps))
            for r in range(self.n_iterations):
                v = v.unsqueeze(1)
                b_batch = b_batch + (u_predict * v).sum(-1)

                c = F.softmax(b_batch.view(-1, output_caps)).view(-1, input_caps, output_caps, 1)
                s = (c * u_predict).sum(dim=1)
                v = squash(s)

        return v


class CapsLayer(nn.Module):
    def __init__(self, input_caps, input_dim, output_caps, output_dim, routing_module):
        super(CapsLayer, self).__init__()
        self.input_dim = input_dim
        self.input_caps = input_caps
        self.output_dim = output_dim
        self.output_caps = output_caps
        self.weights = nn.Parameter(torch.Tensor(input_caps, input_dim, output_caps * output_dim))
        self.routing_module = routing_module
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.input_caps)
        self.weights.data.uniform_(-stdv, stdv)

    def forward(self, caps_output):
        caps_output = caps_output.unsqueeze(2)
        u_predict = caps_output.matmul(self.weights)
        u_predict = u_predict.view(u_predict.size(0), self.input_caps, self.output_caps, self.output_dim)
        v = self.routing_module(u_predict)
        return v


class PrimaryCapsLayer(nn.Module):
    def __init__(self, input_channels, output_caps, output_dim, kernel_size, stride):
        super(PrimaryCapsLayer, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_caps * output_dim, kernel_size=kernel_size, stride=stride)
        self.input_channels = input_channels
        self.output_caps = output_caps
        self.output_dim = output_dim

    def forward(self, input):
        out = self.conv(input)
        N, C, H, W = out.size()
        out = out.view(N, self.output_caps, self.output_dim, H, W)

        # will output N x OUT_CAPS x OUT_DIM
        out = out.permute(0, 1, 3, 4, 2).contiguous()
        out = out.view(out.size(0), -1, out.size(4))
        out = squash(out)
        return out


class CapsNet(nn.Module):
    def __init__(self, routing_iterations, n_classes=200):
        super(CapsNet, self).__init__()
        # N = batch

        self.conv1 = nn.Conv2d(3, 256, kernel_size=26, stride=2)
        
        self.primaryCaps = PrimaryCapsLayer(256, NUM_CAPS, 8, kernel_size=9, stride=2)  # outputs 6*6
        
        self.num_primaryCaps = NUM_CAPS * 6 * 6
        routing_module = AgreementRouting(self.num_primaryCaps, n_classes, routing_iterations)
        self.digitCaps = CapsLayer(self.num_primaryCaps, 8, n_classes, 8, routing_module)

    def forward(self, input):
        x = input.permute(0, 3, 1, 2)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.primaryCaps(x)
        x = self.digitCaps(x)
        probs = x.pow(2).sum(dim=2).sqrt()
        return x, probs


class ReconstructionNet(nn.Module):
    def __init__(self, n_dim=8, n_classes=200):
        super(ReconstructionNet, self).__init__()
        self.fc1 = nn.Linear(n_dim * n_classes, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 784)
        self.n_dim = n_dim
        self.n_classes = n_classes

    def forward(self, x, target):
        mask = Variable(torch.zeros((x.size()[0], self.n_classes)), requires_grad=False)
        if next(self.parameters()).is_cuda:
            mask = mask.cuda()
        mask.scatter_(1, target.view(-1, 1), 1.)
        mask = mask.unsqueeze(2)
        x = x * mask
        x = x.view(-1, self.n_dim * self.n_classes)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x


class CapsNetWithReconstruction(nn.Module):
    def __init__(self, capsnet, reconstruction_net):
        super(CapsNetWithReconstruction, self).__init__()
        self.capsnet = capsnet
        self.reconstruction_net = reconstruction_net

    def forward(self, x, target):
        x, probs = self.capsnet(x)
        reconstruction = self.reconstruction_net(x, target)
        return reconstruction, probs


class MarginLoss(nn.Module):
    def __init__(self, m_pos, m_neg, lambda_):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, lengths, targets, size_average=True):
        t = torch.zeros(lengths.size()).long()
        if targets.is_cuda:
            t = t.cuda()
        t = t.scatter_(1, targets.data.view(-1, 1), 1)
        targets = Variable(t)
        losses = targets.float() * F.relu(self.m_pos - lengths).pow(2) + \
                 self.lambda_ * (1. - targets.float()) * F.relu(lengths - self.m_neg).pow(2)
        return losses.mean() if size_average else losses.sum()


In [30]:
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from torch.utils.data import Dataset, DataLoader
import os
from scipy.misc import imread
import numpy as np
import torch

ROOT = '/datasets/Tiny-ImageNet/tiny-imagenet-200'
IMG_DIM = 64
IMG_CH = 3

"""
Define dataset
"""
class tinetDataset(Dataset):
    def __init__(self, dataroot, datapaths, labels, training=False):
        self.dataroot = dataroot
        self.training = training
        self.datapaths = datapaths
        self.labels = labels
        
    def __len__(self):
        return len(self.datapaths)
    
    def __getitem__(self, idx):
        img = imread(self.dataroot + self.datapaths[idx])
        if img.shape == (IMG_DIM, IMG_DIM):
            img = np.stack((img,)*3, axis=-1)
        return torch.from_numpy(np.asarray(img.flatten())).view(IMG_DIM, IMG_DIM, IMG_CH).type(torch.FloatTensor), torch.from_numpy(np.asarray(self.labels[idx].flatten())).type(torch.LongTensor)

# # Training settings
# parser = argparse.ArgumentParser(description='CapsNet with MNIST')
# parser.add_argument('--batch-size', type=int, default=128, metavar='N',
#                     help='input batch size for training (default: 64)')
# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
#                     help='input batch size for testing (default: 1000)')
# parser.add_argument('--epochs', type=int, default=250, metavar='N',
#                     help='number of epochs to train (default: 10)')
# parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
#                     help='learning rate (default: 0.01)')
# parser.add_argument('--no-cuda', action='store_true', default=False,
#                     help='disables CUDA training')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                     help='random seed (default: 1)')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='how many batches to wait before logging training status')
# parser.add_argument('--routing_iterations', type=int, default=3)
# parser.add_argument('--with_reconstruction', action='store_true', default=False)
# args = parser.parse_args()
# args.cuda = not args.no_cuda and torch.cuda.is_available()


BATCH_SIZE = 50
TEST_BATCH_SIZE = 50
EPOCHS = 250
LEARNING_RATE = 0.001
CUDA = torch.cuda.is_available()
SEED = 1
LOG_INT = 10 # how many batches to wait before logging training status
ROUTING_ITERS = 3
RECONSTRUCTION = False

kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}

In [31]:
# create training and val set
labels = next(os.walk(ROOT + '/train'))[1]
label_id = dict(zip(labels, range(len(labels))))

# generate the training set and labels
train_labels = np.zeros((200*500))
train_datapaths = [''] * (200*500)

counter = 0
for label in labels:
    curr_dir = ROOT + '/train'
    img_subdir = '/' + label + '/images/'
    images = os.listdir(curr_dir + img_subdir)
    for image in images:
        train_datapaths[counter] = img_subdir + image
        train_labels[counter] = label_id[label]
        counter += 1
        
# get the validation set and labels
val_labels = np.zeros((200*50))
val_datapaths = [''] * (200*50)

curr_dir = ROOT + '/val'
val_info = np.genfromtxt(curr_dir + '/val_annotations.txt', delimiter='\t', dtype=None)
counter = 0

images = []

for info in val_info:
    val_datapaths[counter] = '/images/' + info[0]
    val_labels[counter] = label_id[info[1]]
    counter += 1

In [32]:
# define training and validation data loaders    
    
train_loader = torch.utils.data.DataLoader(
    tinetDataset(ROOT + '/train', train_datapaths, train_labels, training = True),
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    tinetDataset(ROOT + '/val', val_datapaths, val_labels, training = False),
    batch_size=TEST_BATCH_SIZE, shuffle=False)

model = CapsNet(ROUTING_ITERS)

if RECONSTRUCTION:
    reconstruction_model = ReconstructionNet(8, 200)
    reconstruction_alpha = 0.0005
    model = CapsNetWithReconstruction(model, reconstruction_model)

if CUDA:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=15, min_lr=1e-6)

loss_fn = MarginLoss(0.9, 0.1, 0.5)

In [None]:
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if CUDA:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target, requires_grad=False)
        optimizer.zero_grad()
        if RECONSTRUCTION:
            output, probs = model(data, target)
            reconstruction_loss = F.mse_loss(output, data.view(-1, 784))
            margin_loss = loss_fn(probs, target)
            loss = reconstruction_alpha * reconstruction_loss + margin_loss
        else:
            output, probs = model(data)
            loss = loss_fn(probs, target)
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INT == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data[0]))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if CUDA:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        if RECONSTRUCTION:
            output, probs = model(data, target)
            reconstruction_loss = F.mse_loss(output, data.view(-1, 784), size_average=False).data[0]
            test_loss += loss_fn(probs, target, size_average=False).data[0]
            test_loss += reconstruction_alpha * reconstruction_loss
        else:
            output, probs = model(data)
            test_loss += loss_fn(probs, target, size_average=False).data[0]

        pred = probs.data.max(1, keepdim=True)[1]  # get the index of the max probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss


for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test_loss = test()
    scheduler.step(test_loss)
    torch.save(model.state_dict(),
               '{:03d}_model_dict_{}routing_reconstruction{}.pth'.format(epoch, ROUTING_ITERS,
                                                                         RECONSTRUCTION))




Test set: Average loss: 0.6251, Accuracy: 474/10000 (5%)




Test set: Average loss: 0.6167, Accuracy: 712/10000 (7%)




Test set: Average loss: 0.6112, Accuracy: 896/10000 (9%)




Test set: Average loss: 0.6038, Accuracy: 1129/10000 (11%)






Test set: Average loss: 0.6001, Accuracy: 1249/10000 (12%)

