In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
from sklearn.metrics import precision_score,recall_score,f1_score,accuracy_score

In [None]:
USE_CUDA =  True if torch.cuda.is_available() else False

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

trainset = torchvision.datasets.FashionMNIST('/FashionMNIST/', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,shuffle=True, num_workers=2, batch_size = 10)

testset = torchvision.datasets.FashionMNIST('/FashionMNIST/', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset,shuffle=False, num_workers=2, batch_size = 100)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to /FashionMNIST/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /FashionMNIST/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /FashionMNIST/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /FashionMNIST/FashionMNIST/raw
Processing...
Done!




In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
        super(PrimaryCaps, self).__init__()
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
            for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)

        u = u.view(x.size(0), self.num_routes, -1)
        return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)

        b_ij = torch.zeros(1, self.num_routes, self.num_capsules, 1)
        if USE_CUDA:
            b_ij = b_ij.cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)

            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor


class CapsNet(nn.Module):
    def __init__(self, config=None):
        super(CapsNet, self).__init__()
        if config:
            self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
            self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
                                                config.pc_kernel_size, config.pc_num_routes)
            self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
                                            config.dc_out_channels)
            #self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)
        else:
            self.conv_layer = ConvLayer()
            self.primary_capsules = PrimaryCaps()
            self.digit_capsules = DigitCaps()
            #self.decoder = Decoder()

        self.mse_loss = nn.MSELoss()

    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        return output




In [None]:
class Config_1:
    def __init__(self, in_channels):
            # CNN (cnn)
            self.cnn_in_channels = in_channels
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 18
            self.dc_out_channels = 6

class Config_2:
    def __init__(self, in_channels):
            # CNN (cnn)
            self.cnn_in_channels = in_channels
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 8
            self.dc_out_channels = 6

