<a href = "https://zhuanlan.zhihu.com/p/32156167">Capsule Network Tutorial</a>

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets,transforms

device = torch.device('cpu')
if torch.cuda.is_available:
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [None]:
class Mnist:
    def __init__(self,batch_size):
        dataset_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.,), (1.0,))
        ])
        
        train_data = datasets.MNIST('../data', train = True, download = True, transform = dataset_transform)
        test_data = datasets.MNIST('../data',train = False, download = True,transform = dataset_transform)
        self.train_loader = torch.utils.data.DataLoader(train_data,batch_size = batch_size,shuffle = True)
        self.test_loader = torch.utils.data.DataLoader(test_data,batch_size = batch_size,shuffle = True)

In [None]:
class ConvLayer(nn.Module):
    def __init__(self,in_channel = 1, out_channel = 256,kernel_size = 9,stride = 1,padding = 0):
        super(ConvLayer,self).__init__()
        
        self.conv = nn.Conv2d(in_channels = in_channel,out_channels = out_channel,kernel_size = kernel_size, stride = stride,padding = padding)
        self.relu = nn.ReLU()
    def forward(self, x):
        return self.relu(self.conv(x))

In [None]:
class PrimaryCapsule(nn.Module):
    def __init__(self,num_capsule = 8, in_channel = 256, out_channel = 32,kernel_size = 9, stride = 2):
        super(PrimaryCapsule,self).__init__()
        self.num_capsule = num_capsule
        self.capsules = nn.ModuleList([nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size,stride = stride) for i in range(num_capsule)])
    def forward(self,x):
        out = [capsule(x) for capsule in self.capsules]
        out = torch.stack(out,dim = 1)
        out = out.view(out.size()[0],-1,self.num_capsule)
        return self.squash(out)
      
    def squash(self,x):
        square_x = (x ** 2).sum(-1,keepdim = True)
        x_ = (square_x * x) / ((1+square_x) * torch.sqrt(square_x))
        return x_

In [None]:
class DigitCapsule(nn.Module):
    def __init__(self,num_route = 1152,num_capsule = 10,input_channel = 8, output_channel = 16,routing_iter = 3):
        super(DigitCapsule,self).__init__()
        self.input_channel = input_channel
        self.output_channel = output_channel
        self.num_capsule = num_capsule
        self.num_route = num_route
        
        self.routing_iter = routing_iter
        self.softmax = nn.Softmax()
        
        self.w = nn.Parameter(torch.randn(1,num_route,num_capsule, output_channel, input_channel))
        
    def forward(self, x):
        batch_size = x.size()[0]
        x = torch.stack([x] * self.num_capsule,dim = 2).unsqueeze(4)
        '''
            num_route 其實是 input_capsule.
            那首先這裡我們可以看到原本的neural network中， input shape通常是(batch_size, num_neuron)，W = (num_neuron,new_num_neuron)，
            那這裡就可以看到在capsule中，其input_shape會是(batch_size,num_capsule,dimension)，而在capsule的W中會需要有新的new_num_capsule,和new_dimension，
            ，而因為程式碼在進行矩陣相乘時，其前面dim - 2的shape都要相同，例如input shape如果是(batch_size,num_capsule,num_dim)
            而W是(new_num_capsule,new_dimension_old_dimension)，而為了能進行matmul其共同就要擁有batch_size, num_capsule, new_capsule, dimension, new_dimension.
            
            
            這個部分將x重複output所有的capsule次數，所以其shape會變成(batch_size, num_route, output_capsule,new_dimension, olddimension)
            ，那我們可以想一下為什麼要重複output的capsule的數目，這裡我們知道在input_capsule -> output_capsule時，每個output_capsule其實都會看到
            一組完整的input_capsule，而因為在創建W時，一定會有route * output_capsule的數量，而為了使W與x在做相乘時，是能直接用torch.matmul的，
            所以讓該x重複多次，而最終的unsqueeze(4)也是為了讓matmul能夠進行
            
            我們看W的shape是 (batch_size, route, output_capsule,new_dimension,dimension)
        '''
        
        w = torch.cat([self.w] * batch_size,dim = 0)
        ## shape of x is (batch_size, input_capsule, output_capsule,input_dimension,1)
        ## shape of w is (batch_size, input_capsule, output_capsule, output_dimension ,input_dimension)
        u_hat = torch.matmul(w,x)
        
        ## shape of u_hat is (batch_size, input_capsule, output_channel ,output_dimension, 1)
        
        bij = torch.zeros(1,self.num_route, self.num_capsule,1)
        ##那此部分比較能知道bij就是要有 input_capsule * output_channel 的 shape ，其中前後 1 , ... ,1是為了使程式碼能進行。
        ##因為要進行相乘需要有相同或是較少但前面有的shape，例如(1,2,3,1)就可以和(1,2,3,2,2)，進行相乘，而這前提是前三個dim相同，而最後1，
        ##則被視為一個scalar所以就沒差。
        
        for i in range(self.routing_iter):
        
            cij = self.softmax(bij)
        
            cij = torch.cat([cij] * batch_size,dim = 0).unsqueeze(4)
            
            sj = torch.sum((cij * u_hat),dim = 1,keepdim = True)
            sj = sj.squeeze(4)
            vj = self.squash(sj).unsqueeze(4)
            #vj = self.squash(sj)
            ## shape of vj is (batch_size, 1, output_capsule,output_channel, 1)
            ## shape of u_hat is (batch_size, input_capsule, output_capsule, output_channel, 1)
            aij = torch.matmul(u_hat.transpose(3,4),torch.cat([vj] * self.num_route, dim = 1)).mean(0,keepdim = True).squeeze(4)
            if i < self.routing_iter - 1:
                bij = aij + bij
            ## shape of aij is (batch_size, input_capsule, output_capsule, 1 , 1)
        return vj.squeeze(1)
    def squash(self,x):
        square_x = (x ** 2).sum(dim = -1,keepdim = True)
        output = (square_x * x)/((1+square_x) * torch.sqrt(square_x))
        return output

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.reconstruction_layer = nn.Sequential(
            nn.Linear(16*10,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()
            
        )
        self.softmax = nn.Softmax(dim = 1)
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2)) ## 此處是為了得到這個每筆資料對每個class的vector length 
        ## shape of classes is (batch_size, 10, 1)
        classes = self.softmax(classes)
        value, index = classes.max(dim = 1)
        
        masked = torch.eye(10)
        
        masked = masked.index_select(dim = 0, index = index.squeeze(1).data)
        
        reconstruction = self.reconstruction_layer( (x * masked[:,:,None,None]).view(x.size()[0], -1))
        reconstruction = reconstruction.view(reconstruction.size()[0], 1, 28, 28)
        
        return reconstruction, masked 
                

