In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./..")

# standard lib
import shutil
from pathlib import Path
import pickle

# external imports
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from dotted_dict import DottedDict
import pprint
from tqdm import tqdm
#
from torchvision.datasets import CIFAR10

# local imports
from datasets import AffNIST
from effcn.layers import FCCaps, Squash
from effcn.functions import margin_loss, max_norm_masking
from effcn.utils import count_parameters
from misc.optimizer import get_optimizer, get_scheduler
from misc.utils import get_sting_timestamp, mkdir_directories
from misc.plot_utils import plot_couplings, plot_capsules

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)

## Data

In [None]:
transform_train = T.Compose([
    T.RandomAffine(degrees=(-8, 8),
                   shear=(-15, 15),
                   scale=(0.9, 1.1)
                  )
])

transform_train = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    T.Normalize([0, 0, 0], [1, 1, 1])
])

transform_valid = T.Compose([
    T.ToTensor(),
    T.Normalize([0, 0, 0], [1, 1, 1])
])

#transform_valid = T.ToTensor() # converts [0,255] to [0,1] by dividing through 255
#transform_train = T.ToTensor()

p_data = '/mnt/data/pytorch'

ds_train = CIFAR10(root=p_data, train=True, download=True, transform=transform_train, target_transform=None)
ds_valid = CIFAR10(root=p_data, train=False, download=True, transform=transform_valid, target_transform=None)

In [None]:
bs = 512
dl_train = torch.utils.data.DataLoader(
    ds_train, 
    batch_size=bs, 
    shuffle=True,
    pin_memory=True,
    num_workers=4)
dl_valid = torch.utils.data.DataLoader(
    ds_valid, 
    batch_size=bs, 
    shuffle=True, 
    pin_memory=True,
    num_workers=4)

In [None]:
x, _ = next(iter(dl_train))
x_vis_train = x[:32]

x, _ = next(iter(dl_valid))
x_vis_valid = x[:32]

In [None]:
plt.imshow(torchvision.utils.make_grid(x_vis_train).permute(1,2,0))
plt.show()
#
plt.imshow(torchvision.utils.make_grid(x_vis_valid).permute(1,2,0))
plt.show()

# Backbone

In [None]:
class CustomBB(nn.Module):
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=128, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=7, groups=256, stride=1, padding="valid"),
            #nn.BatchNorm2d(num_features=256),
            #nn.ReLU(),
        )
        self.fc = nn.Linear(256 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

In [None]:
model = CustomBB(ch_in=1)
y = model(torch.rand(128,1,32, 32))
y.shape

In [None]:
count_parameters(model)

In [None]:
class FCCaps(nn.Module):
    """
        Attributes
        ----------
        n_l ... number of lower layer capsules
        d_l ... dimension of lower layer capsules
        n_h ... number of higher layer capsules
        d_h ... dimension of higher layer capsules

        W   (n_l, n_h, d_l, d_h) ... weight tensor
        B   (n_l, n_h)           ... bias tensor
    """

    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        
        
        self.W = torch.nn.Parameter(torch.rand(
            n_l, n_h, d_l, d_h), requires_grad=True)
        #self.B = torch.nn.Parameter(torch.rand(n_l, n_h), requires_grad=True)
        self.squash = Squash(eps=1e-20)

        # init custom weights
        # i'm relly unsure about this initialization scheme
        # i don't think it makes sense in our case, but the paper says so ...
        torch.nn.init.kaiming_normal_(
            self.W, a=0, mode='fan_in', nonlinearity='leaky_relu')
        #torch.nn.init.kaiming_normal_(
        #    self.B, a=0, mode="fan_in", nonlinearity="leaky_relu")

        self.attention_scaling = np.sqrt(self.d_l)

    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        A = A / self.attention_scaling
        A_sum = torch.einsum("...hij->...hj", A)
        C = torch.softmax(A_sum, dim=-1)
        #CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h)

    def forward_debug(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        A = A / self.attention_scaling
        A_sum = torch.einsum("...hij->...hj", A)
        C = torch.softmax(A_sum, dim=-1)
        #CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h), C

class DeepCapsNet(nn.Module):
    def __init__(self, ns, ds):
        super().__init__()
        self.ns = ns
        self.ds = ds
        
        self.backbone = CustomBB(ch_in=3)
        self.backbone.fc = nn.Identity()
        
        self.squash = Squash(eps=1e-20)
        layers = []
        for idx in range(1, len(ns), 1):
            n_l = ns[idx - 1]
            n_h = ns[idx]
            d_l = ds[idx - 1]
            d_h = ds[idx]
            layers.append(FCCaps(n_l, n_h, d_l, d_h) )
        self.layers = nn.Sequential(*layers)


    def forward(self, x):
        x = self.backbone(x)
        
        # primecaps
        x = self.squash(x.view(-1, self.ns[0], self.ds[0]))
        
        for layer in self.layers:
            x = layer(x)
        return x

    def forward_debug(self, x):
        x = self.backbone(x)
        
        # primecaps
        x = self.squash(x.view(-1, self.ns[0], self.ds[0]))
        
        us = [torch.clone(x)]
        cc = []
        # fccaps
        for layer in self.layers:
            x, c = layer.forward_debug(x)
            cc.append(c.detach())
            us.append(torch.clone(x).detach())
        return x, cc, us

In [None]:
ns = [32, 32, 16, 10]
ds = [8, 8, 8, 8]

model = DeepCapsNet(ns=ns, ds=ds)
#
print("tot Model ", count_parameters(model))
print("Backbone  ", count_parameters(model.backbone))
#
model = model.to(device)
model

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=2e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)

In [None]:
num_epochs = 11
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        u_h = model.forward(x)
        
        # LOSS
        y_one_hot = F.one_hot(y_true, num_classes=10)
        loss = margin_loss(u_h, y_one_hot)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            u_h = model.forward(x)
            
            y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))

