In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./../..")

# standard lib
import shutil
from pathlib import Path

# external imports
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm
import scipy as sp
import pandas as pd
pd.options.display.float_format = '{:,.2f}'.format
import pickle
from torch.utils.data import DataLoader
import torch.nn as nn

# local imports
from effcn.functions import masking
from datasets.csprites import ClassificationDataset
from effcn.layers import Squash
from effcn.functions import margin_loss, max_norm_masking
from misc.utils import count_parameters
from misc.plot_utils import plot_couplings, plot_capsules, plot_mat, plot_mat2
from misc.metrics import *
from misc.utils import normalize_transform, inverse_normalize_transform

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)

# Data

In [None]:
# black background
p_data = '/mnt/data/csprites/single_csprites_32x32_n7_c24_a12_p6_s2_bg_1_constant_color_145152'

# structured background
# p_data = '/mnt/data/csprites/single_csprites_32x32_n7_c24_a12_p6_s2_bg_inf_random_function_145152'


p_ds_config = Path(p_data) / "config.pkl"
with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)
target_variable = "shape"
target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
n_classes = ds_config["n_classes"][target_variable]
#
norm_transform = normalize_transform(ds_config["means"],
                               ds_config["stds"])
#
target_transform = lambda x: x[target_idx]
transform = T.Compose(
    [T.ToTensor(),
     norm_transform,
    ])
inverse_norm_transform = inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)

In [None]:
# TRAIN
batch_size = 512
num_workers = 4
#
ds_train = ClassificationDataset(
    p_data = p_data,
    transform=transform,
    target_transform=target_transform,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False
)
# VALID
ds_valid = ClassificationDataset(
    p_data = p_data,
    transform=transform,
    target_transform=target_transform,
    split="valid"
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=batch_size,
    shuffle=True,
    num_workers = num_workers,
    pin_memory=False
)

In [None]:
n_vis = 64
x,y = next(iter(dl_train))
x = x[:n_vis]
y = y[:n_vis]
#
x = inverse_norm_transform(x)
#
grid_img = torchvision.utils.make_grid(x, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))

# Models

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, n_classes, d_out):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(n_classes * d_out, 10*10)
        )
        
        self.layer2 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            #nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(64, 3, kernel_size=(3, 3), padding="valid"),
            #nn.LeakyReLU(0.3,inplace=True),
            #nn.Conv2d(128, 3, kernel_size=(3, 3), padding="same"),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = x.view(-1, 1, 10, 10)
        x = self.layer2(x)
        return x


In [None]:
model = Decoder(7, 16)
y = model(torch.rand(1, 7 * 16))
print(y.shape)
print(count_parameters(model))

### Backbone

In [None]:
class CustomBB(nn.Module):
    """
        Custom backbone
    """
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=32, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=256, groups=32, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=7, groups=32, stride=1, padding="valid"),
        )
        self.fc = nn.Linear(512 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

class CustomBB(nn.Module):
    """
        Custom backbone
    """
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, groups=1, stride=2, padding=0, bias=False),
            #nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, groups=32, stride=2, padding=0, bias=False),
            #nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=7, groups=32, stride=1, padding="valid", bias=False),
        )
        self.fc = nn.Linear(512 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

class CustomBB(nn.Module):
    """
        Custom backbone
    """
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, groups=1, stride=2, padding=0, bias=False),
            #nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.Conv2d(32, 256, kernel_size=3, groups=32, stride=2, padding=0, bias=False),
            #nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=7, groups=32, stride=1, padding="valid", bias=False),
        )
        self.fc = nn.Linear(512 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

class CustomBB(nn.Module):
    """
        Custom backbone
    """
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, groups=1, stride=1, padding="same", bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 256, kernel_size=3, groups=32, stride=1, padding="same", bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, kernel_size=3, groups=32, stride=1, padding=0, bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(512, 512, kernel_size=3, groups=32, stride=1, padding=0, bias=False),

)
        self.fc = nn.Linear(512 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

In [None]:
# Check if it works
model = CustomBB(ch_in=3)
y = model(torch.rand(128, 3, 32, 32))
print(count_parameters(model))
print(y.shape)

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, n_z, d_z, d_x, d_i, scale=0.01, dim=1):
        super().__init__()
        self.z = torch.nn.Parameter(torch.rand(n_z, d_z), requires_grad=True)
        self.to_q = nn.Linear(d_z, d_i, bias=False)
        self.to_kv = nn.Linear(d_x, d_i * 2, bias=False)
        self.scale = scale
        self.dim = dim
        #
        self.squash = Squash(eps=1e-20)
    
    def forward(self, x):
        out, _ = self.forward_debug(x)
        return out
    
    def forward_debug(self, x):
        q = self.to_q(self.z)
        k, v = self.to_kv(x).chunk(2, dim=-1)
        #
        S = torch.einsum("id, ...jd -> ...ij", q, k)
        
        C = torch.softmax(S / self.scale, dim=self.dim)
        
        # (b,n_h,n_l) (b,n_l,d_i)
        out = torch.einsum("...ij, ...jk -> ...ik", C, v)
        #
        return self.squash(out), C.permute(0,2,1)

