In [36]:
#https://github.com/ageron/handson-ml/blob/master/extra_capsnets.ipynb

In [1]:
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn

In [2]:
torch.set_default_tensor_type('torch.FloatTensor')

In [3]:
def squash(s, dim=-1, epsilon=1e-3):
#     print(s.size())
    norm_s = torch.sqrt(torch.norm(s) + epsilon)
#     print(norm_s)
    v = (norm_s / (1 + norm_s)) * (s / norm_s)
    return v

In [4]:
# https://discuss.pytorch.org/t/why-softmax-function-cant-specify-the-dimension-to-operate/2637
def softmax(input, axis=1):
    input_size = input.size()
    
    trans_input = input.transpose(axis, len(input_size)-1)
    trans_size = trans_input.size()

    input_2d = trans_input.contiguous().view(-1, trans_size[-1])
    
    soft_max_2d = nn.functional.softmax(input_2d)
    
    soft_max_nd = soft_max_2d.view(*trans_size)
    return soft_max_nd.transpose(axis, len(input_size)-1)

In [27]:
#https://github.com/cedrickchee/capsule-net-pytorch/blob/master/model.py
def loss(out_digit_caps, target, size_average=True):
        """Custom loss function
        Args:
            image: [batch_size, 1, 28, 28] MNIST samples.
            out_digit_caps: [batch_size, 10, 16, 1] The output from `DigitCaps` layer.
            target: [batch_size, 10] One-hot MNIST dataset labels.
            size_average: A boolean to enable mean loss (average loss over batch size).
        Returns:
            total_loss: A scalar Variable of total loss.
            m_loss: A scalar of margin loss.
            recon_loss: A scalar of reconstruction loss.
        """
        m_loss = margin_loss(out_digit_caps, target)
        if size_average:
            m_loss = m_loss.mean()

        total_loss = m_loss

        return total_loss

def margin_loss(input, target):
        """
        Class loss
        Implement equation 4 in section 3 'Margin loss for digit existence' in the paper.
        Args:
            input: [batch_size, 10, 16, 1] The output from `DigitCaps` layer.
            target: target: [batch_size, 10] One-hot MNIST labels.
        Returns:
            l_c: A scalar of class loss or also know as margin loss.
        """
        batch_size = input.size(0)

        # ||vc|| also known as norm.
        v_c = torch.sqrt((input**2).sum(dim=2, keepdim=True))

        # Calculate left and right max() terms.
        zero = Variable(torch.zeros(1))

        m_plus = 0.9
        m_minus = 0.1
        loss_lambda = 0.5
        max_left = torch.max(m_plus - v_c, zero).view(batch_size, -1)**2
        max_right = torch.max(v_c - m_minus, zero).view(batch_size, -1)**2
        t_c = target
        # Lc is margin loss for each digit of class c
        l_c = t_c * max_left + loss_lambda * (1.0 - t_c) * max_right
        l_c = l_c.sum(dim=1)

        return l_c

In [6]:
class CapsuleLayer(nn.Module):
    def __init__(self, in_channels, capsule_dimension, num_capsules, kernel_size=None, routing=False, num_iterations=0, stride=1):
        
        super(CapsuleLayer, self).__init__()
        
        self.in_channels = in_channels
        self.capsule_dimension = capsule_dimension
        self.num_capsules = num_capsules
        self.kernel_size = kernel_size
        self.routing = routing
        self.num_iterations = num_iterations
        self.stride = stride
        
        if not self.routing:
            
            self.conv = nn.Conv2d(self.in_channels, self.capsule_dimension * self.num_capsules,\
                                  self.kernel_size, self.stride)      
            
        
        else:
            
            # todo: to be calculated later
            self.width = 1152
            
            self.weights = torch.autograd.Variable(torch.from_numpy(0.01 * np.random.randn(1, self.width,\
                                                                                           self.num_capsules, self.capsule_dimension,\
                                                                                           self.in_channels)), requires_grad=True)
            
            

        
    def forward(self, x):
        
        if not self.routing:
            
            self.conv_out = self.conv(x)
            
            # todo: to be calculated later
            width = 6 # (W - F + 2 * P) / S + 1
            
            conv2 = self.conv_out.view(-1, width * width * self.num_capsules, self.capsule_dimension)
            
            squash_conv = squash(conv2)
            