# Visualize and Analyze

### Show parse tree and activations for individual samples

In [None]:
x, y = next(iter(dl_valid))
x = x[:128]
y = y[:128]
#
model.eval()
with torch.no_grad():
    u_h, CC, US = model.forward_debug(x.to(device))
y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
y_pred = y_pred.detach().cpu().numpy()
#
US = [u.cpu().numpy() for u in US]
CS = [c.cpu().numpy() for c in CC]
#
Y_true = y.cpu().numpy()
Y_pred = y_pred

In [None]:
cl = 1
for idx in range(32):
    if cl is not None and Y_true[idx] != cl:
        continue
    cs = [c[idx] for c in CS]
    us = [u[idx] for u in US]
    u_norms = [np.linalg.norm(u, axis=1) for u in us]
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    title = "exp={} a={}".format(y[idx], y_pred[idx])
    #
    plot_couplings(cs, title=title, ax=axes[0], show=False)
    #
    plot_capsules(u_norms, title=title , ax=axes[1], show=False)
    plt.show()

# Collect Statistics

In [None]:
model.eval()

YY = []
CC = [[] for _ in range(len(ns) - 1)]
US = [[] for _ in range(len(ns))]


for x,y_true in dl_valid:
    x = x.to(device)
    #y_true = y_true.to(device)
        
    with torch.no_grad():
        _, cc, us = model.forward_debug(x.to(device))
        for idx in range(len(cc)):
            CC[idx].append(cc[idx].detach().cpu().numpy())
        for idx in range(len(us)):
            US[idx].append(us[idx].detach().cpu().numpy())
        YY.append(y_true.numpy())
YY = np.concatenate(YY)
CC = [np.concatenate(c) for c in CC]
US = [np.concatenate(u) for u in US]

### Mean parse tree and mean activation for dataset

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
# Mean parse tree
cc = [np.mean(c, axis=0) for c in CC]
plot_couplings(cc, ax=axes[0], show=False, title="mean parse tree")
    
# mean and std capsule activation
us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in US]
us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in US]
plot_capsules(us_mean, scale_factor=1, ax=axes[1], show=False, title="mean activation")
plot_capsules(us_std, scale_factor=1, ax=axes[2], show=False, title="std activation")
plt.suptitle("dataset")
plt.show()

