In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./..")

# standard lib
import shutil
from pathlib import Path
import pickle

# external imports
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from dotted_dict import DottedDict
import pprint
from tqdm import tqdm
#
from torchvision.datasets import CIFAR10

# local imports
from datasets import AffNIST
from effcn.layers import FCCaps, Squash
from effcn.functions import margin_loss, max_norm_masking
from effcn.utils import count_parameters
from misc.optimizer import get_optimizer, get_scheduler
from misc.utils import get_sting_timestamp, mkdir_directories
from misc.plot_utils import plot_couplings, plot_capsules, plot_mat

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)

# Data

In [None]:
transform_train = T.Compose([
    T.RandomAffine(degrees=(-8, 8),
                   shear=(-15, 15),
                   scale=(0.9, 1.1)
                  )
])
transform_valid = None # converts [0,255] to [0,1] by dividing through 255

p_data = '/home/matthias/projects/EfficientCN/data'

ds_mnist_train = AffNIST(p_root=p_data, split="mnist_train", download=True, transform=transform_train, target_transform=None)
ds_mnist_valid = AffNIST(p_root=p_data, split="mnist_valid", download=True, transform=transform_valid, target_transform=None)
ds_affnist_valid = AffNIST(p_root=p_data, split="affnist_valid", download=True, transform=transform_valid, target_transform=None)

In [None]:
bs = 512
dl_mnist_train = torch.utils.data.DataLoader(
    ds_mnist_train, 
    batch_size=bs, 
    shuffle=True,
    pin_memory=True,
    num_workers=4)
dl_mnist_valid= torch.utils.data.DataLoader(
    ds_mnist_valid, 
    batch_size=bs, 
    shuffle=True, 
    pin_memory=True,
    num_workers=4)
dl_affnist_valid= torch.utils.data.DataLoader(
    ds_affnist_valid, 
    batch_size=bs, 
    shuffle=True, 
    pin_memory=True,
    num_workers=4)

In [None]:
x, _ = next(iter(dl_mnist_train))
x_vis_train = x[:32]

x, _ = next(iter(dl_mnist_valid))
x_vis_mnist_valid = x[:32]

x, _ = next(iter(dl_affnist_valid))
x_vis_affnist_valid = x[:32]

In [None]:
plt.imshow(torchvision.utils.make_grid(x_vis_train).permute(1,2,0))
plt.show()
#
plt.imshow(torchvision.utils.make_grid(x_vis_mnist_valid).permute(1,2,0))
plt.show()
#
plt.imshow(torchvision.utils.make_grid(x_vis_affnist_valid).permute(1,2,0))
plt.show()

# Backbone

