In [8]:
import torch
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os

from model.net import *
from utils.training import *
from data.data import *

## Experiment Config

In [10]:
model_path = os.getcwd()

args = {
    'USE_CUDA': True if torch.cuda.is_available() else False,
    'BATCH_SIZE': 32,
    'N_EPOCHS': 30,
    'LEARNING_RATE': 0.01,
    'MOMENTUM': 0.9,
    'DATASET_NAME':'mnist',
}

## Model Loading

In [13]:
#Modified Decoder

class Decoder(nn.Module):
    def __init__(self, input_width=28, input_height=28, input_channel=1):
        super(Decoder, self).__init__()
        self.input_width = input_width
        self.input_height = input_height
        self.input_channel = input_channel
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, self.input_height * self.input_height * self.input_channel),
            nn.Sigmoid()
        )

    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes.squeeze(), dim=1)

        _, max_length_indices = classes.max(dim=1)
        masked = torch.sparse.torch.eye(10)
        if USE_CUDA:
            masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze().data)
#         t = (x * masked[:, :, None, None]).view(x.size(0), -1)
        reconstructions = self.reconstraction_layers(x)
        reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)
        return reconstructions, None

['data',
 'model',
 'utils',
 'LICENSE',
 '.gitignore',
 '.git',
 'Capsule Network Train.ipynb',
 'Pretrain_Capsule.ipynb',
 'L2_recon.ipynb',
 '.ipynb_checkpoints',
 'README.md',
 'CapsNetMNIS.pth ']

In [15]:
#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case
class Config:
    def __init__(self):
        # CNN (cnn)
        self.cnn_in_channels = 1
        self.cnn_out_channels = 12
        self.cnn_kernel_size = 15

        # Primary Capsule (pc)
        self.pc_num_capsules = 16
        self.pc_in_channels = 12
        self.pc_out_channels = 1
        self.pc_kernel_size = 8
        self.pc_num_routes = 1 * 7 * 7

        # Digit Capsule 1 (dc)
        self.dc_num_capsules = 49
        self.dc_num_routes = 1 * 7 * 7
        self.dc_in_channels = 16
        self.dc_out_channels = 16
        
        # Digit Capsule 2 (dc)
        self.dc_2_num_capsules = 10
        self.dc_2_num_routes = 1 * 7 * 7
        self.dc_2_in_channels = 16
        self.dc_2_out_channels = 16

        # Decoder
        self.input_width = 28
        self.input_height = 28

torch.manual_seed(1)
config = Config()

net = CapsNet(config)
# capsule_net = torch.nn.DataParallel(capsule_net)
if args['USE_CUDA']:
    net = net.cuda()
    
net.load_state_dict(torch.load(os.path.join(model_path, 'CapsNetMNIST.pth'), map_location='cpu'))

<All keys matched successfully>

## Loading Dataset

In [None]:
trainloader, testloader = dataset(args)

##  The histogram of L2 distances between the input and the reconstruction using the
## correct capsule or other capsules in CapsNet on the real MNIST images. 

In [None]:
# torch.save(capsule_net.state_dict(), "./CapsNetMNIST.pth")

In [None]:
#Config for 16 1d vectors in Capsule Layer. Set the Softmax Dimension to 1 in this case
# class Config:
#     def __init__(self, dataset='mnist'):
#         # CNN (cnn)
#         self.cnn_in_channels = 1
#         self.cnn_out_channels = 12
#         self.cnn_kernel_size = 15

#         # Primary Capsule (pc)
#         self.pc_num_capsules = 1
#         self.pc_in_channels = 12
#         self.pc_out_channels = 16
#         self.pc_kernel_size = 8
#         self.pc_num_routes = 16 * 7 * 7

#         # Digit Capsule 1 (dc)
#         self.dc_num_capsules = 49
#         self.dc_num_routes = 16 * 7 * 7
#         self.dc_in_channels = 1
#         self.dc_out_channels = 1 #16
        
#         # Digit Capsule 2 (dc)
#         self.dc_2_num_capsules = 10
#         self.dc_2_num_routes = 7 * 7
#         self.dc_2_in_channels = 1 #16
#         self.dc_2_out_channels = 16

#         # Decoder
#         self.input_width = 28
#         self.input_height = 28