In [1]:
import torch
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

## Model Loading

In [None]:
USE_CUDA = True if torch.cuda.is_available() else False

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
        super(PrimaryCaps, self).__init__()
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=0)
            for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]        
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), self.num_routes, -1)
        return self.squash(u)

    def squash(self, input_tensor, epsilon=1e-7):
        squared_norm = (input_tensor ** 2 + epsilon).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)

        b_ij = torch.zeros(1, self.num_routes, self.num_capsules, 1)
        if USE_CUDA:
            b_ij = b_ij.cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)

            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)

    def squash(self, input_tensor, epsilon=1e-7):
        squared_norm = (input_tensor ** 2 + epsilon).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class Decoder(nn.Module):
    def __init__(self, input_width=28, input_height=28, input_channel=1):
        super(Decoder, self).__init__()
        self.input_width = input_width
        self.input_height = input_height
        self.input_channel = input_channel
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, self.input_height * self.input_height * self.input_channel),
            nn.Sigmoid()
        )

    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes.squeeze(), dim=1)

        _, max_length_indices = classes.max(dim=1)
        masked = torch.sparse.torch.eye(10)
        if USE_CUDA:
            masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze().data)
        t = (x * masked[:, :, None, None]).view(x.size(0), -1)
        reconstructions = self.reconstraction_layers(t)
        reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)
        return reconstructions, masked


class CapsNet(nn.Module):
    def __init__(self, config=None):
        super(CapsNet, self).__init__()
        if config:
            self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
            self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
                                                config.pc_kernel_size, config.pc_num_routes)
            self.digit_capsules_1 = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
                                            config.dc_out_channels)
            self.digit_capsules_2 = DigitCaps(config.dc_2_num_capsules, config.dc_2_num_routes, config.dc_2_in_channels,
                                            config.dc_2_out_channels)
            self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)
        else:
            self.conv_layer = ConvLayer()
            self.primary_capsules = PrimaryCaps()
            self.digit_capsules_1 = DigitCaps()
            self.decoder = Decoder()

        self.mse_loss = nn.MSELoss()

    def forward(self, data):
        output = self.digit_capsules_1(self.primary_capsules(self.conv_layer(data))).squeeze(-1)
        output = self.digit_capsules_2(output)
        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked

    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)

    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.sqrt((x ** 2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss

    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.0005

In [3]:
#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case
class Config:
    def __init__(self, dataset='mnist'):
        # CNN (cnn)
        self.cnn_in_channels = 1
        self.cnn_out_channels = 12
        self.cnn_kernel_size = 15

        # Primary Capsule (pc)
        self.pc_num_capsules = 16
        self.pc_in_channels = 12
        self.pc_out_channels = 1
        self.pc_kernel_size = 8
        self.pc_num_routes = 1 * 7 * 7

        # Digit Capsule 1 (dc)
        self.dc_num_capsules = 49
        self.dc_num_routes = 1 * 7 * 7
        self.dc_in_channels = 16
        self.dc_out_channels = 16
        
        # Digit Capsule 2 (dc)
        self.dc_2_num_capsules = 10
        self.dc_2_num_routes = 1 * 7 * 7
        self.dc_2_in_channels = 16
        self.dc_2_out_channels = 16

        # Decoder
        self.input_width = 28
        self.input_height = 28

## Training CapsuleNet

In [4]:
USE_CUDA = True if torch.cuda.is_available() else False
BATCH_SIZE = 256
N_EPOCHS = 30
LEARNING_RATE = 0.01
MOMENTUM = 0.9

def train(model, optimizer, train_loader, epoch):
    capsule_net = model
    capsule_net.train()
    n_batch = len(train_loader)
    total_loss = 0
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):
        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        if(USE_CUDA):
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()
        correct = torch.sum(torch.argmax(masked, 1) == torch.argmax(target, 1))
        train_loss = loss.item()
        total_loss += train_loss
        if batch_id % 100 == 0:
            tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
                epoch,
                N_EPOCHS,
                batch_id + 1,
                n_batch,
                correct / float(BATCH_SIZE),
                train_loss / float(BATCH_SIZE)
                ))
    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,N_EPOCHS,total_loss / len(train_loader.dataset)))

## Loading Dataset

In [5]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

