In [1]:
#https://github.com/ageron/handson-ml/blob/master/extra_capsnets.ipynb

In [2]:
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn

In [3]:
torch.set_default_tensor_type('torch.FloatTensor')

In [4]:
def squash(s, dim=-1, epsilon=1e-3):
#     print(s.size())
    norm_s = torch.sqrt(torch.norm(s) + epsilon)
#     print(norm_s)
    v = (norm_s / (1 + norm_s)) * (s / norm_s)
    return v

In [5]:
# https://discuss.pytorch.org/t/why-softmax-function-cant-specify-the-dimension-to-operate/2637
def softmax(input, axis=1):
    input_size = input.size()
    
    trans_input = input.transpose(axis, len(input_size)-1)
    trans_size = trans_input.size()

    input_2d = trans_input.contiguous().view(-1, trans_size[-1])
    
    soft_max_2d = nn.functional.softmax(input_2d)
    
    soft_max_nd = soft_max_2d.view(*trans_size)
    return soft_max_nd.transpose(axis, len(input_size)-1)

In [6]:
class CapsuleLayer(nn.Module):
    def __init__(self, in_channels, capsule_dimension, num_capsules, kernel_size=None, routing=False, num_iterations=0, stride=1):
        
        super(CapsuleLayer, self).__init__()
        
        self.in_channels = in_channels
        self.capsule_dimension = capsule_dimension
        self.num_capsules = num_capsules
        self.kernel_size = kernel_size
        self.routing = routing
        self.num_iterations = num_iterations
        self.stride = stride
        
        if not self.routing:
            
            self.conv = nn.Conv2d(self.in_channels, self.capsule_dimension * self.num_capsules,\
                                  self.kernel_size, self.stride)      
            
        
        else:
            
            # todo: to be calculated later
            self.width = 1152
            
            self.weights = torch.autograd.Variable(torch.from_numpy(0.01 * np.random.randn(1, self.width,\
                                                                                           self.num_capsules, self.capsule_dimension,\
                                                                                           self.in_channels)), requires_grad=True)
            
            

        
    def forward(self, x):
        
        if not self.routing:
            
            self.conv_out = self.conv(x)
            
            # todo: to be calculated later
            width = 6 # (W - F + 2 * P) / S + 1
            
            conv2 = self.conv_out.view(-1, width * width * self.num_capsules, self.capsule_dimension)
            
            squash_conv = squash(conv2)
            
#             print("squash conv", type(squash_conv))
            
            return squash_conv
                
                
        else:
            # x -> [batch size , 1152, 8]
             
            batch_size = x.size()[0]
            
#             print("x", x.size())
#             print("w", self.weights.size())
            
            u = x.unsqueeze(2).expand(x.size()[0], x.size()[1], self.num_capsules, x.size()[2]).unsqueeze(4).type(torch.FloatTensor)
            
            self.weights = self.weights.expand(batch_size, self.weights.size()[1], self.weights.size()[2],\
                                              self.weights.size()[3], self.weights.size()[4]).type(torch.FloatTensor)
    
#             print("w with batch", self.weights.size())
#             print("u", u.size())
            
#             print("w type", type(self.weights))
#             print("u type", type(u))
                    
            u_hat = torch.matmul(self.weights, u)
        
#             print("u_hat", u_hat.size())
            
            v = None
        
            b = Variable(torch.zeros(batch_size, x.size()[1], self.num_capsules, 1, 1).type(torch.FloatTensor), requires_grad=True)
            
            for r in range(self.num_iterations):
                
#                 print("b", b.size())
                
                c = softmax(b, axis=2)
        
#                 print(c.requires_grad)
                
#                 c = c.expand(c.size()[0], c.size()[1], c.size()[2], u_hat.size()[3], c.size()[4])
                
#                 print("c", c.size())
                
#                 print("c numpy", c.data.cpu().numpy().shape)
#                 print("u_hat numpy", u_hat.data.cpu().numpy().shape)

#                 s_numpy = np.matrix.dot(c.data.cpu().numpy(), u_hat.data.cpu().numpy())
                
#                 print("s_numpy", s_numpy.shape)
                
                s = c * u_hat
                
#                 print("s", s.size())
                
                s = torch.sum(s, dim=1)
                
#                 print("s", s.size())
                
#                 s = s.unsqueeze(2)
                
#                 print("s", s.size())
                
                v = squash(s, dim=-2)
                
#                 print("v", v.size())
                                
                temp1 = v.unsqueeze(1).expand(v.size()[0], u_hat.size()[1], v.size()[1], v.size()[2], v.size()[3])
            
                v = v.unsqueeze(1)
            
#                 print("v", v.size())
            
                temp = torch.matmul(u_hat.transpose(-1, -2), temp1)
                
#                 print("temp", temp.size())
                
#                 print("temp type", type(temp))
                
#                 print("b type", type(b))
                
                b = b + temp
                
#                 print("b", b.size())
    
#             print("v", type(v))

#             print(v.requires_grad)
                
            return v
                
            
        

In [47]:
class Net(nn.Module):
    def __init__(self):
        
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 256, 9)
        
        self.primary_capsule = CapsuleLayer(256, 8, 32, kernel_size=9, routing=False, num_iterations=0, stride=2)
        
        self.digital_capsule = CapsuleLayer(8, 16, 10, kernel_size=None, routing=True, num_iterations=5, stride=1)
        
        self.criterion = torch.nn.MSELoss(size_average=True)
        
        
    def forward(self, x, target):
        
        x = self.conv1(x)
        
        x = nn.functional.relu(x)
        
        x = self.primary_capsule(x)
        
        x = self.digital_capsule(x)
        
        x = x.squeeze()
        
