In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable
import time
USE_GPU = True

In [2]:
def squash(x):

  norm_squared = (x ** 2).sum(-1, keepdim=True)
  part1 = norm_squared / (1 +  norm_squared)
  part2 = x / torch.sqrt(norm_squared)

  output = part1 * part2 
  return output

In [3]:
class ConvLayer(nn.Module):
  def __init__(self, 
               in_channels=1, 
               out_channels=256, 
               kernel_size=9):
    super(ConvLayer, self).__init__()
    
    self.conv = nn.Conv2d(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=kernel_size,
                          stride=1)
  def forward(self, x):
    output = self.conv(x)
    output = functional.relu(output)
    return output

In [4]:
class PrimaryCapules(nn.Module):
  
  def __init__(self, 
               num_capsules=8, 
               in_channels=256, 
               out_channels=32, 
               kernel_size=9):
    super(PrimaryCapules, self).__init__()
    self.capsules = nn.ModuleList([
      nn.Conv2d(in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=2,
                padding=0) for i in range(num_capsules)
    ])
  
  def forward(self, x):
    output = [caps(x) for caps in self.capsules]
    output = torch.stack(output, dim=1)
    output = output.view(x.size(0), 32*6*6, -1)
    
    return squash(output)
  
  # The squash function specified in Dynamic Routing Between Capsules
  # x: input tensor 


In [169]:
class ClassCapsules(nn.Module):
  
  def __init__(self, 
               num_capsules=10,
               num_routes = 32*6*6,
               in_channels=8,
               out_channels=16,
               routing_iterations=3):
    super(ClassCapsules, self).__init__()
    
    self.in_channels = in_channels
    self.num_routes = num_routes
    self.num_capsules = num_capsules
    self.routing_iterations = routing_iterations
    self.W = nn.Parameter(torch.rand(1,
                                     num_routes,
                                     num_capsules,
                                     out_channels,
                                     in_channels))
  
  def forward(self, x):
    batch_size = x.size(0)
    x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
    
    W = torch.cat([self.W] * batch_size, dim=0)
    u_hat = torch.matmul(W, x)
    
    b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
    
    if USE_GPU:
      b_ij = b_ij.cuda()
    
    for it in range(self.routing_iterations):
      c_ij = functional.softmax(b_ij, dim=1) # Not sure if it should be dim=1
      c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
      
      s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
      v_j = squash(s_j)
      
      if it < self.routing_iterations - 1: 
        uhatv_product = torch.matmul(u_hat.transpose(3,4),
                            torch.cat([v_j] * self.num_routes, dim=1))
        uhatv_product = uhatv_product.squeeze(4).mean(dim=0, keepdim=True)
        b_ij = b_ij + uhatv_product
      
    return v_j.squeeze(1)

  
    
    

In [170]:
class ReconstructionModule(nn.Module):
  def __init__(self, capsule_size=16, num_capsules=10):
    super(ReconstructionModule, self).__init__()
    
    self.num_capsules = num_capsules
    self.capsule_size = capsule_size
    
    self.decoder = nn.Sequential(
      nn.Linear(capsule_size*num_capsules, 512),
      nn.ReLU(),
      nn.Linear(512, 1024),
      nn.ReLU(),
      nn.Linear(1024, 784),
      nn.Sigmoid()
    )
  
  def forward(self, x, data, target=None):
    batch_size = x.size(0)
    if target is None:
      classes = torch.sqrt((x **2).sum(2))
      classes = functional.softmax(classes, dim=1)

      _, max_length_indices = classes.max(dim=1)
    else:
      max_length_indices = target.max(dim=1)[1].reshape(-1,1)
    masked = Variable(torch.eye(self.num_capsules))
    
    if USE_GPU:
      masked  = masked.cuda()
    masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
    decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)

    reconstructions = self.decoder(decoder_input)
    reconstructions = reconstructions.view(-1, 1, 28, 28)
    
    return reconstructions, masked

    
  
  

In [171]:
class CapsNet(nn.Module):
  
  def __init__(self,
               alpha=0.0005 # Alpha from the loss function 
              ):
    super(CapsNet, self).__init__()
    
    self.conv_layer = ConvLayer()
    self.primary_capsules = PrimaryCapules()
    self.digit_caps = ClassCapsules()
    self.decoder = ReconstructionModule()
    
    self.mse_loss = nn.MSELoss()
    self.alpha = alpha
  
  def forward(self, x, target=None):
    output = self.conv_layer(x)
    output = self.primary_capsules(output)
    output = self.digit_caps(output)
    reconstruction, masked = self.decoder(output, x, target)
    return output, reconstruction, masked
  
  def loss(self, images,labels, capsule_output,  reconstruction):
    marg_loss = self.margin_loss(capsule_output, labels)
    rec_loss = self.reconstruction_loss(data, reconstruction)
    return marg_loss + self.alpha*rec_loss
  
  def margin_loss(self, x, labels):
    batch_size = x.size(0)
    
    v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))
    
    left = functional.relu(0.9 - v_c).view(batch_size, -1)
    right = functional.relu(v_c - 0.1).view(batch_size, -1)
    
    loss = labels * left + 0.5 *(1-labels)*right
    loss = loss.sum(dim=1).mean()
    return loss
  
  def reconstruction_loss(self, data, reconstructions):
    batch_size = reconstructions.size(0)
    loss = self.mse_loss(reconstructions.view(batch_size, -1),
                         data.view(batch_size, -1))
    return loss
  
  

