<a href="https://colab.research.google.com/github/kartikdutt18/Capsule-Net-on-MNIST/blob/master/CapsuleNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Importing Libraries**

In [2]:
!pip install torch
!pip install torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms
from numpy import prod
from time import time


Collecting torch
[?25l  Downloading https://files.pythonhosted.org/packages/7e/60/66415660aa46b23b5e1b72bc762e816736ce8d7260213e22365af51e8f9c/torch-1.0.0-cp36-cp36m-manylinux1_x86_64.whl (591.8MB)
[K    100% |████████████████████████████████| 591.8MB 25kB/s 
tcmalloc: large alloc 1073750016 bytes == 0x62386000 @  0x7f007d0d42a4 0x591a07 0x5b5d56 0x502e9a 0x506859 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x504c28 0x502540 0x502f3d 0x507641
[?25hInstalling collected packages: torch
Successfully installed torch-1.0.0
Collecting torchvision
[?25l  Downloading https://files.pythonhosted.org/packages/ca/0d/f00b2885711e08bd71242ebe7b96561e6f6d01fdb4b9dcf4d37e2e13c5e1/torchvision-0.2.1-py2.py3-none-any.whl (54kB)
[K    100% |████████████████████████████████| 61kB 4.0MB/s 
Collecting pillow>=4.1.1 (from torchvision)
[?25l  Downloadin

# **Defining The Layers**

### **Defining the Non-Linear function (Squash)**

In [0]:
def squash(vec,dim=-1):
  squared_normal=torch.sum(vec**2,dim=dim,keepdim=True)
  fn=squared_normal / (1 + squared_normal) * vec / (torch.sqrt(squared_normal) + 1e-8)
  return fn  

### **Defining Primary Capsules**

In [0]:
class PrimaryCapsules(nn.Module):
  def __init__(self,in_channels, out_channels, dim_caps,kernel_size=9, stride=2, padding=0):
    super(PrimaryCapsules,self).__init__()
    self.dim_caps=dim_caps
    self._caps_channel = int(out_channels / dim_caps)
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    
  def forward(self,x):
    out=self.conv(x)
    out = out.view(out.size(0), self._caps_channel, out.size(2), out.size(3), self.dim_caps)
    out = out.view(out.size(0), -1, self.dim_caps)
    out=squash(out)
    return out

### **Defining routing between Capsules**

In [0]:
class Router(nn.Module):
  def __init__(self,in_dim, in_caps, num_caps, dim_caps, num_routing):
      super(Router,self).__init__()
      self.in_dim = in_dim
      self.in_caps = in_caps
      self.num_caps = num_caps
      self.dim_caps = dim_caps
      self.num_routing = num_routing
      self.W = nn.Parameter( 0.01 * torch.randn(1, num_caps, in_caps, dim_caps, in_dim ))
    
  def __repr__(self):
      
      tab= '  '
      line = '\n'
      next = ' -> '
      res = self.__class__.__name__ + '('
      res = res + line + tab + '(' + str(0) + '): ' + 'CapsuleLinear('
      res = res + str(self.in_dim) + ', ' + str(self.dim_caps) + ')'
      res = res + line + tab + '(' + str(1) + '): ' + 'Routing('
      res = res + 'Routing No =' + str(self.num_routing) + ')'
      res = res + line + ')'
      return res
  
  def forward(self,x):
    batch_size = x.size(0)
    x = x.unsqueeze(1).unsqueeze(4)
    # W @ x =(1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =(batch_size, num_caps, in_caps, dim_caps, 1)
    u_hat = torch.matmul(self.W, x)
    u_hat = u_hat.squeeze(-1)
    #Prevent flow of Gradients
    temp_u_hat = u_hat.detach()
    b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1)
    
    for route_iter in range(self.num_routing-1):
      sc = F.softmax(b, dim=1)
      # (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) =(batch_size, num_caps, in_caps, dim_caps) sum across in_caps ->(batch_size, num_caps, dim_caps)
      vec = (sc * temp_u_hat).sum(dim=2)
      v = squash(vec)
      uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
      b += uv
      
    sc = F.softmax(b, dim=1)
    vec = (sc * u_hat).sum(dim=2)
    v = squash(vec)
    return v

#**Defining the Network**

In [0]:
class CapsuleNet(nn.Module):
	def __init__(self, img_shape, channels, primary_dim, num_classes, out_dim, num_routing, kernel_size=9):
		super(CapsuleNet,self).__init__()
		self.img_shape = img_shape
		self.num_classes = num_classes

		self.conv1 = nn.Conv2d(img_shape[0], channels, kernel_size, stride=1, bias=True)
		self.relu = nn.ReLU(inplace=True)

		self.primary = PrimaryCapsules(channels, channels, primary_dim, kernel_size)
		
		primary_caps = int(channels / primary_dim * ( img_shape[1] - 2*(kernel_size-1) ) * ( img_shape[2] - 2*(kernel_size-1) ) / 4)
		self.digits = Router(primary_dim, primary_caps, num_classes, out_dim, num_routing)

		self.decoder = nn.Sequential(
			nn.Linear(out_dim * num_classes, 512),
			nn.ReLU(inplace=True),
			nn.Linear(512, 1024),
			nn.ReLU(inplace=True),
			nn.Linear(1024, int(prod(img_shape)) ),
			nn.Sigmoid()
		)

	def forward(self, x):
		out = self.conv1(x)
		out = self.relu(out)
		out = self.primary(out)
		out = self.digits(out)
		preds = torch.norm(out, dim=-1)

		# Reconstruct the *predicted* image
		_, max_length_idx = preds.max(dim=1)	
		y = torch.eye(self.num_classes)
		y = y.index_select(dim=0, index=max_length_idx).unsqueeze(2)

		reconstructions = self.decoder( (out*y).view(out.size(0), -1) )
		reconstructions = reconstructions.view(-1, *self.img_shape)

		return preds, reconstructions

# **Defining the Losses**

In [0]:
class MarginLoss(nn.Module):
    def __init__(self, size_average=False, loss_lambda=0.5):
            super(MarginLoss,self).__init__()
            self.size_average = size_average
            self.m_plus = 0.9
            self.m_minus = 0.1
            self.loss_lambda = loss_lambda
    
    def forward(self,inputs,labels):
          L_k = labels * F.relu(self.m_plus - inputs)**2 + self.loss_lambda * (1 - labels) * F.relu(inputs - self.m_minus)**2
          L_k.sum(dim=1)

          if self.size_average:return L_k.mean()
          else:return L_k.sum()
    
    

In [0]:
class CapsuleLoss(nn.Module):
	def __init__(self,loss_lambda=0.5, recon_loss_scale=5e-4, size_average=False):

		super(CapsuleLoss,self).__init__()
		self.size_average = size_average
		self.margin_loss = MarginLoss(size_average=size_average, loss_lambda=loss_lambda)
		self.reconstruction_loss = nn.MSELoss(size_average=size_average)
		self.recon_loss_scale = recon_loss_scale

	def forward(self,inputs, labels, images, reconstructions):
		margin_loss = self.margin_loss(inputs, labels)
		reconstruction_loss = self.reconstruction_loss(reconstructions, images)
		caps_loss = (margin_loss + self.recon_loss_scale * reconstruction_loss)

		return caps_loss

# **Making the Trainer**

In [0]:
import os
if not os.path.exists('Checkpoints/'):os.mkdir('Checkpoints/')
SAVE_MODEL_PATH='Checkpoints/'
class Trainer:
  def __init__(self,loaders,net,epochs,batch_size,learning_rate,num_routing=3,lr_decay=0.8):
    self.loaders=loaders
    self.img_shape=self.loaders['train'].dataset[0][0].numpy().shape
    self.epochs=epochs
    self.net=net
    
    self.criterion=CapsuleLoss()
    self.optimizer = optim.Adam(self.net.parameters(), lr=learning_rate)
    self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=lr_decay)
    print('Num params:', sum([prod(p.size()) for p in self.net.parameters()]))
    
  def __repr__(self):
    return repr(self.net)
  
  def run(self):
    classes=list(range(10))
    print(8*'#','RUN STARTED',8*'#')
    eye = torch.eye(len(classes))
    
    for epoch in range(1,self.epochs):
      for phase in ['train', 'test']:
        print(f'{phase}ing...'.upper())
        if phase=='train':
          self.net.train()
        else: self.net.eval()
        
        t0=time()
        running_loss=0.0
        correct=0
        total=0
        
        for i, (images, labels) in enumerate(self.loaders[phase]):
          t1=time()
          labels = eye[labels]
          self.optimizer.zero_grad()
          outputs, reconstructions = self.net(images)
          loss = self.criterion(outputs, labels, images, reconstructions)
          if phase == 'train':
            loss.backward()
            self.optimizer.step()
            
          running_loss+=loss.item()
          _, predicted = torch.max(outputs, 1)
          
          _,labels=torch.max(outputs,1)
          total += labels.size(0)
          
          correct += (predicted == labels).sum()
          print(' correct: ',correct,'Total: ',total)
          accuracy=float(correct)/float(total)
          if phase == 'train':
            print(f'Epoch {epoch}, Batch {i+1}, Loss {running_loss/(i+1)}',f'Accuracy {accuracy} Time {round(time()-t1, 3)}s')
        print(f'{phase.upper()} Epoch {epoch}, Loss {running_loss/(i+1)}',f'Accuracy {accuracy} Time {round(time()-t0, 3)}s')
      self.scheduler.step()
    error_rate = round((1-accuracy)*100, 2)
    t2=str(time()).replace(" ", "-")
    torch.save(self.net.state_dict(), os.path.join(SAVE_MODEL_PATH, f'{error_rate}_{t2}.pth.tar'))
    class_correct = list(0. for _ in classes)
    class_total = list(0. for _ in classes)
    for images, labels in self.loaders['test']:
      outputs, reconstructions = self.net(images)
      _, predicted = torch.max(outputs, 1)
      c = (predicted == labels).squeeze()
      for i in range(labels.size(0)):
        label = labels[i]
        class_correct[label] += c[i].item()
        class_total[label] += 1
    
    for i in range(len(classes)):
        print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
        

        


# **MNIST Datasets**

In [0]:
import torchvision
BatchSize=batch_size=64
size = 28
classes = list(range(10))
mean, std = ( ( 0.1307,), ( 0.3081,) )
loader={}
trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)]))
loader['train'] = torch.utils.data.DataLoader(trainset, batch_size=BatchSize,shuffle=True) 
testset = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)]))
loader['test']=testloader = torch.utils.data.DataLoader(testset, batch_size=BatchSize,shuffle=False)

