In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./..")

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np

# local imports
from effcn.models import MnistEcnBackbone, MnistEcnDecoder
from effcn.layers import PrimaryCaps, FCCaps
from effcn.functions import margin_loss, max_norm_masking

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)

In [None]:
ds_train = datasets.MNIST(root='./data', train=True, download=True, transform=T.ToTensor())
ds_valid = datasets.MNIST(root="./data", train=False, download=True, transform=T.ToTensor())

In [None]:
plt.imshow(ds_train.data[0], cmap='gray')
plt.title('%i' % ds_train.targets[0])
plt.show()

In [None]:
dl_train = torch.utils.data.DataLoader(ds_train, 
                                          batch_size=32, 
                                          shuffle=True, 
                                          num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                          batch_size=32, 
                                          shuffle=True, 
                                          num_workers=4)

In [None]:
# values from paper, are fixed!
n_l = 16    # num of primary capsules
d_l = 8     # dim of primary capsules
n_h = 10    # num of output capsules
d_h = 16    # dim of output capsules

In [None]:
x, y = next(iter(dl_valid))

In [None]:
model_backbone = MnistEcnBackbone()
model_primary = PrimaryCaps(F=128, K=9, N=n_l, D=d_l) # F = n_l * d_l !!!
model_fcncaps = FCCaps(n_l, n_h, d_l, d_h)
model_decoder = MnistEcnDecoder()

In [None]:
x_bb = model_backbone(x)
u_l = model_primary(x_bb)
u_h = model_fcncaps(u_l)

y_one_hot = F.one_hot(y, num_classes=10)
loss_margin = margin_loss(u_h, y_one_hot)

# rec
u_h_masked = max_norm_masking(u_h)
u_h_masked_flat = torch.flatten(u_h_masked, start_dim=1)

x_rec = model_decoder(u_h_masked_flat)

loss_rec = torch.nn.functional.mse_loss(x, x_rec)

loss = loss_margin + 0.0005 * loss_rec

In [None]:
loss