In [None]:
class CustomBB(nn.Module):
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=128, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=9, groups=32, stride=1, padding="valid"),
            #nn.BatchNorm2d(num_features=256),
            #nn.ReLU(),
        )
        self.fc = nn.Linear(256 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

In [None]:
model = CustomBB(ch_in=1)
y = model(torch.rand(128, 1, 40, 40))
y.shape

In [None]:
count_parameters(model)

In [None]:
class FCCaps(nn.Module):
    """
        Attributes
        ----------
        n_l ... number of lower layer capsules
        d_l ... dimension of lower layer capsules
        n_h ... number of higher layer capsules
        d_h ... dimension of higher layer capsules

        W   (n_l, n_h, d_l, d_h) ... weight tensor
        B   (n_l, n_h)           ... bias tensor
    """

    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        
        
        self.W = torch.nn.Parameter(torch.rand(
            n_l, n_h, d_l, d_h), requires_grad=True)
        #self.B = torch.nn.Parameter(torch.rand(n_l, n_h), requires_grad=True)
        self.squash = Squash(eps=1e-20)

        # init custom weights
        # i'm relly unsure about this initialization scheme
        # i don't think it makes sense in our case, but the paper says so ...
        torch.nn.init.kaiming_normal_(
            self.W, a=0, mode='fan_in', nonlinearity='leaky_relu')
        #torch.nn.init.kaiming_normal_(
        #    self.B, a=0, mode="fan_in", nonlinearity="leaky_relu")

        self.attention_scaling = np.sqrt(self.d_l)

    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        A = A / self.attention_scaling
        A_sum = torch.einsum("...hij->...hj", A)
        C = torch.softmax(A_sum, dim=-1)
        #CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h)

    def forward_debug(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        A = A / self.attention_scaling
        A_sum = torch.einsum("...hij->...hj", A)
        C = torch.softmax(A_sum, dim=-1)
        #CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h), C

class DeepCapsNet(nn.Module):
    def __init__(self, ns, ds):
        super().__init__()
        self.ns = ns
        self.ds = ds
        
        self.backbone = CustomBB(ch_in=1)
        self.backbone.fc = nn.Identity()
        
        self.squash = Squash(eps=1e-20)
        layers = []
        for idx in range(1, len(ns), 1):
            n_l = ns[idx - 1]
            n_h = ns[idx]
            d_l = ds[idx - 1]
            d_h = ds[idx]
            layers.append(FCCaps(n_l, n_h, d_l, d_h) )
        self.layers = nn.Sequential(*layers)


    def forward(self, x):
        x = self.backbone(x)
        
        # primecaps
        x = self.squash(x.view(-1, self.ns[0], self.ds[0]))
        
        for layer in self.layers:
            x = layer(x)
        return x

    def forward_debug(self, x):
        x = self.backbone(x)
        
        # primecaps
        x = self.squash(x.view(-1, self.ns[0], self.ds[0]))
        
        us = [torch.clone(x)]
        cc = []
        # fccaps
        for layer in self.layers:
            x, c = layer.forward_debug(x)
            cc.append(c.detach())
            us.append(torch.clone(x).detach())
        return x, cc, us

In [None]:
ns = [32, 32, 16, 10]
ds = [8, 8, 8, 8]

model = DeepCapsNet(ns=ns, ds=ds)
#
print("tot Model ", count_parameters(model))
print("Backbone  ", count_parameters(model.backbone))
#
model = model.to(device)
model

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=2e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)

In [None]:
num_epochs = 101
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_mnist_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        u_h = model.forward(x)
        
        # LOSS
        y_one_hot = F.one_hot(y_true, num_classes=10)
        loss = margin_loss(u_h, y_one_hot)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_mnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            u_h = model.forward(x)
            
            y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))

    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_affnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            u_h = model.forward(x)
            
            y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   affnist acc_valid: {:.3f}".format(total_correct / total))

# Visualize and Analyze

### Show parse tree and activations for individual samples

In [None]:
x, y = next(iter(dl_affnist_valid))
x = x[:128]
y = y[:128]
#
model.eval()
with torch.no_grad():
    u_h, CC, US = model.forward_debug(x.to(device))
y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
y_pred = y_pred.detach().cpu().numpy()
#
US = [u.cpu().numpy() for u in US]
CS = [c.cpu().numpy() for c in CC]
#
Y_true = y.cpu().numpy()
Y_pred = y_pred

In [None]:
cl = 3
for idx in range(128):
    if cl is not None and Y_true[idx] != cl:
        continue
    cs = [c[idx] for c in CS]
    us = [u[idx] for u in US]
    u_norms = [np.linalg.norm(u, axis=1) for u in us]
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    title = "exp={} a={}".format(y[idx], y_pred[idx])
    #
    plot_couplings(cs, title=title, ax=axes[0], show=False)
    #
    plot_capsules(u_norms, title=title , ax=axes[1], show=False)
    plt.show()

# Collect Statistics

In [None]:
model.eval()

YY = []
CC = [[] for _ in range(len(ns) - 1)]
US = [[] for _ in range(len(ns))]


for x,y_true in dl_affnist_valid:
    x = x.to(device)
    #y_true = y_true.to(device)
        
    with torch.no_grad():
        _, cc, us = model.forward_debug(x.to(device))
        for idx in range(len(cc)):
            CC[idx].append(cc[idx].detach().cpu().numpy())
        for idx in range(len(us)):
            US[idx].append(us[idx].detach().cpu().numpy())
        YY.append(y_true.numpy())