### classwise mean parse tree and mean activation

In [None]:
# mean and variance activation
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Mean parse tree
    cc = [C[idcs] for C in CC]
    cc = [np.mean(c, axis=0) for c in cc]
    plot_couplings(cc, ax=axes[0], show=False, title="mean parse tree")
    
    # mean and std capsule activation
    us = [u[idcs] for u in US]
    us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in us]
    us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in us]
    plot_capsules(us_mean, scale_factor=1, ax=axes[1], show=False, title="mean activation")
    plot_capsules(us_std, scale_factor=1, ax=axes[2], show=False, title="std activation")
    plt.suptitle("class {}".format(cls))
    plt.show()

In [None]:
# mean and variance couplings 
for C in CC:
    C_mean = C.mean(axis=0)
    C_var = C.std(axis=0)
    #
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    axes[0].imshow(C_mean, cmap="gray", vmin=0., vmax=1.)
    axes[1].imshow(C_var, cmap="gray", vmin=0.)
    #plt.cmap()
    plt.show()

In [None]:
# mean and variance activation
for U in US:
    u = np.linalg.norm(U, axis=2)
    u_mean = u.mean(axis=0)
    u_std = u.std(axis=0)
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    axes[0].bar(range(len(u_mean)), u_mean)
    axes[1].bar(range(len(u_mean)), u_std)
    axes[0].set_ylim(0,1)
    axes[1].set_ylim(0,1)
    plt.show()

In [None]:
for cls in range(10):
    print("#" * 100, "\n{}\n".format(cls), "#" * 100)
    idcs = np.where(YY == cls)[0]
    # mean and variance couplings 
    for C in CC:
        C = C[idcs]
        C_mean = C.mean(axis=0)
        C_var = C.std(axis=0)
        #
        fig, axes = plt.subplots(1, 2, figsize=(4, 2))
        axes[0].imshow(C_mean, cmap="gray", vmin=0., vmax=1.)
        axes[1].imshow(C_var, cmap="gray", vmin=0.)
        #plt.cmap()
        plt.show()

In [None]:
# mean and variance activation
for cls in range(10):
    print("#" * 100, "\n{}\n".format(cls), "#" * 100)
    idcs = np.where(YY == cls)[0]
    for U in US:
        U = U[idcs]
        u = np.linalg.norm(U, axis=2)
        u_mean = u.mean(axis=0)
        u_std = u.std(axis=0)
        fig, axes = plt.subplots(1, 2, figsize=(4, 2))
        axes[0].bar(range(len(u_mean)), u_mean)
        axes[1].bar(range(len(u_mean)), u_std)
        axes[0].set_ylim(0,1)
        axes[1].set_ylim(0,1)
        plt.show()

## Metrics

In [None]:
def mean_max(C):
    return np.max(C, axis=2).mean()

def max_std_dev(C):
    return np.max(C.std(axis=0), axis=1).mean()

In [None]:
for C in CC:
    print("{:.3f}   {:.3f}".format(mean_max(C), max_std_dev(C)))

In [None]:
# Uniform routing
CC_uni = []
for C in CC:
    CC_uni.append(np.ones(C.shape) / C.shape[2])
for C in CC_uni:
    print("{:.3f}   {:.3f}".format(mean_max(C), max_std_dev(C)))

In [None]:
CC_rand = []
for C in CC:
    Cr = np.random.rand(*C.shape) * 10
    Cr = torch.softmax(torch.Tensor(Cr), dim=-1).numpy()
    CC_rand.append(Cr)
for C in CC_rand:
    print("{:.3f}   {:.3f}".format(mean_max(C), max_std_dev(C)))

# CNN Only Baseline

In [None]:
model = CustomBB(ch_in=3, n_classes=10)
#
model = model.to(device)
#backbone
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=2e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
#
criterion = nn.CrossEntropyLoss()

In [None]:
print(count_parameters(model))

In [None]:
num_epochs = 11
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        logits = model.forward(x)
        loss = criterion(logits, y_true)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(logits, dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            logits = model.forward(x)
            
            y_pred = torch.argmax(logits, dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))