In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./..")

In [None]:
import torch
#
import numpy as np
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision
import torch.nn as nn
from torch import optim
#
import matplotlib.pyplot as plt
from einops import rearrange
from tqdm import tqdm
#

from misc.plot_utils import plot_mat, imshow
from effcn.layers import FCCaps, FCCapsWOBias, Squash
from effcn.utils import count_parameters
from effcn.functions import margin_loss

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)

In [None]:
def pos_tanh_embedding(h, w, t_freq = 2, t_symm = 0.5):
    pe = torch.zeros(4, h, w)
    pe[0] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
    pe[1] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
    pe[2] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, h) - t_symm)).T.repeat(w, 1)) * 0.5
    pe[3] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, h) - t_symm)).T.repeat(w, 1)) * 0.5
    return pe

In [None]:
transform_train = T.Compose([
    T.RandomRotation(degrees=(-30, 30)),
    T.RandomResizedCrop(
        28,
        scale=(0.8, 1.0),
        ratio=(1, 1),
    ),
    T.RandomAffine(
        degrees=(-30, 30),
        #translate=(0.1, 0.1)
    ),
    T.ToTensor()
])
transform_valid = T.Compose([
    T.ToTensor()
])

In [None]:
ds_train = datasets.MNIST(root='./../data', train=True, download=True, transform=transform_train)
ds_valid = datasets.MNIST(root="./../data", train=False, download=True, transform=transform_valid)
#
dl_train = torch.utils.data.DataLoader(ds_train, 
                                       batch_size=8, 
                                       shuffle=True, 
                                       num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                       batch_size=8, 
                                       shuffle=True, 
                                       num_workers=4)

In [None]:
# plot train imgs
x, y = next(iter(dl_train))
img = torchvision.utils.make_grid(x[:64], nrow=8)
img = img.permute((1,2,0))
plt.imshow(img)
plt.show()

# plot valid imgs
x, y = next(iter(dl_valid))
img = torchvision.utils.make_grid(x[:64], nrow=8)
img = img.permute((1,2,0))
plt.imshow(img)
plt.show()

In [None]:
from einops.layers.torch import Rearrange
from einops import rearrange, repeat

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

def pair(t):
    return t if isinstance(t, tuple) else (t, t)
    
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_emb = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
        self.pe = torch.nn.Parameter(pos_tanh_embedding(28, 28), requires_grad=False)

    def emb(self, x):
        b, _, _, _ = x.shape
        pe = self.pe.unsqueeze(0).repeat(b, 1, 1, 1) 
        x = torch.cat([x, pe], dim=1)
        x = x.permute(0, 2,3,1).reshape(b,  28*28, -1)
        return x

    def forward(self, img):
        x = self.emb(img)
        x = self.transformer(x)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [None]:
class UniversalPrimeCaps(nn.Module):
    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
    
        self.n_h = n_h
        self.d_l = d_l
        self.d_h = d_h
        
        self.W = torch.nn.Parameter(torch.rand(n_h, d_l, d_h), requires_grad=True)
        self.squash = Squash(eps=1e-20)
        
        torch.nn.init.kaiming_normal_(
            self.W, a=0, mode='fan_in', nonlinearity='leaky_relu')

        self.attention_scaling = np.sqrt(d_l)
    
    def forward(self, U_l):
        """
            In:  U_l
            Out: U_h
            DIMS:
                U_l  (n_l, d_l)
                U_h  (n_h, d_h)
                W    (n_h, d_l, d_h)
        """
        U_hat = torch.einsum("...ij, ...kjl -> ...ikl", U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        A_scaled = A / self.attention_scaling
        A_sum = torch.einsum("...hij->...hj", A_scaled)
        C = torch.softmax(A_sum, dim=-1)
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h)

class MnistCaps(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = Transformer(dim=5, depth=2, heads=1, dim_head=32, mlp_dim=64, dropout=0.)
        self.pe = torch.nn.Parameter(pos_tanh_embedding(28, 28), requires_grad=False)
        #
        self.caps_prime = UniversalPrimeCaps(n_l=28*28, n_h=16, d_l=5, d_h=16)
        self.caps_digits = FCCaps(n_l=16, n_h=10, d_l=16, d_h=16)
    
    def emb(self, x):
        b, _, _, _ = x.shape
        pe = self.pe.unsqueeze(0).repeat(b, 1, 1, 1) 
        x = torch.cat([x, pe], dim=1)
        x = x.permute(0, 2,3,1).reshape(b,  28*28, -1)
        return x

    def forward(self, x):
        x = self.emb(x)
        #
        x = self.transformer(x)
        x = self.caps_prime(x)
        x = self.caps_digits(x)
        return x

In [None]:
x, _ = ds_train[0]
x = x.unsqueeze(0)
#
model = ViT(
    image_size = 28,
    patch_size = 1,
    num_classes = 10,
    dim = 5,
    depth = 6,
    heads = 16,
    mlp_dim = 128,
    dropout = 0.1,
    emb_dropout = 0.1,
    channels = 1,
)
#
y = model(x)
y.shape

model = model.to(device)

In [None]:
count_parameters(model)

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
#
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 11
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        out = model.forward(x)
        loss = criterion(out, y_true)

        loss.backward()
        
        optimizer.step()
        
        #y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
        #acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 #'acc': acc.item()
                 }
        )
    lr_scheduler.step()