YY = np.concatenate(YY)
CC = [np.concatenate(c) for c in CC]
US = [np.concatenate(u) for u in US]

### Mean parse tree and mean activation for dataset

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
# Mean parse tree
cc_mean = [np.mean(c, axis=0) for c in CC]
cc_std = [np.std(c, axis=0) for c in CC]
plot_couplings(cc_mean, ax=axes[0], show=False, title="mean couplings")
plot_couplings(cc_std, ax=axes[1], show=False, title="std couplings")
    
# mean and std capsule activation
us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in US]
us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in US]
plot_capsules(us_mean, scale_factor=1, ax=axes[2], show=False, title="mean activation")
plot_capsules(us_std, scale_factor=1, ax=axes[3], show=False, title="std activation")
plt.suptitle("dataset")
plt.show()

### classwise mean parse tree and mean activation

In [None]:
# mean and variance activation
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Mean parse tree
    cc = [C[idcs] for C in CC]
    cc_mean = [np.mean(c, axis=0) for c in cc]
    cc_std = [np.std(c, axis=0) for c in cc]
    plot_couplings(cc_mean, ax=axes[0], show=False, title="mean couplings")
    plot_couplings(cc_std, ax=axes[1], show=False, title="std couplings")
    
    # mean and std capsule activation
    us = [u[idcs] for u in US]
    us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in us]
    us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in us]
    plot_capsules(us_mean, scale_factor=1, ax=axes[2], show=False, title="mean activation")
    plot_capsules(us_std, scale_factor=1, ax=axes[3], show=False, title="std activation")
    plt.suptitle("class {}".format(cls))
    plt.show()

In [None]:
# mean and std of couplings 
for C in CC:
    C_mean = C.mean(axis=0)
    C_var = C.std(axis=0)
    #
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    axes[0].imshow(C_mean, cmap="gray", vmin=0., vmax=1.)
    axes[1].imshow(C_var, cmap="gray", vmin=0.)
    #plt.cmap()
    plt.show()

In [None]:
# mean and variance activation
for U in US:
    u = np.linalg.norm(U, axis=2)
    u_mean = u.mean(axis=0)
    u_std = u.std(axis=0)
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    axes[0].bar(range(len(u_mean)), u_mean)
    axes[1].bar(range(len(u_mean)), u_std)
    axes[0].set_ylim(0,1)
    axes[1].set_ylim(0,1)
    plt.show()

In [None]:
for cls in range(10):
    print("#" * 100, "\n{}\n".format(cls), "#" * 100)
    idcs = np.where(YY == cls)[0]
    # mean and variance couplings 
    for C in CC:
        C = C[idcs]
        C_mean = C.mean(axis=0)
        C_var = C.std(axis=0)
        #
        fig, axes = plt.subplots(1, 2, figsize=(4, 2))
        axes[0].imshow(C_mean, cmap="gray", vmin=0., vmax=1.)
        axes[1].imshow(C_var, cmap="gray", vmin=0.)
        #plt.cmap()
        plt.show()

In [None]:
# mean and variance activation
for cls in range(10):
    print("#" * 100, "\n{}\n".format(cls), "#" * 100)
    idcs = np.where(YY == cls)[0]
    for U in US:
        U = U[idcs]
        u = np.linalg.norm(U, axis=2)
        u_mean = u.mean(axis=0)
        u_std = u.std(axis=0)
        fig, axes = plt.subplots(1, 2, figsize=(4, 2))
        axes[0].bar(range(len(u_mean)), u_mean)
        axes[1].bar(range(len(u_mean)), u_std)
        axes[0].set_ylim(0,1)
        axes[1].set_ylim(0,1)
        plt.show()

# Metrics

In [None]:
def mean_max(C):
    return np.max(C, axis=2).mean()

def max_std_dev(C):
    return np.max(C.std(axis=0), axis=1).mean()

def max_std_dev(C):
    return np.sqrt(np.mean((C - C.mean(axis=0))**2, axis=0)).max(axis=1).mean()