In [18]:
epochs=5
net=CapsuleNet(img_shape=(1,28,28),channels=256,primary_dim=8,num_classes=10,out_dim=16,num_routing=3)
classe = list(range(10))
learning_rate=1e-3
lr_decay=0.95
caps_net=Trainer(loader,net=net,batch_size=128,learning_rate=learning_rate,num_routing=3, lr_decay=lr_decay,epochs=5)
caps_net.run()



Num params: 8215568
######## RUN STARTED ########
TRAINING...
 correct:  tensor(64) Total:  64
Epoch 1, Batch 1, Loss 82.69861602783203 Accuracy 1.0 Time 3.792s
 correct:  tensor(128) Total:  128
Epoch 1, Batch 2, Loss 72.00513458251953 Accuracy 1.0 Time 3.257s
 correct:  tensor(192) Total:  192
Epoch 1, Batch 3, Loss 73.6968510945638 Accuracy 1.0 Time 3.125s
 correct:  tensor(256) Total:  256
Epoch 1, Batch 4, Loss 71.00421905517578 Accuracy 1.0 Time 3.107s
 correct:  tensor(320) Total:  320
Epoch 1, Batch 5, Loss 69.04557495117187 Accuracy 1.0 Time 3.085s
 correct:  tensor(384) Total:  384