#         print(x.size())
        
#         y_predicted = torch.squeeze(x, dim=1)
        
        
#         start
                
        x = torch.sqrt(torch.norm(x, dim=2) + 1e-3)
        
        x = x.squeeze()
        
        
#         sof = nn.functional.softmax(x)
                
#         values, _ = torch.max(sof, dim=1)
        
#         aa = values.data.cpu().numpy()
#         aa = aa.reshape(aa.shape[0], 1)
        
#         vvv = x.data.cpu().numpy() == aa
        
#         y_predicted_temp = vvv + 0
        
#         y_predicted = Variable(torch.from_numpy(y_predicted_temp), requires_grad=True).type(torch.FloatTensor)
        
        
        y_predicted = x
    
    
    
    
    
    
#         end
        
#         y_predicted = nn.Softmax()(x)
        
#         print(y_predicted)
        
#         print("y", y_predicted.size())
        
        #start
#         n_labels = 10
        
#         targets = target.data.cpu().numpy()

#         ohm = np.zeros((targets.shape[0], n_labels))
#         #empty one-hot matrix
#         ohm[np.arange(targets.shape[0]), targets] = 1
        
#         target = Variable(torch.from_numpy(ohm), requires_grad=False).type(torch.LongTensor)
        #end
        
#         print("target", target.size())
        
#         print(target.requires_grad)
#         print(y_predicted.requires_grad)
        
        
        
        
        
        
        
#         _, y_predicted = torch.max(x, dim=1)
        
#         print(y_predicted.requires_grad)
        
#         temp3 = np.eye(10) * (y_predicted * 10)
        
#         print(temp3)
        
        
#         print(list(zip(y_predicted, target)))

#         print(y_predicted == target)
                
        
#         print(y_predicted.size(), target.size())
                
#         loss = torch.sum(y_predicted != target)

#         loss = self.criterion(y_predicted, target)

#         loss = mse_loss(y_predicted, target)

#         loss = nn.functional.nll_loss(nn.functional.log_softmax(y_predicted), target)

#         loss = self.criterion(y_predicted, target)

        target = target.type(torch.LongTensor)
    
#         print(target.size(), y_predicted.size())
#         print(target.data.type())
#         print(y_predicted.data.type())

        loss_fn = nn.CrossEntropyLoss() 
        loss = loss_fn(y_predicted, target)

    
        return x, loss
        

In [48]:
net = Net()

In [49]:
net

Net (
  (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (primary_capsule): CapsuleLayer (
    (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (digital_capsule): CapsuleLayer (
  )
  (criterion): MSELoss (
  )
)

In [10]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
## load mnist dataset
root = '../data'
download = False
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(root=root, train=True, transform=trans, download=download)
test_set = dset.MNIST(root=root, train=False, transform=trans)

batch_size = 128
# kwargs = {'num_workers': 1}
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


In [11]:
import torch.optim as optim
import torch.nn.functional as F

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.01)
for epoch in range(10):
    # trainning
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = Variable(x), Variable(target)
        _, loss = net(x, target)
        loss.backward()
        optimizer.step()
#         if batch_idx % 100 == 0:
        print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(epoch, batch_idx, loss.data[0]))
    # testing
    correct_cnt, ave_loss = 0, 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = Variable(x, volatile=True), Variable(target, volatile=True)
        score, loss = net(x, target)
        _, pred_label = torch.max(score.data, 1)
        correct_cnt += (pred_label == target.data).sum()
        ave_loss += loss.data[0]
    accuracy = correct_cnt*1.0/len(test_loader)/batch_size
    ave_loss /= len(test_loader)
    print('==>>> epoch: {}, test loss: {:.6f}, accuracy: {:.4f}'.format(epoch, ave_loss, accuracy))

torch.save(net.state_dict(), 'capsule_net')

==>>> epoch: 0, batch index: 0, train loss: 2.289733
==>>> epoch: 0, batch index: 1, train loss: 2.253983
==>>> epoch: 0, batch index: 2, train loss: 2.242553
==>>> epoch: 0, batch index: 3, train loss: 2.219036
==>>> epoch: 0, batch index: 4, train loss: 2.213325
==>>> epoch: 0, batch index: 5, train loss: 2.206645
==>>> epoch: 0, batch index: 6, train loss: 2.198196
==>>> epoch: 0, batch index: 7, train loss: 2.204408
==>>> epoch: 0, batch index: 8, train loss: 2.184974
==>>> epoch: 0, batch index: 9, train loss: 2.172639
==>>> epoch: 0, batch index: 10, train loss: 2.165656
==>>> epoch: 0, batch index: 11, train loss: 2.163977
==>>> epoch: 0, batch index: 12, train loss: 2.158755
==>>> epoch: 0, batch index: 13, train loss: 2.164740
==>>> epoch: 0, batch index: 14, train loss: 2.164521
==>>> epoch: 0, batch index: 15, train loss: 2.148664
==>>> epoch: 0, batch index: 16, train loss: 2.136095
==>>> epoch: 0, batch index: 17, train loss: 2.137916
==>>> epoch: 0, batch index: 18, train