def calc_norm_entropy(C):
    Cm = C.mean(axis=0)
    Ce = np.sum(Cm * np.log(Cm) * (1/np.log(Cm.shape[1])), axis=1) * -1
    return Ce

In [None]:
for C in CC:
    print("{:.3f}   {:.3f}  {:.3f}".format(mean_max(C), max_std_dev(C), calc_norm_entropy(C).mean()))

In [None]:
# Uniform routing
CC_uni = []
for C in CC:
    CC_uni.append(np.ones(C.shape) / C.shape[2])
for C in CC_uni:
    print("{:.3f}   {:.3f}  {:.3f}".format(mean_max(C), max_std_dev(C), calc_norm_entropy(C).mean()))

In [None]:
# Uniform routing
CC_uni = []
for C in CC:
    CC_uni.append(np.ones(C.shape) / C.shape[2])
for C in CC_uni:
    print("{:.3f}   {:.3f}".format(mean_max(C), max_std_dev(C)))

In [None]:
CC_rand = []
for C in CC:
    Cr = np.random.rand(*C.shape) * 10
    Cr = torch.softmax(torch.Tensor(Cr), dim=-1).numpy()
    CC_rand.append(Cr)
for C in CC_rand:
    print("{:.3f}   {:.3f}  {:.3f}".format(mean_max(C), max_std_dev(C), calc_norm_entropy(C).mean()))

In [None]:
sf = 3
fig, axes = plt.subplots(3, len(CC), figsize=(sf * len(CC), sf * 3))
for idx in range(len(CC)):
    C = CC[idx]
    C_ent = calc_norm_entropy(C)
    MAXC_mean = C.max(axis=2).mean(axis=0)
    MAXC_std = C.max(axis=2).std(axis=0)
    xx = range(C.shape[1])
    axes[0][idx].bar(xx, MAXC_mean)
    axes[0][idx].set_title("mean MAX coupling")
    axes[0][idx].set_ylim(0, 1)
    axes[1][idx].bar(xx, MAXC_std)
    axes[1][idx].set_title("std MAX coupling")
    axes[1][idx].set_ylim(0, 1)
    axes[2][idx].bar(xx, C_ent)
    axes[2][idx].set_ylim(0, 1)
    axes[2][idx].set_title("entropy coupling")
    #plt.bar(MAX)
    #plot_mat(MAXC, scale_factor=1)
plt.show()

# Correlation of lower level capsule activation and max coupling coefficients!

In [None]:
from scipy.stats import pearsonr

In [None]:
for idx in range(len(US) - 1):
    UL = US[idx]
    UH = US[idx + 1]
    C = CC[idx]
    
    # correlation for all capsules
    U_norm = np.linalg.norm(UL, axis=2)
    C_max = np.max(C, axis=2)
    u = U_norm.flatten()
    c = C_max.flatten()
    corr, _ = pearsonr(u, c)
    print(corr)
    
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # correlation for each capsule
    corrs = [pearsonr(U_norm[:, idx], C_max[:, idx])[0] for idx in range(U_norm.shape[1])]
    nans = [int(np.isnan(i)) for i in corrs]
    
    xx = range(len(corrs))
    axes[0].bar(xx, corrs)
    axes[0].bar(xx, nans)
    axes[0].set_xticks(xx)
    
    xx = range(len(corrs))
    # capsule mean & std
    axes[1].bar(xx, U_norm.mean(axis=0))
    axes[1].set_ylim(0, 1)
    axes[1].set_xticks(xx)
    axes[2].bar(xx, U_norm.std(axis=0), alpha=0.5)
    axes[2].set_ylim(0, 1)
    axes[2].set_xticks(xx)
    plt.show()
    # capsule sld


# Correlation of HIGHER level capsule activation and max coupling coefficients to lower capsules