In [None]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet,self).__init__()
        self.conv = ConvLayer()
        self.pricap = PrimaryCapsule()
        self.digcap = DigitCapsule()
        self.decoder = Decoder()
        self.MSELoss = nn.MSELoss()
        self.discount = 5e-4
        
    def forward(self,x):
        conv_out = self.conv(x)
        pricap_out = self.pricap(conv_out)
        digcap_out = self.digcap(pricap_out)
        reconstruction, masked = self.decoder(digcap_out, x)
        
        return digcap_out, reconstruction, masked 
      
    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
      
    def margin_loss(self,digcap_out,label):
        
        batch_size = digcap_out.size()[0]
        relu = nn.ReLU()
        
        vector_length = torch.sqrt((digcap_out ** 2).sum(2))
        
        left_loss = relu(0.9 - vector_length).view(batch_size, -1)
        right_loss = relu(vector_length - 0.1).view(batch_size, -1)

        loss = label * left_loss +0.5 * (1.0 - label) * right_loss
        return loss.sum(dim = 1).mean()
        
        ##shape of vector_length is (batch_size,output_capsule,1)
    def reconstruction_loss(self, predict_x, x):
        return self.MSELoss(predict_x.view(predict_x.size()[0],-1), x.view(x.size()[0],-1)) 

In [None]:
capsnet = CapsNet()
capsnet.to(device)