#             print("squash conv", type(squash_conv))
            
            return squash_conv
                
                
        else:
            # x -> [batch size , 1152, 8]
             
            batch_size = x.size()[0]
            
#             print("x", x.size())
#             print("w", self.weights.size())
            
            u = x.unsqueeze(2).expand(x.size()[0], x.size()[1], self.num_capsules, x.size()[2]).unsqueeze(4).type(torch.FloatTensor)
            
            self.weights = self.weights.expand(batch_size, self.weights.size()[1], self.weights.size()[2],\
                                              self.weights.size()[3], self.weights.size()[4]).type(torch.FloatTensor)
    
#             print("w with batch", self.weights.size())
#             print("u", u.size())
            
#             print("w type", type(self.weights))
#             print("u type", type(u))
                    
            u_hat = torch.matmul(self.weights, u)
        
#             print("u_hat", u_hat.size())
            
            v = None
        
            b = Variable(torch.zeros(batch_size, x.size()[1], self.num_capsules, 1, 1).type(torch.FloatTensor), requires_grad=True)
            
            for r in range(self.num_iterations):
                
#                 print("b", b.size())
                
                c = softmax(b, axis=2)
        
#                 print(c.requires_grad)
                
#                 c = c.expand(c.size()[0], c.size()[1], c.size()[2], u_hat.size()[3], c.size()[4])
                
#                 print("c", c.size())
                
#                 print("c numpy", c.data.cpu().numpy().shape)
#                 print("u_hat numpy", u_hat.data.cpu().numpy().shape)

#                 s_numpy = np.matrix.dot(c.data.cpu().numpy(), u_hat.data.cpu().numpy())
                
#                 print("s_numpy", s_numpy.shape)
                
                s = c * u_hat
                
#                 print("s", s.size())
                
                s = torch.sum(s, dim=1)
                
#                 print("s", s.size())
                
#                 s = s.unsqueeze(2)
                
#                 print("s", s.size())
                
                v = squash(s, dim=-2)
                
#                 print("v", v.size())
                                
                temp1 = v.unsqueeze(1).expand(v.size()[0], u_hat.size()[1], v.size()[1], v.size()[2], v.size()[3])
            
                v = v.unsqueeze(1)
            
#                 print("v", v.size())
            
                temp = torch.matmul(u_hat.transpose(-1, -2), temp1)
                
#                 print("temp", temp.size())
                
#                 print("temp type", type(temp))
                
#                 print("b type", type(b))
                
                b = b + temp
                
#                 print("b", b.size())
    
#             print("v", type(v))

#             print(v.requires_grad)
                
            return v
                
            
        

In [32]:
class Net(nn.Module):
    def __init__(self):
        
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 256, 9)
        
        self.primary_capsule = CapsuleLayer(256, 8, 32, kernel_size=9, routing=False, num_iterations=0, stride=2)
        
        self.digital_capsule = CapsuleLayer(8, 16, 10, kernel_size=None, routing=True, num_iterations=5, stride=1)
        
        self.criterion = torch.nn.MSELoss(size_average=True)
        
        
    def forward(self, x, target):
        
        x = self.conv1(x)
        
        x = nn.functional.relu(x)
        
        x = self.primary_capsule(x)
        
        x = self.digital_capsule(x)
        
        y_predicted = torch.squeeze(x, dim=1)
        
        
        #start
                
#         x = torch.sqrt(torch.norm(x, dim=2) + 1e-3)
        
#         x = x.squeeze()
        
#         sof = nn.functional.softmax(x)
        
#         values, _ = torch.max(sof, dim=1)
        
#         aa = values.data.cpu().numpy()
#         aa = aa.reshape(aa.shape[0], 1)
        
#         vvv = x.data.cpu().numpy() == aa
        
#         y_predicted_temp = vvv + 0
        
#         y_predicted = torch.from_numpy(y_predicted_temp)
        
#         print(y_predicted.size())
        #end
        
#         y_predicted = nn.Softmax()(x)
        
#         print(y_predicted)
        
#         print("y", y_predicted.size())
        
        #start
        n_labels = 10
        
        targets = target.data.cpu().numpy()

        ohm = np.zeros((targets.shape[0], n_labels))
        #empty one-hot matrix
        ohm[np.arange(targets.shape[0]), targets] = 1
        
        target = Variable(torch.from_numpy(ohm), requires_grad=True)
        #end
        