In [None]:
n_l = 4
n_h = 3
d_l = 5
d_h = 2
#
model = CrossAttention(n_z=n_h, d_z=d_h, d_x=d_l, d_i=d_h)
x = torch.rand(1, n_l, d_l)
y, C = model.forward_debug(x)

In [None]:
C

In [None]:
class FCCaps(nn.Module):
    """
        Attributes
        ----------
        n_l ... number of lower layer capsules
        d_l ... dimension of lower layer capsules
        n_h ... number of higher layer capsules
        d_h ... dimension of higher layer capsules

        W   (n_l, n_h, d_l, d_h) ... weight tensor
        B   (n_l, n_h)           ... bias tensor
    """

    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        
        
        self.W = torch.nn.Parameter(torch.rand(
            n_l, n_h, d_l, d_h), requires_grad=True)
        self.B = torch.nn.Parameter(torch.rand(n_l, n_h), requires_grad=True)
        self.squash = Squash(eps=1e-20)

        # init custom weights
        # i'm relly unsure about this initialization scheme
        # i don't think it makes sense in our case, but the paper says so ...
        torch.nn.init.kaiming_normal_(
            self.W, a=0, mode='fan_in', nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(
            self.B, a=0, mode="fan_in", nonlinearity="leaky_relu")

        self.attention_scaling = np.sqrt(self.d_l)
        self.scaling = 0.01

    def forward(self, U_l):
        U_h, _ = self.forward_debug(U_l)
        return U_h

    def forward_debug(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        """
            Same as forward() but returns more stuff to analyze routing
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        
        # I removed the scaling, to create stronger couplings
        #A = A / self.attention_scaling)
        
        
        A_sum = torch.einsum("...hij->...hj", A)
        C = torch.softmax(A_sum / self.scaling, dim=-1)
        
        # I removed the Bias term
        #CB = C + B
        
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h), C

class DeepCapsNet(nn.Module):
    """
        A Deeper CN that allows
    """
    def __init__(self, ns, ds):
        super().__init__()
        self.ns = ns
        self.ds = ds
        
        self.backbone = CustomBB(ch_in=3)
        self.backbone.fc = nn.Identity()
        self.decoder = Decoder(ns[-1], ds[-1])
        
        self.squash = Squash(eps=1e-20)
        layers = []
        for idx in range(1, len(ns), 1):
            dim = 1
            scale = 0.001
            if idx == 1:
                dim = dim
                scale = scale
            n_l = ns[idx - 1]
            n_h = ns[idx]
            d_l = ds[idx - 1]
            d_h = ds[idx]
            layers.append(CrossAttention(n_z=n_h, d_z=d_h, d_x=d_l, d_i=d_h, dim=dim, scale=scale))
        self.layers = nn.Sequential(*layers)
        
        
        #self.PW = nn.Parameter(torch.rand(ns[0], ds[0], ds[0]))


    def forward(self, x, y_true=None):
        x = self.backbone(x)
        
        # primecaps
        x = x.view(-1, self.ns[0], self.ds[0])
        #x = torch.einsum("...nd,nkd->...nk", x, self.PW)
        x = self.squash(x)
        
        # fccaps
        for layer in self.layers:
            x = layer(x)
        
        # decoder
        u_h_masked = masking(x, y_true)
        x_rec = self.decoder(u_h_masked)
        return x, x_rec

    def forward_debug(self, x, y_true=None):
        x = self.backbone(x)
        
        # primecaps
        x = x.view(-1, self.ns[0], self.ds[0])
        #x = torch.einsum("...nd,nkd->...nk", x, self.PW)
        x = self.squash(x)
        
        us = [torch.clone(x)]
        cc = []
        # fccaps
        for layer in self.layers:
            x, c = layer.forward_debug(x)
            cc.append(c.detach())
            us.append(torch.clone(x).detach())
        
        u_h_masked = masking(x, y_true)
        x_rec = self.decoder(u_h_masked)
        
        return x, cc, us, x_rec

In [None]:
ns = [32, 26, 20, 14, n_classes]
ds = [16, 16, 16, 16, 16]
#
ns = [32, 32, 32, n_classes]
ds = [8, 8, 8, 8]
#
ns = [32, n_classes]
ds = [16, 16]

model = DeepCapsNet(ns=ns, ds=ds)
#
print("tot Model ", count_parameters(model))
print("Backbone  ", count_parameters(model.backbone))
print("Decoder  ", count_parameters(model.decoder))
print("Routing  ", count_parameters(model.layers))
#
model = model.to(device)
#model

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.98)
#
func_rec_loss = torch.nn.MSELoss()

In [None]:
n_vis = 8
x_vis, y_vis = next(iter(dl_valid))
x_vis = x_vis[:n_vis].to(device)
y_vis = y_vis[:n_vis].to(device)

In [None]:
num_epochs = 51
do_overfit = False
if do_overfit:
    x_of, y_of = next(iter(dl_train))
    x_of = x_of[:10]
    y_of = y_of[:10]
    n_vis = 10
    x_vis = x_of.to(device)
    y_vis = y_of.to(device)
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        if do_overfit:
            x = x_of
            y_true = y_of
        x = x.to(device)
        y_true = y_true.to(device)
        #
        #
        optimizer.zero_grad()
        
        u_h, x_rec = model.forward(x, y_true)
        
        # LOSS
        y_one_hot = F.one_hot(y_true, num_classes=n_classes)
        loss_mar = margin_loss(u_h, y_one_hot)
        
        loss_rec = func_rec_loss(inverse_norm_transform(x), x_rec) * 10
        
        loss = loss_rec + loss_mar

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'rec': loss_rec.item(),
                 'mar': loss_mar.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            u_h, x_rec = model.forward(x)
            
            y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))
    #
    with torch.no_grad():
        _, x_rec_vis = model(x_vis.to(device))
        _, x_rec_vis_mask = model(x_vis.to(device), y_vis.to(device))
    #
    grid_img = torch.cat([inverse_norm_transform(x_vis), x_rec_vis, x_rec_vis_mask], dim=0)
    grid_img = torchvision.utils.make_grid(grid_img, nrow=n_vis)
    plt.imshow(grid_img.permute(1, 2, 0).cpu())
    plt.show()

### results
epoch = 50 0.994 0.883, groups=256
epoch = 50 0.992 0.923, groups=32


# Visualize and Analyze

### Show parse tree and activations for individual samples

In [None]:
x, y = next(iter(dl_valid))

model.eval()
with torch.no_grad():
    u_h, CC, US, x_rec = model.forward_debug(x.to(device))
y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
y_pred = y_pred.detach().cpu().numpy()
#
US = [u.cpu().numpy() for u in US]
CS = [c.cpu().numpy() for c in CC]
#
Y_true = y.cpu().numpy()
Y_pred = y_pred

In [None]:
vis_class = None
vis_max = 4
for idx in range(vis_max):
    if vis_class is not None and Y_true[idx] != vis_class:
        continue
    cs = [c[idx] for c in CS]
    us = [u[idx] for u in US]
    u_norms = [np.linalg.norm(u, axis=1) for u in us]
    
    # plot stuff
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    title = "exp={} a={}".format(y[idx], y_pred[idx])
    #
    plot_couplings(cs, title=title, ax=axes[0], show=False)
    plot_capsules(u_norms, title=title , ax=axes[1], show=False)
    plt.show()

# Statistics For Further Evaluation and Visualization

In [None]:
model.eval()

YY = []
CC = [[] for _ in range(len(ns) - 1)]
US = [[] for _ in range(len(ns))]


# use whole dataset
for x,y_true in dl_valid:
    x = x.to(device)
        
    with torch.no_grad():
        _, cc, us, _ = model.forward_debug(x.to(device))
        for idx in range(len(cc)):
            CC[idx].append(cc[idx].detach().cpu().numpy())
        for idx in range(len(us)):
            US[idx].append(us[idx].detach().cpu().numpy())
        YY.append(y_true.numpy())
        
# Dataset Labels
YY = np.concatenate(YY)

# Dataset Coupling Coefficient Matrices
CC = [np.concatenate(c) for c in CC]

# Dataset Capsules
US = [np.concatenate(u) for u in US]

### Mean parse tree and mean activation for dataset

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
# Mean parse tree
cc_mean = [np.mean(c, axis=0) for c in CC]
cc_std = [np.std(c, axis=0) for c in CC]
plot_couplings(cc_mean, ax=axes[0], show=False, title="mean couplings")
plot_couplings(cc_std, ax=axes[1], show=False, title="std couplings")
    
# mean and std capsule activation
us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in US]
us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in US]
us_max = [np.linalg.norm(u, axis=-1).max(axis=0) for u in US]
plot_capsules(us_mean, scale_factor=1, ax=axes[2], show=False, title="mean activation")
plot_capsules(us_std, scale_factor=1, ax=axes[3], show=False, title="std activation")
plot_capsules(us_max, scale_factor=1, ax=axes[4], show=False, title="max activation")
plt.suptitle("dataset")
plt.show()

### Parse tree from normalized Couplings

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(4 * len(CC), 4))

CNS = [normalize_couplings(C) for C in CC]

CNS_MAN = [ma_couplings(C, pr) for C, pr in CNS]
CNS_MAX = [C.max(axis=0) for C, pr in CNS]
CNS_STD = [stda_couplings(C, pr) for C, pr in CNS]

plot_couplings(CNS_MAN, ax=axes[0], show=False, title="mean")
plot_couplings(CNS_STD, ax=axes[1], show=False, title="std")
plot_couplings(CNS_MAX, ax=axes[2], show=False, title="max")
plt.show()

### Classwise mean parse tree and mean activation

In [None]:
# mean and variance activation
for cls in range(n_classes):
    idcs = np.where(YY == cls)[0]
    
    fig, axes = plt.subplots(1, 6, figsize=(24, 4))
    
    cc = [C[idcs] for C in CC]
    CNS = [normalize_couplings(C, eps_rate=0.5) for C in cc]
    
    CNS_MAN = [ma_couplings(C, pr) for C, pr in CNS]
    CNS_MAX = [C.max(axis=0) for C, pr in CNS]
    CNS_STD = [stda_couplings(C, pr) for C, pr in CNS]

    plot_couplings(CNS_MAN, ax=axes[0], show=False, title="mean")
    plot_couplings(CNS_STD, ax=axes[1], show=False, title="std")
    plot_couplings(CNS_MAX, ax=axes[2], show=False, title="max")
    
    # mean and std capsule activation
    us = [u[idcs] for u in US]
    us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in us]
    us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in us]
    us_max = [np.linalg.norm(u, axis=-1).max(axis=0) for u in us]
    
    plot_capsules(us_mean, scale_factor=1, ax=axes[3], show=False, title="mean activation")
    plot_capsules(us_std, scale_factor=1, ax=axes[4], show=False, title="std activation")
    plot_capsules(us_max, scale_factor=1, ax=axes[5], show=False, title="max activation")
    plt.suptitle("class {}".format(cls))
    plt.show()

# Coupling Death vs Alive

In [None]:
c_th_mu = 1e-2
c_th_sd = 1e-2

fig, axes = plt.subplots(len(US), 2, figsize=(6, 3 * len(US)))
#
US_alive = []
for idx in range(len(US)):
    U = US[idx]
    U_norm = np.linalg.norm(U, axis=2)
    U_norm_mu = U_norm.mean(axis=0)
    U_norm_sd = U_norm.std(axis=0)
    #
    U_dead = (U_norm_sd < 1e-2) * (U_norm_mu < 1e-2)
    #
    xx = range(len(U_norm_mu))
    axes[idx][0].set_title("mu(norm(U))")
    axes[idx][0].bar(xx, U_dead, color="red",alpha=0.1)
    axes[idx][0].bar(xx, U_norm_mu)
    axes[idx][0].set_ylim(0, 1)
    axes[idx][1].set_title("sd(norm(U))")
    axes[idx][1].bar(xx, U_norm_sd)
    axes[idx][1].bar(xx, U_dead, color="red",alpha=0.1)
    axes[idx][1].set_ylim(0, 1)
    U_alive = 1 - U_dead
    US_alive.append(U_alive)
plt.show()

# Metrics

### Vibrance

In [None]:
for U in US:
    pr = rate_dead_capsules_norm(U)
    print("#Permanently Dead: {:.3f}".format(pr.mean()))

In [None]:
for C in CC:
    pr = rate_inactive_capsules(C)
    print("Rate inactive capsules {:.3f}".format(pr.mean()))

In [None]:
# sanity check
for idx in range(len(CC)):
    C = CC[idx]
    U = US[idx]
    #
    rnd, rac, racnd =  get_vibrance(U, C)
    #
    print("rate alive: {:.3f} rate active {:.3f} rate active of alive {:.3f}".format(
        rnd, rac, racnd))

### Bonding

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    b = get_bonding(C)
    print_str = "bonding strength: {:.3f}"
    print(print_str.format(b))

### Dynamics

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    dyc = get_dynamics(C)
    #
    print("dynamics: {:.3f}".format(
          dyc))

### Correlation Capsule Activation and Max Coupling

In [None]:
for idx in range(len(CC)):
    U = US[idx]
    C = CC[idx]
    print("corr: {:.3f}".format(activation_coupling_corr(C, U)))

### metrics for whole dataset

In [None]:
vals = []
for idx in range(len(CC)):
    C = CC[idx]
    U = US[idx]
    #
    rnd, rac, racnd = get_vibrance(U, C)
    b = get_bonding(C)
    dyc = get_dynamics(C)
    cor = activation_coupling_corr(C, U)
    vals.append((idx,
                 rnd, rac, racnd,
                 b, dyc, cor))

In [None]:
cols = ["layer",
        "alive rate", "active rate", "active of alive rate",
        "bonding str.", "dynamics", "cor"]
df = pd.DataFrame(data=vals, columns=cols)
df

### metrics for whole dataset, but classwise

In [None]:

vals = []

#
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    for idx in range(len(CC)):
        C = CC[idx][idcs]
        U = US[idx][idcs]
        #
        rnd, rac, racnd = get_vibrance(U, C)
        b = get_bonding(C)
        dyc = get_dynamics(C)
        cor = activation_coupling_corr(C, U)
        vals.append((cls, idx,
                     rnd, rac, racnd,
                     b, dyc, cor))

In [None]:
cols = ["class", "layer",
        "alive rate", "active rate", "active of alive rate",
        "bonding str.", "dynamics", "cor"]
df = pd.DataFrame(data=vals, columns=cols)
#

In [None]:
for idx in range(len(CC)):
    sdf = df[df["layer"] == idx].drop(columns=["layer"])
    print(sdf)

# Couplings Viszalizations

#### Couplings FROM DEAD Capsules

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    C = C[:,np.where(Ul_alive == False)[0],:]
    
    if len(C.flatten()) < 1:
        print("No dead capsules for layer {}".format(idx))
        continue
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(33, 11))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

#### Couplings FROM Alive Capsules

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    C = C[:,np.where(Ul_alive == True)[0],:]
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(42, 14))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

per sample count max coupling and use max to find out if coupling in general gets lower or just the average as they are loosly connected


### Couplings FROM ALIVE to DEAD

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    Uh_alive = US_alive[idx + 1]
    if (1 - Uh_alive).sum() < 1:
        print("{} No dead capsules for upper layer {}".format(idx ,idx + 1))
        continue
    
    
    C = C[:,np.where(Ul_alive == True)[0],:][:,:,np.where(Uh_alive == False)[0]]
    
    if len(C.flatten()) < 1:
        print("No dead capsules for layer {}".format(idx))
        continue
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(33, 11))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

### Couplings FROM ALIVE TO ALIVE

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    Uh_alive = US_alive[idx + 1]
    if (Uh_alive).sum() < 1:
        print("No dead capsules for upper layer {}".format(idx + 1))
        continue
    
    
    C = C[:,np.where(Ul_alive == True)[0],:][:,:,np.where(Uh_alive == True)[0]]
    
    if len(C.flatten()) < 1:
        print("No dead capsules for layer {}".format(idx))
        continue
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(33, 11))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

# CNN Only Baseline

In [None]:
"""
epoch acc
010   96.0 85.5
020   99.6 87.0
050   1.00 88.4
100   1.09 88.6
"""

In [None]:
model = CustomBB(ch_in=3, n_classes=n_classes)
#
model = model.to(device)
#backbone
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=2e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
#
criterion = nn.CrossEntropyLoss()

In [None]:
print(count_parameters(model))

In [None]:
, acc = num_epochs = 51
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        logits = model.forward(x)
        loss = criterion(logits, y_true)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(logits, dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            logits = model.forward(x)
            
            y_pred = torch.argmax(logits, dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))