CapsNet(
  (conv): ConvLayer(
    (conv): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
    (relu): ReLU()
  )
  (pricap): PrimaryCapsule(
    (capsules): ModuleList(
      (0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (1): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (2): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (3): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (4): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (5): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (6): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (7): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
    )
  )
  (digcap): DigitCapsule(
    (softmax): Softmax()
  )
  (decoder): Decoder(
    (reconstruction_layer): Sequential(
      (0): Linear(in_features=160, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_fea

In [None]:
def plot_result(reconstruction_images, images):
    random_index = np.random.choice(len(reconstruction_images),5)
    reconstruction_images = reconstruction_images[random_index]
    images = images[random_index]
    reconstruction_images = reconstruction_images.squeeze()
    images = images.squeeze()
    #reconstruction_images = reconstruction_images.transpose((0,2,3,1))
    #images = images.transpose((0,2,3,1))
    
    plt.subplots(1,5)
    plt.title('Real')
    for i, image in enumerate(images):
        
        plt.subplot(1, 5, i+1)
        plt.imshow(image)
    plt.show()
    
    plt.subplots(1,5)
    plt.title('Predict')
    for i, image in enumerate(reconstruction_images):
        
        plt.subplot(1, 5, i+1)
        plt.imshow(image)
    plt.show()

In [None]:
epochs = 1
adam = optim.Adam(capsnet.parameters())
#adam2 = optim.Adam(capsnet2.parameters())

In [None]:
mnist = Mnist(100)
for epoch in range(epochs):
    capsnet.train()
    for i, element in enumerate(mnist.train_loader):
        data, target = element
        data = data.to(device)
        target = target.to(device)
        target = torch.eye(10).index_select(dim = 0,index = target)
        
        digcap_out, reconstruction, masked = capsnet(data)
        loss = capsnet.loss(data,digcap_out,target,reconstruction)
        
        adam.zero_grad()
        loss.backward()
        adam.step()
        if i % 10 == 0:
            print(loss)
            print("train accuracy:", sum(np.argmax(masked.cpu().detach().numpy(),1) == np.argmax(target.cpu().detach().numpy())) / float(30))
            plot_result(reconstruction.cpu().detach().numpy(),data.cpu().detach().numpy())
            

In [None]:
tensor = torch.randn(1,5)
print(tensor.size())

In [None]:
print(tensor.sum(-1,keepdim = True))
print(tensor.sum(-1))

In [None]:
print(tensor/tensor.sum(-1,keepdim = True))
print(tensor/tensor.sum(-1))

tensor([[ 2.2307, -1.9394,  1.1864,  1.1580, -1.6357]])
tensor([[ 2.2307, -1.9394,  1.1864,  1.1580, -1.6357]])


In [None]:
test_tensor = torch.zeros(1,10,2,1)
softmax = nn.Softmax()
output_tensor = softmax(test_tensor)
output_tensor2 = F.softmax(test_tensor)

print(output_tensor.size())
print(output_tensor)
print(output_tensor2.size())
print(output_tensor2)

In [None]:
tensor = torch.tensor([[[2.0],[3.0]]])
print(tensor.size())

torch.Size([1, 2, 1])


In [None]:
test_tensor = torch.ones(1,2,5)

In [None]:
print(tensor * test_tensor)

tensor([[[2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]]])


In [None]:
def squash(x):
    square_x = (x**2).sum(-1,keepdim = True)
    output = (square_x * x)/((square_x) * torch.sqrt(square_x))
    return output

In [None]:
test_tensor = torch.ones(1,2,3,2)*2
##print(torch.sum(test_tensor,dim = -1))
square_test = torch.sqrt((test_tensor**2).sum(-1,keepdim = True))
print(test_tensor)
print(square_test)
print(torch.sum(test_tensor - square_test))

tensor([[[[2., 2.],
          [2., 2.],
          [2., 2.]],

         [[2., 2.],
          [2., 2.],
          [2., 2.]]]])
tensor([[[[2.8284],
          [2.8284],
          [2.8284]],

         [[2.8284],
          [2.8284],
          [2.8284]]]])
tensor(-9.9411)


In [None]:
print(test_tensor)
print(test_tensor **2)
print(torch.sqrt(test_tensor ** 2))

In [None]:
test_tensor = torch.ones(2,2,3,1)
output = test_tensor.sum(dim = 2)
print(output)
print(output.size())

tensor([[[3.],
         [3.]],

        [[3.],
         [3.]]])
torch.Size([2, 2, 1])


In [None]:
test_tensor = torch.ones(2,10,1)
#print(F.softmax(test_tensor))
#print(F.softmax(test_tensor,dim = 1))
value, index = test_tensor.max(1)
print(index.size())
masked = torch.eye(10)
masked = masked.index_select(dim = 0,index = index.squeeze(1).data)
print(masked.size())
print(masked)

torch.Size([2, 1])
torch.Size([2, 10])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])


In [None]:
print(torch.eye(10))

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])


In [None]:
array = torch.ones(1,2)
array2 = torch.ones(1,2,3)
print(array2 * array[:,:])


RuntimeError: ignored

In [None]:
value = 10
def change_value(value):
    print('Before changing, value is ',value)
    value = 20
change_value(value)
print('After changing value is ',value)

Before changing, value is  10
After changing value is  10