Epoch 1, Batch 6, Loss 67.22044118245442 Accuracy 1.0 Time 3.103s
 correct:  tensor(448) Total:  448
Epoch 1, Batch 7, Loss 65.4186156136649 Accuracy 1.0 Time 3.057s
 correct:  tensor(512) Total:  512
Epoch 1, Batch 8, Loss 63.480666160583496 Accuracy 1.0 Time 3.163s
 correct:  tensor(576) Total:  576
Epoch 1, Batch 9, Loss 61.81080288357205 Accuracy 1.0 Time 3.118s
 correct:  tensor(640) Total:  6

# **Saving the Complete Model**

In [22]:
import torch
os.mkdir('./Model')
PATH='/Model'
torch.save(net,PATH)

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [24]:
model=torch.load(PATH)
print(model)

CapsuleNet(
  (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (relu): ReLU(inplace)
  (primary): PrimaryCapsules(
    (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (digits): Router(
    (0): CapsuleLinear(8, 16)
    (1): Routing(Routing No =3)
  )
  (decoder): Sequential(
    (0): Linear(in_features=160, out_features=512, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU(inplace)
    (4): Linear(in_features=1024, out_features=784, bias=True)
    (5): Sigmoid()
  )
)


In [0]:
from google.colab import files
files.download(PATH)

In [26]:
files.upload()


Saving test.csv to test.csv


In [0]:
import pandas as pd
import numpy
testdata=pd.read_csv('test.csv')
dataset=testdata.iloc[:,0:]


In [45]:
files.upload()

Saving sample_submission.csv to sample_submission.csv


{'sample_submission.csv': b'ImageId,Label\r\n1,0\r\n2,0\r\n3,0\r\n4,0\r\n5,0\r\n6,0\r\n7,0\r\n8,0\r\n9,0\r\n10,0\r\n11,0\r\n12,0\r\n13,0\r\n14,0\r\n15,0\r\n16,0\r\n17,0\r\n18,0\r\n19,0\r\n20,0\r\n21,0\r\n22,0\r\n23,0\r\n24,0\r\n25,0\r\n26,0\r\n27,0\r\n28,0\r\n29,0\r\n30,0\r\n31,0\r\n32,0\r\n33,0\r\n34,0\r\n35,0\r\n36,0\r\n37,0\r\n38,0\r\n39,0\r\n40,0\r\n41,0\r\n42,0\r\n43,0\r\n44,0\r\n45,0\r\n46,0\r\n47,0\r\n48,0\r\n49,0\r\n50,0\r\n51,0\r\n52,0\r\n53,0\r\n54,0\r\n55,0\r\n56,0\r\n57,0\r\n58,0\r\n59,0\r\n60,0\r\n61,0\r\n62,0\r\n63,0\r\n64,0\r\n65,0\r\n66,0\r\n67,0\r\n68,0\r\n69,0\r\n70,0\r\n71,0\r\n72,0\r\n73,0\r\n74,0\r\n75,0\r\n76,0\r\n77,0\r\n78,0\r\n79,0\r\n80,0\r\n81,0\r\n82,0\r\n83,0\r\n84,0\r\n85,0\r\n86,0\r\n87,0\r\n88,0\r\n89,0\r\n90,0\r\n91,0\r\n92,0\r\n93,0\r\n94,0\r\n95,0\r\n96,0\r\n97,0\r\n98,0\r\n99,0\r\n100,0\r\n101,0\r\n102,0\r\n103,0\r\n104,0\r\n105,0\r\n106,0\r\n107,0\r\n108,0\r\n109,0\r\n110,0\r\n111,0\r\n112,0\r\n113,0\r\n114,0\r\n115,0\r\n116,0\r\n117,0\r\n118,0\r\n1

In [0]:
submit=pd.read_csv('sample_submission.csv')
fin=submit.iloc[:,:]

In [85]:
net.eval()
print(8*'#','Started Predicting',8*'#')
for i in range(len(testdata)):
  b=dataset.iloc[i,:]
  a=numpy.array(b)
  a=torch.from_numpy(a)
  a=a.type(torch.FloatTensor)
  a=a.reshape(1,1,28,28)
  a=a/256
  label,_=net(a)
  _, predicted = torch.max(label, 1)
  predicted=predicted.data.numpy()
  fin.iloc[i,1]=predicted
  z=(i+1)/len(testdata)*100
  if z%1==0:
    print(4*'#',z,' % Complete',4*'#')
  
  
  
  

######## Started Predicting ########
#### 1.0  % Complete ####
#### 2.0  % Complete ####
#### 3.0  % Complete ####
#### 4.0  % Complete ####
#### 5.0  % Complete ####
#### 6.0  % Complete ####
#### 8.0  % Complete ####
#### 9.0  % Complete ####
#### 10.0  % Complete ####
#### 11.0  % Complete ####
#### 12.0  % Complete ####
#### 13.0  % Complete ####
#### 15.0  % Complete ####
#### 16.0  % Complete ####
#### 17.0  % Complete ####
#### 18.0  % Complete ####
#### 19.0  % Complete ####
#### 20.0  % Complete ####
#### 21.0  % Complete ####
#### 22.0  % Complete ####
#### 23.0  % Complete ####
#### 24.0  % Complete ####
#### 25.0  % Complete ####
#### 26.0  % Complete ####
#### 27.0  % Complete ####
#### 30.0  % Complete ####
#### 31.0  % Complete ####
#### 32.0  % Complete ####
#### 33.0  % Complete ####
#### 34.0  % Complete ####
#### 35.0  % Complete ####
#### 36.0  % Complete ####
#### 37.0  % Complete ####
#### 38.0  % Complete ####
#### 39.0  % Complete ####
#### 40.0  % Complete ####

In [0]:
fin.to_csv('Final.csv')

In [0]:
files.download('Final.csv')