In [None]:
for idx in range(len(US) - 1):
    UH = US[idx + 1]
    C = CC[idx]
    
    # correlation for all capsules
    U_norm = np.linalg.norm(UH, axis=2)
    C_max = np.max(C, axis=1)
    u = U_norm.flatten()
    c = C_max.flatten()
    corr, _ = pearsonr(u, c)
    print(corr)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # correlation for each capsule
    corrs = [pearsonr(U_norm[:, idx], C_max[:, idx])[0] for idx in range(U_norm.shape[1])]
    nans = [int(np.isnan(i)) for i in corrs]

    xx = range(len(corrs))
    axes[0].bar(xx, corrs)
    axes[0].bar(xx, nans)
    axes[0].set_xticks(xx)
    
    xx = range(len(corrs))
    # capsule mean & std
    axes[1].bar(xx, U_norm.mean(axis=0))
    axes[1].set_ylim(0, 1)
    axes[1].set_xticks(xx)
    axes[2].bar(xx, U_norm.std(axis=0), alpha=0.5)
    axes[2].set_ylim(0, 1)
    axes[2].set_xticks(xx)
    plt.show()
    # capsule sld


# How dynamic is the Parse Tree?

- For lower level capsules: variation in the routing to upper level capsules
- For higher level capsules: variation in the capsules that route to it

In [None]:
def cc_matrix(x1, x2):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    c = x1.T @ x2
    c.div_(x1.shape[0])
    return c

def reshape_(x):
    if len(x.shape) == 1:
        x = x.reshape((-1, 1))
    return x

def cc_bn(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    bn = torch.nn.BatchNorm1d(x1.shape[1], affine=False, eps=eps)
    x1 = bn(x1)
    bn = torch.nn.BatchNorm1d(x1.shape[1], affine=False, eps=eps)
    x2 = bn(x2)
    if debug:
        print("bn(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("bn(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

def cc_norm(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    # recenter
    #x1 = x1 - x1.mean()
    #x2 = x2 - x2.mean()
    
    # unit variance
    if eps > 0:
        x1 = (x1 - x1.mean(axis=0)) / torch.sqrt(x1.var(axis=0) + eps)
        x2 = (x2 - x2.mean(axis=0)) / torch.sqrt(x2.var(axis=0) + eps)
    else:
        x1 = (x1 - x1.mean(axis=0)) / x1.std()
        x2 = (x2 - x2.mean(axis=0)) / x2.std()
    
    if debug:
        print("mv(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("mv(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

def cc_norm2(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    # recenter
    x1 = x1 - x1.mean(axis=0)
    x2 = x2 - x2.mean(axis=0)
    
    # unit variance
    if eps > 0:
        x1 = x1 / torch.sqrt(x1.var(axis=0) + eps)
        x2 = x2 / torch.sqrt(x2.var(axis=0) + eps)
    else:
        x1 = x1 / x1.std()
        x2 = x2 / x2.std()
    
    if debug:
        print("mv(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("mv(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

import scipy.stats as stats

In [None]:
layer_idx = 0
C = CC[layer_idx]
U = US[layer_idx]
#
U_norm = np.linalg.norm(U, axis=2)
C_max = C.max(axis=2)

In [None]:
for idx in range(U_norm.shape[1]):
    plt.scatter(U_norm[:,idx], C_max[:,idx])
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel("U_norm")
    plt.ylabel("C_max")
    plt.show()

In [None]:
# max coupling
# min coupling
# mean coupling
# std coupling
# to how many capsules is a capsule connected?
# how strong is the connection?
# visualize routing together with activation

# CNN Only Baseline

In [None]:
"""
epochs = 050, acc = 98,0, 73.0
epochs = 101, acc = 98.5, 74.1
epochs = 201, acc = 
"""

In [None]:
model = CustomBB(ch_in=1, n_classes=10)
#
model = model.to(device)
#backbone
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=2e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
#
criterion = nn.CrossEntropyLoss()

In [None]:
print(count_parameters(model))

In [None]:
num_epochs = 101
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_mnist_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        logits = model.forward(x)
        loss = criterion(logits, y_true)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(logits, dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_mnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            logits = model.forward(x)
            
            y_pred = torch.argmax(logits, dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))

    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_affnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            logits = model.forward(x)
            
            y_pred = torch.argmax(logits, dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   affnist acc_valid: {:.3f}".format(total_correct / total))