trainset = torchvision.datasets.MNIST(root='./data/mnist', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data/mnist', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)


In [6]:
torch.manual_seed(1)


config = Config()

capsule_net = CapsNet(config)
# capsule_net = torch.nn.DataParallel(capsule_net)
if USE_CUDA:
    capsule_net = capsule_net.cuda()

optimizer = torch.optim.Adam(capsule_net.parameters())

for e in range(1, N_EPOCHS + 1):
    train(capsule_net, optimizer, trainloader, e)
#     test(capsule_net, mnist.test_loader, e)

  0%|          | 1/235 [00:01<04:35,  1.18s/it]

Epoch: [1/30], Batch: [1/235], train accuracy: 0.117188, loss: 0.003517


 43%|████▎     | 101/235 [00:57<01:18,  1.70it/s]

Epoch: [1/30], Batch: [101/235], train accuracy: 0.628906, loss: 0.003080


 86%|████████▌ | 201/235 [01:56<00:19,  1.70it/s]

Epoch: [1/30], Batch: [201/235], train accuracy: 0.753906, loss: 0.002670


100%|██████████| 235/235 [02:16<00:00,  1.73it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [1/30], train loss: 0.002966


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [2/30], Batch: [1/235], train accuracy: 0.699219, loss: 0.002398


 43%|████▎     | 101/235 [00:59<01:19,  1.70it/s]

Epoch: [2/30], Batch: [101/235], train accuracy: 0.820312, loss: 0.001477


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [2/30], Batch: [201/235], train accuracy: 0.890625, loss: 0.001005


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [2/30], train loss: 0.001490


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [3/30], Batch: [1/235], train accuracy: 0.906250, loss: 0.000893


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [3/30], Batch: [101/235], train accuracy: 0.906250, loss: 0.000708


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [3/30], Batch: [201/235], train accuracy: 0.945312, loss: 0.000540


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [3/30], train loss: 0.000635


  0%|          | 1/235 [00:00<02:43,  1.44it/s]

Epoch: [4/30], Batch: [1/235], train accuracy: 0.949219, loss: 0.000414


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [4/30], Batch: [101/235], train accuracy: 0.964844, loss: 0.000363


 86%|████████▌ | 201/235 [01:58<00:20,  1.69it/s]

Epoch: [4/30], Batch: [201/235], train accuracy: 0.968750, loss: 0.000313


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [4/30], train loss: 0.000427


  0%|          | 1/235 [00:00<02:41,  1.45it/s]

Epoch: [5/30], Batch: [1/235], train accuracy: 0.960938, loss: 0.000361


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [5/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000261


 86%|████████▌ | 201/235 [01:58<00:19,  1.70it/s]

Epoch: [5/30], Batch: [201/235], train accuracy: 0.972656, loss: 0.000330


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [5/30], train loss: 0.000355


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [6/30], Batch: [1/235], train accuracy: 0.957031, loss: 0.000341


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [6/30], Batch: [101/235], train accuracy: 0.945312, loss: 0.000370


 86%|████████▌ | 201/235 [01:58<00:20,  1.69it/s]

Epoch: [6/30], Batch: [201/235], train accuracy: 0.957031, loss: 0.000312


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [6/30], train loss: 0.000319


  0%|          | 1/235 [00:00<02:44,  1.42it/s]

Epoch: [7/30], Batch: [1/235], train accuracy: 0.964844, loss: 0.000275


 43%|████▎     | 101/235 [00:59<01:19,  1.69it/s]

Epoch: [7/30], Batch: [101/235], train accuracy: 0.964844, loss: 0.000298


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [7/30], Batch: [201/235], train accuracy: 0.976562, loss: 0.000213


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [7/30], train loss: 0.000274


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [8/30], Batch: [1/235], train accuracy: 0.964844, loss: 0.000259


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [8/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.000183


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [8/30], Batch: [201/235], train accuracy: 0.976562, loss: 0.000169


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [8/30], train loss: 0.000256


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [9/30], Batch: [1/235], train accuracy: 0.972656, loss: 0.000199


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [9/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000265


 86%|████████▌ | 201/235 [01:58<00:19,  1.70it/s]

Epoch: [9/30], Batch: [201/235], train accuracy: 0.957031, loss: 0.000324


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [9/30], train loss: 0.000229


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [10/30], Batch: [1/235], train accuracy: 0.980469, loss: 0.000177


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [10/30], Batch: [101/235], train accuracy: 0.964844, loss: 0.000245


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [10/30], Batch: [201/235], train accuracy: 0.992188, loss: 0.000115


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [10/30], train loss: 0.000209


  0%|          | 1/235 [00:00<02:41,  1.45it/s]

Epoch: [11/30], Batch: [1/235], train accuracy: 0.980469, loss: 0.000174


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [11/30], Batch: [101/235], train accuracy: 0.972656, loss: 0.000198


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [11/30], Batch: [201/235], train accuracy: 0.957031, loss: 0.000221


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [11/30], train loss: 0.000201


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [12/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.000164


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [12/30], Batch: [101/235], train accuracy: 0.972656, loss: 0.000178


 86%|████████▌ | 201/235 [01:58<00:19,  1.70it/s]

Epoch: [12/30], Batch: [201/235], train accuracy: 0.949219, loss: 0.000270


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [12/30], train loss: 0.000188


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [13/30], Batch: [1/235], train accuracy: 0.976562, loss: 0.000219


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [13/30], Batch: [101/235], train accuracy: 0.968750, loss: 0.000251


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [13/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000142


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [13/30], train loss: 0.000178


  0%|          | 1/235 [00:00<02:44,  1.43it/s]

Epoch: [14/30], Batch: [1/235], train accuracy: 0.980469, loss: 0.000177


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [14/30], Batch: [101/235], train accuracy: 0.972656, loss: 0.000169


 86%|████████▌ | 201/235 [01:57<00:19,  1.70it/s]

Epoch: [14/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000133


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [14/30], train loss: 0.000173


  0%|          | 1/235 [00:00<02:44,  1.42it/s]

Epoch: [15/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.000149


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [15/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000152


 86%|████████▌ | 201/235 [01:58<00:20,  1.69it/s]

Epoch: [15/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000113


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [15/30], train loss: 0.000164


  0%|          | 1/235 [00:00<02:43,  1.43it/s]

Epoch: [16/30], Batch: [1/235], train accuracy: 0.976562, loss: 0.000148


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [16/30], Batch: [101/235], train accuracy: 0.972656, loss: 0.000218


 86%|████████▌ | 201/235 [01:58<00:19,  1.70it/s]

Epoch: [16/30], Batch: [201/235], train accuracy: 0.964844, loss: 0.000249


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [16/30], train loss: 0.000160


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [17/30], Batch: [1/235], train accuracy: 0.980469, loss: 0.000126


 43%|████▎     | 101/235 [00:59<01:18,  1.71it/s]

Epoch: [17/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000159


 86%|████████▌ | 201/235 [01:57<00:19,  1.71it/s]

Epoch: [17/30], Batch: [201/235], train accuracy: 0.968750, loss: 0.000247


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [17/30], train loss: 0.000152


  0%|          | 1/235 [00:00<02:39,  1.46it/s]

Epoch: [18/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.000106


 43%|████▎     | 101/235 [00:59<01:18,  1.71it/s]

Epoch: [18/30], Batch: [101/235], train accuracy: 0.968750, loss: 0.000189


 86%|████████▌ | 201/235 [01:57<00:19,  1.71it/s]

Epoch: [18/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.000093


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [18/30], train loss: 0.000144


  0%|          | 1/235 [00:00<02:40,  1.45it/s]

Epoch: [19/30], Batch: [1/235], train accuracy: 0.964844, loss: 0.000195


 43%|████▎     | 101/235 [00:59<01:18,  1.71it/s]

Epoch: [19/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.000130


 86%|████████▌ | 201/235 [01:57<00:19,  1.71it/s]

Epoch: [19/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.000130


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [19/30], train loss: 0.000154


  0%|          | 1/235 [00:00<02:41,  1.45it/s]

Epoch: [20/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.000089


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [20/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000165


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [20/30], Batch: [201/235], train accuracy: 0.972656, loss: 0.000190


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [20/30], train loss: 0.000137


  0%|          | 1/235 [00:00<02:41,  1.45it/s]

Epoch: [21/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.000133


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [21/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000142


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [21/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000124


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [21/30], train loss: 0.000131


  0%|          | 1/235 [00:00<02:41,  1.45it/s]

Epoch: [22/30], Batch: [1/235], train accuracy: 0.968750, loss: 0.000192


 43%|████▎     | 101/235 [00:59<01:19,  1.69it/s]

Epoch: [22/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.000155


 86%|████████▌ | 201/235 [01:58<00:20,  1.69it/s]

Epoch: [22/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.000182


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [22/30], train loss: 0.000129


  0%|          | 1/235 [00:00<02:43,  1.43it/s]

Epoch: [23/30], Batch: [1/235], train accuracy: 0.976562, loss: 0.000162


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [23/30], Batch: [101/235], train accuracy: 0.957031, loss: 0.000224


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [23/30], Batch: [201/235], train accuracy: 0.972656, loss: 0.000174


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [23/30], train loss: 0.000126


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [24/30], Batch: [1/235], train accuracy: 0.980469, loss: 0.000145


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [24/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.000104


 86%|████████▌ | 201/235 [01:58<00:19,  1.71it/s]

Epoch: [24/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000082


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [24/30], train loss: 0.000121


  0%|          | 1/235 [00:00<02:41,  1.45it/s]

Epoch: [25/30], Batch: [1/235], train accuracy: 0.976562, loss: 0.000168


 43%|████▎     | 101/235 [00:59<01:19,  1.69it/s]

Epoch: [25/30], Batch: [101/235], train accuracy: 1.000000, loss: 0.000055


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [25/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.000190


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [25/30], train loss: 0.000121


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [26/30], Batch: [1/235], train accuracy: 0.972656, loss: 0.000151


 43%|████▎     | 101/235 [00:59<01:19,  1.70it/s]

Epoch: [26/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.000123


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [26/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000101


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [26/30], train loss: 0.000117


  0%|          | 1/235 [00:00<02:43,  1.43it/s]

Epoch: [27/30], Batch: [1/235], train accuracy: 0.976562, loss: 0.000165


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [27/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.000148


 86%|████████▌ | 201/235 [01:58<00:19,  1.70it/s]

Epoch: [27/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000092


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [27/30], train loss: 0.000113


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [28/30], Batch: [1/235], train accuracy: 0.968750, loss: 0.000180


 43%|████▎     | 101/235 [00:59<01:18,  1.70it/s]

Epoch: [28/30], Batch: [101/235], train accuracy: 0.964844, loss: 0.000205


 86%|████████▌ | 201/235 [01:58<00:20,  1.70it/s]

Epoch: [28/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.000092


100%|██████████| 235/235 [02:17<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [28/30], train loss: 0.000113


  0%|          | 1/235 [00:00<02:42,  1.44it/s]

Epoch: [29/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.000119


 43%|████▎     | 101/235 [00:59<01:19,  1.69it/s]

Epoch: [29/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.000140


 86%|████████▌ | 201/235 [01:58<00:20,  1.69it/s]

Epoch: [29/30], Batch: [201/235], train accuracy: 0.968750, loss: 0.000168


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [29/30], train loss: 0.000111


  0%|          | 1/235 [00:00<02:43,  1.43it/s]

Epoch: [30/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.000117


 43%|████▎     | 101/235 [00:59<01:19,  1.69it/s]

Epoch: [30/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.000102


 86%|████████▌ | 201/235 [01:58<00:19,  1.70it/s]

Epoch: [30/30], Batch: [201/235], train accuracy: 0.992188, loss: 0.000062


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]

Epoch: [30/30], train loss: 0.000106





In [8]:
torch.save(capsule_net.state_dict(), "./CapsNetMNIS.pth ")

In [13]:
def test(capsule_net, test_loader, epoch):
    capsule_net.eval()
    test_loss = 0
    correct = 0
    for batch_id, (data, target) in enumerate(test_loader):

        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.item()
        correct += torch.sum(torch.argmax(masked, 1) == torch.argmax(target, 1))

    tqdm.write(
        "Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, correct / len(test_loader.dataset),
                                                                  test_loss / len(test_loader)))



In [14]:
test(capsule_net, testloader, 30)

Epoch: [30/30], test accuracy: 0.985100, loss: 0.030051


In [7]:
#Config for 16 1d vectors in Capsule Layer. Set the Softmax Dimension to 1 in this case
# class Config:
#     def __init__(self, dataset='mnist'):
#         # CNN (cnn)
#         self.cnn_in_channels = 1
#         self.cnn_out_channels = 12
#         self.cnn_kernel_size = 15

#         # Primary Capsule (pc)
#         self.pc_num_capsules = 1
#         self.pc_in_channels = 12
#         self.pc_out_channels = 16
#         self.pc_kernel_size = 8
#         self.pc_num_routes = 16 * 7 * 7

#         # Digit Capsule 1 (dc)
#         self.dc_num_capsules = 49
#         self.dc_num_routes = 16 * 7 * 7
#         self.dc_in_channels = 1
#         self.dc_out_channels = 1 #16
        
#         # Digit Capsule 2 (dc)
#         self.dc_2_num_capsules = 10
#         self.dc_2_num_routes = 7 * 7
#         self.dc_2_in_channels = 1 #16
#         self.dc_2_out_channels = 16

#         # Decoder
#         self.input_width = 28
#         self.input_height = 28