In [183]:
class Test(nn.DataParallel):
  
  def __init__(self, capsnet, device_ids):
    super(Test, self).__init__(capsnet, device_ids=device_ids)
    self.capsnet = capsnet
    
  def loss(self, images,labels, capsule_output,  reconstruction): 
    return self.capsnet.loss(images, labels, capsule_output, reconstruction)
  
  def initialize_weights(self, initializer):
    self.capsnet.conv_layer.conv.apply(initializer)
    self.capsnet.primary_capsules.apply(initializer)
    self.capsnet.decoder.apply(initializer)

In [184]:
def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        nn.init.xavier_normal_(m.weight.data, gain=0.02)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight.data, gain=0.02)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal(m.weight.data, 1.0, 0.02)
        nn.init.constant(m.bias.data, 0.0)

In [189]:

capsnet = Test(CapsNet(), device_ids=[0,1])
if USE_GPU:
  capsnet.cuda()
optimizer = torch.optim.Adam(capsnet.parameters())

In [190]:
capsnet.initialize_weights(weights_init_xavier)

In [191]:
dataset_transform = transforms.Compose([
               transforms.ToTensor(),
               transforms.Normalize((0.1307,), (0.3081,))
           ])
"""Hyperparameters"""
max_epochs = 50
batch_size = 128
train_dataset = datasets.MNIST('../data', 
                               train=True, 
                               download=True, 
                               transform=dataset_transform)
test_dataset = datasets.MNIST('../data', 
                               train=False, 
                               download=True, 
                               transform=dataset_transform)


train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, 
                                           batch_size=batch_size,
                                           shuffle=False)

In [None]:
t = time.time()
Te_LOSS = []
Tr_LOSS = []
for epoch in range(max_epochs):
  capsnet.train()
  train_loss = 0
  for batch, (data, target) in enumerate(train_loader):
    target = torch.eye(10).index_select(dim=0, index=target)
    data, target = Variable(data), Variable(target)
    
    if USE_GPU:
      data, target = data.cuda(), target.cuda()
    
    optimizer.zero_grad()
    
    output, reconstructions, masked = capsnet(data, target)
    loss = capsnet.loss(data, target, output, reconstructions)
    loss.backward()
    optimizer.step()
    
    train_loss += loss.data.item()
  
  capsnet.eval()
  test_loss = 0
  test_correct = 0
  test_total = 0
  for batch_id, (data, target) in enumerate(test_loader):
    target = torch.eye(10).index_select(dim=0, index=target)
    data, target = Variable(data), Variable(target)
    
    if USE_GPU:
      data,target = data.cuda(), target.cuda()
    
    output, reconstruction, masked = capsnet(data)
    loss = capsnet.loss(data, target, output, reconstruction)
    
    test_loss += loss.data.item()
    test_total += data.size(0)
    test_correct += sum(np.argmax(masked.data.cpu().numpy(),1 ) == np.argmax(target.data.cpu().numpy(), 1))
  
  acc = test_correct / test_total
  Te_LOSS.append(test_loss / len(test_loader))
  Tr_LOSS.append(train_loss / len(train_loader))
  test_loss /= len(test_loader)
  train_loss /= len(train_loader)
  time_spent = time.time() - t
  t = time.time()
  print("Epoch: {:3.0f} \t Time: {:3.0f} \t Test: {:.3f} \t Train: {:.3f} \t Accuracy: {:3.4f}".format(epoch, time_spent,test_loss, train_loss, acc*100))
  

Epoch:   0 	 Time: 106 	 Test: 0.819 	 Train: 0.859 	 Accuracy: 36.6900
Epoch:   1 	 Time: 103 	 Test: 0.756 	 Train: 0.779 	 Accuracy: 49.2800
Epoch:   2 	 Time: 105 	 Test: 0.367 	 Train: 0.547 	 Accuracy: 81.0200
Epoch:   3 	 Time: 106 	 Test: 0.188 	 Train: 0.253 	 Accuracy: 93.0800
Epoch:   4 	 Time: 106 	 Test: 0.111 	 Train: 0.143 	 Accuracy: 95.8100
Epoch:   5 	 Time: 105 	 Test: 0.098 	 Train: 0.104 	 Accuracy: 96.7800
Epoch:   6 	 Time: 106 	 Test: 0.073 	 Train: 0.084 	 Accuracy: 97.6100
Epoch:   7 	 Time: 106 	 Test: 0.072 	 Train: 0.072 	 Accuracy: 97.7700
Epoch:   8 	 Time: 105 	 Test: 0.059 	 Train: 0.063 	 Accuracy: 98.0500
Epoch:   9 	 Time: 103 	 Test: 0.050 	 Train: 0.056 	 Accuracy: 98.2300
Epoch:  10 	 Time: 106 	 Test: 0.047 	 Train: 0.050 	 Accuracy: 98.4700
Epoch:  11 	 Time: 103 	 Test: 0.043 	 Train: 0.046 	 Accuracy: 98.6500
Epoch:  12 	 Time: 103 	 Test: 0.041 	 Train: 0.042 	 Accuracy: 98.6500
Epoch:  13 	 Time: 104 	 Test: 0.041 	 Train: 0.040 	 Accuracy: 

In [None]:
[36.69, 49.28, 81.02, 93.08, 95.81, 96.78, 97.61, 97.77, 98.05, 98.23, 98.47, 98.6, 98.65, 98.82, 98.56, 98.96, 98.99,99.05,98.96]