### Necessary imports

In [1]:
import os
import torch
import numpy as np
from torch import nn
from time import time
from capsules import *
from torch.optim import Adam
import pytorch_lightning as pl
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torchvision.datasets import MNIST
from sklearn.metrics import accuracy_score
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

### Definition of CapsNet model

In [2]:
class CapsNet4MNIST(pl.LightningModule):
    
    def __init__(self):
        super(CapsNet4MNIST, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=256,
                kernel_size=9,
                stride=1
            ),
            nn.ReLU(inplace=True)
        )
        self.primcaps = PrimaryCapsuleLayer()
        self.digicaps = SecondaryCapsuleLayer()
        self.decoder = RegularizingDecoder()
        
    def forward(self, x):
        """Compute forward of capsules, get the longest vectors, reconstruct the pictures"""
        u = self.conv(x)
        u = self.primcaps(u)
        internal = self.digicaps(u)
        lengths = F.softmax(
            (internal**2).sum(dim=-1)**0.5, dim=-1
        )
        _, max_caps_index = lengths.max(dim=-1)
        masked = Variable(torch.eye(10))
        masked = masked.cuda() if torch.cuda.is_available() else masked
        masked = masked.index_select(dim=0, index=max_caps_index)
        reconstruction = self.decoder(
            (internal*masked[:,:,None]).reshape(x.size(0), -1)
        )
        return(internal, reconstruction, lengths, max_caps_index)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return(optimizer)
    
    def cross_entropy_loss(self, logits, labels):
        return(F.nll_loss(logits, labels))
    
    def capsule_loss(self, real_class, x, classes, reconstruction):
        capsule_loss = CapsuleLoss().cuda()
        return(
            capsule_loss(
                real_class, x.view(x.size(0), 28*28),
                classes, reconstruction
            )
        )
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        real_class = Variable(
            make_y(y.type(torch.LongTensor).cuda(), 10)
        )
        internal, reconstruction, classes, max_index = self.forward(x)
        loss = self.capsule_loss(
            real_class, x, classes, reconstruction
        )
        self.log("train_loss", loss)
        return(loss)
    
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        real_class = Variable(
            make_y(y.type(torch.LongTensor).cuda(), 10)
        )
        internal, reconstruction, classes, max_index = self.forward(x)
        loss = self.capsule_loss(
            real_class, x, classes, reconstruction
        )
        self.log("val_loss", loss)
        return(loss)

In [3]:
class MNISTDataModule(pl.LightningDataModule):
    
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
    
    def prepare_data(self):
        MNIST(os.getcwd(), train=True, download=True)
        MNIST(os.getcwd(), train=False, download=True)
        
    def setup(self, stage=None):
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]
        )
        mnist_train = MNIST(
            os.getcwd(), train=True, download=False, 
            transform=transform
        )
        self.mnist_test = MNIST(
            os.getcwd(), train=False, download=False, 
            transform=transform
        )
        self.mnist_train, self.mnist_val = random_split(
            mnist_train, (55000, 5000)
        )
        
    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, self.batch_size)
        return(mnist_train)
    
    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, self.batch_size)
        return(mnist_val)
    
    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, self.batch_size)
        return(mnist_test)

In [4]:
cn = CapsNet4MNIST().cuda()

In [5]:
trainer = pl.Trainer(gpus="0")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [6]:
data_module = MNISTDataModule()

In [7]:
trainer.fit(cn, data_module)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])

  | Name     | Type                  | Params
---------------------------------------------------
0 | conv     | Sequential            | 20 K  
1 | primcaps | PrimaryCapsuleLayer   | 5 M   
2 | digicaps | SecondaryCapsuleLayer | 1 M   
3 | decoder  | RegularizingDecoder   | 1 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Exception ignored in: <function Image.__del__ at 0x7f483e1da440>
Traceback (most recent call last):
  File "/home/bakirillov/anaconda3/envs/lapki/lib/python3.7/site-packages/PIL/Image.py", line 630, in __del__
    self.__exit__()
  File "/home/bakirillov/anaconda3/envs/lapki/lib/python3.7/site-packages/PIL/Image.py", line 597, in __exit__
    self.fp = None
KeyboardInterrupt







1