#         print(target.size())
        
#         print(target.requires_grad)
#         print(y_predicted.requires_grad)
        
        
        
        
        
        
        
#         _, y_predicted = torch.max(x, dim=1)
        
#         print(y_predicted.requires_grad)
        
#         temp3 = np.eye(10) * (y_predicted * 10)
        
#         print(temp3)
        
        
#         print(list(zip(y_predicted, target)))

#         print(y_predicted == target)
                
        
#         print(y_predicted.size(), target.size())
                
#         loss = torch.sum(y_predicted != target)

#         loss = self.criterion(y_predicted, target)

#         loss = mse_loss(y_predicted, target)

#         loss = nn.functional.nll_loss(nn.functional.log_softmax(y_predicted), target)

#         loss = self.criterion(y_predicted, target)


        print(y_predicted.size(), target.size())

        loss_value = loss(y_predicted, target)
    
        return x, loss_value
        

In [33]:
net = Net()

In [34]:
net

Net (
  (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (primary_capsule): CapsuleLayer (
    (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (digital_capsule): CapsuleLayer (
  )
  (criterion): MSELoss (
  )
)

In [10]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
## load mnist dataset
root = '../data'
download = False
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(root=root, train=True, transform=trans, download=download)
test_set = dset.MNIST(root=root, train=False, transform=trans)

batch_size = 128
# kwargs = {'num_workers': 1}
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


In [11]:
import torch.optim as optim
import torch.nn.functional as F

In [35]:
optimizer = optim.Adam(net.parameters(), lr=0.01)
for epoch in range(10):
    # trainning
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = Variable(x), Variable(target)
        _, loss = net(x, target)
        loss.backward()
        optimizer.step()
#         if batch_idx % 100 == 0:
        print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(epoch, batch_idx, loss.data[0]))
#     # testing
#     correct_cnt, ave_loss = 0, 0
#     for batch_idx, (x, target) in enumerate(test_loader):
#         x, target = Variable(x, volatile=True), Variable(target, volatile=True)
#         score, loss = net(x, target)
#         _, pred_label = torch.max(score.data, 1)
#         correct_cnt += (pred_label == target.data).sum()
#         ave_loss += loss.data[0]
#     accuracy = correct_cnt*1.0/len(test_loader)/batch_size
#     ave_loss /= len(test_loader)
#     print('==>>> epoch: {}, test loss: {:.6f}, accuracy: {:.4f}'.format(epoch, ave_loss, accuracy))

torch.save(net.state_dict(), 'capsule_net')

torch.Size([128, 10, 16, 1]) torch.Size([128, 10])


TypeError: mul received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
 * (float value)
      didn't match because some of the arguments have invalid types: ([31;1mtorch.FloatTensor[0m)
 * (torch.DoubleTensor other)
      didn't match because some of the arguments have invalid types: ([31;1mtorch.FloatTensor[0m)


In [45]:
a = torch.Tensor([[1, 2, 3], [4, 2, 1]])
a


 1  2  3
 4  2  1
[torch.FloatTensor of size 2x3]

In [46]:
torch.max(a, dim=1)

(
  3
  4
 [torch.FloatTensor of size 2], 
  2
  0
 [torch.LongTensor of size 2])

In [55]:
x = nn.functional.softmax(a)

In [56]:
x

Variable containing:
 0.0900  0.2447  0.6652
 0.8438  0.1142  0.0420
[torch.FloatTensor of size 2x3]

In [59]:
z, v = torch.max(x, dim=1)
z, v

(Variable containing:
  0.6652
  0.8438
 [torch.FloatTensor of size 2], Variable containing:
  2
  0
 [torch.LongTensor of size 2])

In [72]:
aa = z.data.cpu().numpy()
aa = aa.reshape(aa.shape[0], 1)
aa

array([[ 0.66524094],
       [ 0.84379476]], dtype=float32)

In [75]:
vvv = x.data.cpu().numpy() == aa
vvv

array([[False, False,  True],
       [ True, False, False]], dtype=bool)

In [77]:
vvv + 0

array([[0, 0, 1],
       [1, 0, 0]])