class Config_3:
    def __init__(self, in_channels):
            # CNN (cnn)
            self.cnn_in_channels = in_channels
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 2
            self.dc_out_channels = 6

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,32,kernel_size = (3,3))
        self.conv2 = nn.Conv2d(32,48,kernel_size = (3,3))
        self.conv3 = nn.Conv2d(48,64,kernel_size = (3,3))
        self.conv4 = nn.Conv2d(64,80,kernel_size = (3,3))
        self.conv5 = nn.Conv2d(80,96,kernel_size = (3,3))
        self.conv6 = nn.Conv2d(96,112,kernel_size = (3,3))
        self.conv7 = nn.Conv2d(112,128,kernel_size = (3,3))
        self.conv8 = nn.Conv2d(128,144,kernel_size = (3,3))
        self.conv9 = nn.Conv2d(144,160,kernel_size = (3,3))
        self.caps_a = CapsNet(Config_1(64))
        self.caps_b = CapsNet(Config_2(112))
        self.caps_c = CapsNet(Config_3(160))
        self.merge_weight1= nn.Parameter(torch.randn(1))
        self.merge_weight2= nn.Parameter(torch.randn(1))
        self.merge_weight3= nn.Parameter(torch.randn(1))
        self.relu = nn.ReLU()
        self.batchnorm_branch1 = nn.BatchNorm2d(64)
        self.batchnorm_branch2 = nn.BatchNorm2d(112)
        self.batchnorm_branch3 = nn.BatchNorm2d(160)
        self.upsample_conv1 = nn.Conv2d(1,1,kernel_size=(5,5))
        self.upsample_conv2 = nn.Conv2d(48,80,kernel_size=(5,5))
        self.upsample_conv3 = nn.Conv2d(80,112,kernel_size=(5,5))
        self.upsample_conv4 = nn.Conv2d(112,144,kernel_size=(5,5))

        self.k1= nn.Parameter(torch.randn(1))
        self.k2= nn.Parameter(torch.randn(1))
        self.k3= nn.Parameter(torch.randn(1))
        self.k4= nn.Parameter(torch.randn(1))
        
    def forward(self,x):
        branch1 = self.conv1(x)
        branch1 = self.relu(branch1)
        branch1 = self.conv2(branch1)
        branch1 = self.relu(branch1)

        branch1 += self.k1 * self.upsample_conv1(x)
        x1 = branch1.clone()
        
        branch1 = self.conv3(branch1)
        branch1 = self.relu(branch1)
        branch2 = self.conv4(branch1)
        branch2 = self.relu(branch2)

        branch2 += self.k2 * self.upsample_conv2(x1)
        x1 = branch2.clone()

        branch2 = self.conv5(branch2)
        branch2 = self.relu(branch2)
        branch2 = self.conv6(branch2)
        branch2 = self.relu(branch2)

        branch2 += self.k3 * self.upsample_conv3(x1)
        x1 = branch2.clone()

        branch3 = self.conv7(branch2)
        branch3 = self.relu(branch3)
        branch3 = self.conv8(branch3)
        branch3 = self.relu(branch3)

        branch3 += self.k4 * self.upsample_conv4(x1)

        branch3 = self.conv9(branch3)
        branch3 = self.relu(branch3)

        #Batch Normalizing 
        branch1 = self.batchnorm_branch1(branch1)
        branch2 = self.batchnorm_branch2(branch2)
        branch3 = self.batchnorm_branch3(branch3)

        branch1_out = self.caps_a(branch1)
        branch2_out = self.caps_b(branch2)
        branch3_out = self.caps_c(branch3)

        branch1_out = torch.sqrt((branch1_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
        branch2_out = torch.sqrt((branch2_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
        branch3_out = torch.sqrt((branch3_out ** 2).sum(dim=2, keepdim=True)).view(-1,10)
        
        out1 = self.merge_weight1 * branch1_out
        out2 = self.merge_weight1 * branch2_out
        out3 = self.merge_weight1 * branch3_out
        stack = torch.stack([out1,out3,out3],dim = 0)
        summed = torch.sum(stack,dim = 0)
        return summed

In [None]:
net = Net()
if USE_CUDA:
  net = net.cuda()
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = torch.optim.Adam(net.parameters(),lr = lr)

In [None]:
print("The number of trainable parameters in Convnet :",sum(p.numel() for p in net.parameters() if p.requires_grad))

The number of trainable parameters in Convnet : 5961617


In [None]:
for epoch in range(5):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        if USE_CUDA:
          inputs = inputs.cuda()
          labels = labels.cuda()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        if torch.isnan(loss):
          print("nan Loss")
          running_loss = 0.0
          continue
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:  
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

[1,   100] loss: 2.006
[1,   200] loss: 1.686
[1,   300] loss: 1.353
[1,   400] loss: 0.972
[1,   500] loss: 0.820
[1,   600] loss: 0.734
[1,   700] loss: 0.608
[1,   800] loss: 0.631
[1,   900] loss: 0.565
[1,  1000] loss: 0.569
[1,  1100] loss: 0.589
[1,  1200] loss: 0.540
[1,  1300] loss: 0.525
[1,  1400] loss: 0.500
[1,  1500] loss: 0.526
[1,  1600] loss: 0.470
[1,  1700] loss: 0.464
[1,  1800] loss: 0.485
[1,  1900] loss: 0.508
[1,  2000] loss: 0.426
[1,  2100] loss: 0.475
[1,  2200] loss: 0.412
[1,  2300] loss: 0.390
[1,  2400] loss: 0.437
[1,  2500] loss: 0.414
[1,  2600] loss: 0.400
[1,  2700] loss: 0.404
[1,  2800] loss: 0.347
[1,  2900] loss: 0.383
[1,  3000] loss: 0.370
[1,  3100] loss: 0.395
[1,  3200] loss: 0.319
[1,  3300] loss: 0.328
[1,  3400] loss: 0.354
[1,  3500] loss: 0.355
[1,  3600] loss: 0.338
[1,  3700] loss: 0.339
[1,  3800] loss: 0.335
[1,  3900] loss: 0.332
[1,  4000] loss: 0.384
[1,  4100] loss: 0.368
[1,  4200] loss: 0.329
[1,  4300] loss: 0.349
[1,  4400] 

In [None]:
predict = []
true = []
device = 'cpu'
if USE_CUDA:
  device = 'cuda:0'
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = torch.tensor(images,device = device)
        capsnet = net.cuda()
        outputs = capsnet(images)
       # pred = torch.sqrt((outputs[0] ** 2).sum(dim=2, keepdim=True))
        predicted = torch.argmax(outputs.data, 1)
        predict.extend(predicted.tolist())
        true.extend(labels.tolist())

  if __name__ == '__main__':


In [None]:
print("Accuracy : ",accuracy_score(true,predict))
print("Precision Score : ",np.mean(precision_score(true,predict,average=None)))
print("Recall Score : ",np.mean(recall_score(true,predict,average=None)))
print("F1 Score : ",np.mean(f1_score(true,predict,average=None)))

Accuracy :  0.916
Precision Score :  0.9165823579745856
Recall Score :  0.9159999999999998
F1 Score :  0.9158978603839701


In [None]:
torch.save(net,'modified_capsnet_fashion.pt')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [None]:
from google.colab import files
files.download('modified_capsnet_fashion.pt')