In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./..")

# standard lib
import shutil
from pathlib import Path
import pickle

# external imports
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from dotted_dict import DottedDict
import pprint
from tqdm import tqdm
import scipy as sp
import pandas as pd
#
from torchvision.datasets import CIFAR10

# local imports
from datasets import AffNIST
from effcn.layers import FCCaps, Squash
from effcn.functions import margin_loss, max_norm_masking
from effcn.utils import count_parameters
from misc.optimizer import get_optimizer, get_scheduler
from misc.utils import get_sting_timestamp, mkdir_directories
from misc.plot_utils import plot_couplings, plot_capsules, plot_mat, plot_mat2
from misc.metrics import dynamics, mean_max, adjusted_dynamics, adjusted_mean_max
from misc.metrics2 import *

In [None]:
pd.options.display.float_format = '{:,.2f}'.format

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu"  
device = torch.device(dev)

# Data

In [None]:
transform_train = T.Compose([
    T.RandomAffine(degrees=(-8, 8),
                   shear=(-15, 15),
                   scale=(0.9, 1.1)
                  )
])
transform_valid = None # converts [0,255] to [0,1] by dividing through 255

p_data = '/home/matthias/projects/EfficientCN/data'

ds_mnist_train = AffNIST(p_root=p_data, split="mnist_train", download=True, transform=transform_train, target_transform=None)
ds_mnist_valid = AffNIST(p_root=p_data, split="mnist_valid", download=True, transform=transform_valid, target_transform=None)
ds_affnist_valid = AffNIST(p_root=p_data, split="affnist_valid", download=True, transform=transform_valid, target_transform=None)

In [None]:
bs = 512
dl_mnist_train = torch.utils.data.DataLoader(
    ds_mnist_train, 
    batch_size=bs, 
    shuffle=True,
    pin_memory=True,
    num_workers=4)
dl_mnist_valid= torch.utils.data.DataLoader(
    ds_mnist_valid, 
    batch_size=bs, 
    shuffle=True, 
    pin_memory=True,
    num_workers=4)
dl_affnist_valid= torch.utils.data.DataLoader(
    ds_affnist_valid, 
    batch_size=bs, 
    shuffle=True, 
    pin_memory=True,
    num_workers=4)

In [None]:
x, _ = next(iter(dl_mnist_train))
x_vis_train = x[:32]

x, _ = next(iter(dl_mnist_valid))
x_vis_mnist_valid = x[:32]

x, _ = next(iter(dl_affnist_valid))
x_vis_affnist_valid = x[:32]

In [None]:
plt.imshow(torchvision.utils.make_grid(x_vis_train).permute(1,2,0))
plt.show()
#
plt.imshow(torchvision.utils.make_grid(x_vis_mnist_valid).permute(1,2,0))
plt.show()
#
plt.imshow(torchvision.utils.make_grid(x_vis_affnist_valid).permute(1,2,0))
plt.show()

# Backbone

In [None]:
class CustomBB(nn.Module):
    def __init__(self, ch_in=3, n_classes=10):
        super().__init__()
        self.ch_in = ch_in
        self.n_classes=n_classes
    
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=128, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=9, groups=256, stride=1, padding="valid"),
            #nn.BatchNorm2d(num_features=256),
        )
        self.fc = nn.Linear(256 , n_classes)
    
    def forward(self, x):
        x = self.convs(x)
        x = torch.flatten(x, 1)     # -> (b, 256), remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)
        return x

In [None]:
model = CustomBB(ch_in=1)
y = model(torch.rand(128, 1, 40, 40))
y.shape

In [None]:
count_parameters(model)

In [None]:
class FCCaps(nn.Module):
    """
        Attributes
        ----------
        n_l ... number of lower layer capsules
        d_l ... dimension of lower layer capsules
        n_h ... number of higher layer capsules
        d_h ... dimension of higher layer capsules

        W   (n_l, n_h, d_l, d_h) ... weight tensor
        B   (n_l, n_h)           ... bias tensor
    """

    def __init__(self, n_l, n_h, d_l, d_h, do=0.0):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        
        
        self.W = torch.nn.Parameter(torch.rand(
            n_l, n_h, d_l, d_h), requires_grad=True)
        #self.B = torch.nn.Parameter(torch.rand(n_l, n_h), requires_grad=True)
        self.squash = Squash(eps=1e-20)

        # init custom weights
        # i'm relly unsure about this initialization scheme
        # i don't think it makes sense in our case, but the paper says so ...
        torch.nn.init.kaiming_normal_(
            self.W, a=0, mode='fan_in', nonlinearity='leaky_relu')
        #torch.nn.init.kaiming_normal_(
        #    self.B, a=0, mode="fan_in", nonlinearity="leaky_relu")

        self.attention_scaling = np.sqrt(self.d_l)
        
        self.dp = nn.Dropout(p=do)

    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        U_norm = torch.norm(U_l, dim=-1)
        A = torch.einsum("...ijk,...j->...ijk", A, U_norm)
        #A = A / self.attention_scaling)
        A_sum = torch.einsum("...hij->...hj", A)
        A_sum = self.dp(A_sum)
        #A_sum = torch.einsum("...ij,...i->...ij", A_sum, U_norm)
        C = torch.softmax(A_sum, dim=-1)
        #CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h)

    def forward_debug(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l

        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS:
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        #A = A / self.attention_scaling
        U_norm = torch.norm(U_l, dim=-1)
        A = torch.einsum("...ijk,...j->...ijk", A, U_norm)
        A_sum = torch.einsum("...hij->...hj", A)
        #A_sum = torch.einsum("...ij,...i->...ij", A_sum, U_norm)
        C = torch.softmax(A_sum, dim=-1)
        #CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, C)
        return self.squash(U_h), C

class DeepCapsNet(nn.Module):
    def __init__(self, ns, ds):
        super().__init__()
        self.ns = ns
        self.ds = ds
        
        self.backbone = CustomBB(ch_in=1)
        self.backbone.fc = nn.Identity()
        
        self.squash = Squash(eps=1e-20)
        layers = []
        for idx in range(1, len(ns), 1):
            do = 0
            if idx == 1:
                do = 0.4
            n_l = ns[idx - 1]
            n_h = ns[idx]
            d_l = ds[idx - 1]
            d_h = ds[idx]
            layers.append(FCCaps(n_l, n_h, d_l, d_h, do=do) )
        self.layers = nn.Sequential(*layers)


    def forward(self, x):
        x = self.backbone(x)
        
        # primecaps
        x = self.squash(x.view(-1, self.ns[0], self.ds[0]))
        
        for layer in self.layers:
            x = layer(x)
        return x

    def forward_debug(self, x):
        x = self.backbone(x)
        
        # primecaps
        x = self.squash(x.view(-1, self.ns[0], self.ds[0]))
        
        us = [torch.clone(x)]
        cc = []
        # fccaps
        for layer in self.layers:
            x, c = layer.forward_debug(x)
            cc.append(c.detach())
            us.append(torch.clone(x).detach())
        return x, cc, us

In [None]:
# ns = [32, 32, 32, 10]
# ds = [8, 16, 16, 16]

ns = [32, 10, 10, 10, 10, 10]  # gehn nicht
ds = [8, 16, 16, 16, 16, 16]

ns = [32, 10, 10, 10, 10]  # geht
ds = [8, 16, 16, 16, 16]


ns = [32, 32, 32, 10]
ds = [8, 16, 16, 16]

model = DeepCapsNet(ns=ns, ds=ds)
#
print("tot Model ", count_parameters(model))
print("Backbone  ", count_parameters(model.backbone))
#
model = model.to(device)
model

In [None]:
model.layers

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)

In [None]:
num_epochs = 51
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_mnist_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        u_h = model.forward(x)
        
        # LOSS
        y_one_hot = F.one_hot(y_true, num_classes=10)
        loss = margin_loss(u_h, y_one_hot)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_mnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            u_h = model.forward(x)
            
            y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))

    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_affnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            u_h = model.forward(x)
            
            y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   affnist acc_valid: {:.3f}".format(total_correct / total))

baseline:
- normal routing: 80% acc 

# Visualize and Analyze

### Show parse tree and activations for individual samples

In [None]:
x, y = next(iter(dl_affnist_valid))
x = x[:128]
y = y[:128]
#
model.eval()
with torch.no_grad():
    u_h, CC, US = model.forward_debug(x.to(device))
y_pred = torch.argmax(torch.norm(u_h, dim=2), dim=1)
y_pred = y_pred.detach().cpu().numpy()
#
US = [u.cpu().numpy() for u in US]
CS = [c.cpu().numpy() for c in CC]
#
Y_true = y.cpu().numpy()
Y_pred = y_pred

In [None]:
cl = 9
for idx in range(32):
    if cl is not None and Y_true[idx] != cl:
        continue
    cs = [c[idx] for c in CS]
    us = [u[idx] for u in US]
    u_norms = [np.linalg.norm(u, axis=1) for u in us]
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    title = "exp={} a={}".format(y[idx], y_pred[idx])
    #
    plot_couplings(cs, title=title, ax=axes[0], show=False)
    #
    plot_capsules(u_norms, title=title , ax=axes[1], show=False)
    plt.show()

# Collect Statistics

In [None]:
model.eval()

YY = []
CC = [[] for _ in range(len(ns) - 1)]
US = [[] for _ in range(len(ns))]


for x,y_true in dl_affnist_valid:
    x = x.to(device)
    #y_true = y_true.to(device)
        
    with torch.no_grad():
        _, cc, us = model.forward_debug(x.to(device))
        for idx in range(len(cc)):
            CC[idx].append(cc[idx].detach().cpu().numpy())
        for idx in range(len(us)):
            US[idx].append(us[idx].detach().cpu().numpy())
        YY.append(y_true.numpy())
YY = np.concatenate(YY)
CC = [np.concatenate(c) for c in CC]
US = [np.concatenate(u) for u in US]

### Mean parse tree and mean activation for dataset

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
# Mean parse tree
cc_mean = [np.mean(c, axis=0) for c in CC]
cc_std = [np.std(c, axis=0) for c in CC]
plot_couplings(cc_mean, ax=axes[0], show=False, title="mean couplings")
plot_couplings(cc_std, ax=axes[1], show=False, title="std couplings")
    
# mean and std capsule activation
us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in US]
us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in US]
plot_capsules(us_mean, scale_factor=1, ax=axes[2], show=False, title="mean activation")
plot_capsules(us_std, scale_factor=1, ax=axes[3], show=False, title="std activation")
plt.suptitle("dataset")
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(4 * len(CC), 4))

CNS = [normalize_couplings(C) for C in CC]

CNS_MAN = [ma_couplings_n(C, pr) for C, pr in CNS]
CNS_MAX = [C.max(axis=0) for C, pr in CNS]
CNS_STD = [stda_couplings_n(C, pr) for C, pr in CNS]

plot_couplings(CNS_MAN, ax=axes[0], show=False, title="mean")
plot_couplings(CNS_STD, ax=axes[1], show=False, title="std")
plot_couplings(CNS_MAX, ax=axes[2], show=False, title="max")
plt.show()

### classwise mean parse tree and mean activation

In [None]:
# mean and variance activation
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    
    fig, axes = plt.subplots(1, 6, figsize=(24, 4))
    
    cc = [C[idcs] for C in CC]
    CNS = [normalize_couplings(C, eps_rate=0.5) for C in cc]
    
    CNS_MAN = [ma_couplings_n(C, pr) for C, pr in CNS]
    CNS_MAX = [C.max(axis=0) for C, pr in CNS]
    CNS_STD = [stda_couplings_n(C, pr) for C, pr in CNS]

    plot_couplings(CNS_MAN, ax=axes[0], show=False, title="mean")
    plot_couplings(CNS_STD, ax=axes[1], show=False, title="std")
    plot_couplings(CNS_MAX, ax=axes[2], show=False, title="max")
    
    # mean and std capsule activation
    us = [u[idcs] for u in US]
    us_mean = [np.linalg.norm(u, axis=-1).mean(axis=0) for u in us]
    us_std = [np.linalg.norm(u, axis=-1).std(axis=0) for u in us]
    us_max = [np.linalg.norm(u, axis=-1).max(axis=0) for u in us]
    
    plot_capsules(us_mean, scale_factor=1, ax=axes[3], show=False, title="mean activation")
    plot_capsules(us_std, scale_factor=1, ax=axes[4], show=False, title="std activation")
    plot_capsules(us_max, scale_factor=1, ax=axes[5], show=False, title="max activation")
    plt.suptitle("class {}".format(cls))
    plt.show()

# Capsules - Dead and Alive

In [None]:
c_th_mu = 1e-2
c_th_sd = 1e-2

fig, axes = plt.subplots(len(US), 2, figsize=(6, 3 * len(US)))
#
US_alive = []
for idx in range(len(US)):
    U = US[idx]
    U_norm = np.linalg.norm(U, axis=2)
    U_norm_mu = U_norm.mean(axis=0)
    U_norm_sd = U_norm.std(axis=0)
    #
    U_dead = (U_norm_sd < 1e-2) * (U_norm_mu < 1e-2)
    #
    xx = range(len(U_norm_mu))
    axes[idx][0].set_title("mu(norm(U))")
    axes[idx][0].bar(xx, U_dead, color="red",alpha=0.1)
    axes[idx][0].bar(xx, U_norm_mu)
    axes[idx][0].set_ylim(0, 1)
    axes[idx][1].set_title("sd(norm(U))")
    axes[idx][1].bar(xx, U_norm_sd)
    axes[idx][1].bar(xx, U_dead, color="red",alpha=0.1)
    axes[idx][1].set_ylim(0, 1)
    U_alive = 1 - U_dead
    US_alive.append(U_alive)
plt.show()

# Determine what ACTIVE capsules are

### via capsule norms

In [None]:
ARS = []
ths = [0.6, 0.4, 0.2, 0.1, 0.05, 0.01, 0.001]
for th in ths:
    ars = []
    for U in US:
        U_norm = np.linalg.norm(U, axis=2)
        U_a = U_norm > th
        ars.append(U_a.mean())
    ARS.append(ars)
ARS = np.array(ARS)
#
plot_mat2(ARS, scale_factor=1, row_names=ths)

In [None]:
#
# Seems not to work
#
for U in US:
    U_norm = np.linalg.norm(U, axis=2)
    print("{:.3f}".format(U_norm.mean(axis=0).mean()))

### via coupling coefficients

In [None]:
#
# This seems to be consistent with capsule norms with extremely small threshold
#
for C in CC:
    Cn, pr = normalize_couplings(C)
    ar = 1 - pr
    print(ar.mean())

# Metrics

In [None]:
from misc.metrics2 import *

In [None]:
def mma_capsules_n(C, pr):
    _, n_l, _ = C.shape
    assert len(pr) == n_l
    mma = C.max(axis=-1).mean(axis=0) / (1 - pr.flatten()  + EPS)
    return mma

def mma_layer(C, pr):
    mma = mma_capsules_n(C, pr)
    ws = (1 - pr) / (1 - pr + EPS).sum()
    mma = (ws * mma).sum()
    return mma

def mm_layer(C, pr):
    mm = mean_max(C)
    ws = (1 - pr) / (1 - pr + EPS).sum()
    mm = (ws * mm).sum()
    return mm

# How Strong are the couplings

In [None]:
#
# Using all capsules
#
for idx in range(len(CC)):
    C = CC[idx]
    Cn, pr = normalize_couplings(C, eps_rate=0.5)
    mma = mma_capsules_n(Cn, pr)
    mm = mm_capsules(C)
    #
    lmma = mma_layer(Cn, pr)
    lmm = mm_layer(C, pr)
    ar = (1 - pr).mean()
    print_str = "mma {:.3f} lmma: {:.3f} mm {:.3f} lmm {:.3f} | ar {:.2f}"
    print(print_str.format(mma.mean(), lmma,  mm.mean(), lmm, ar))

In [None]:
#
# Removing permanently dead capsules
#
for idx in range(len(CC)):
    U_alive = US_alive[idx]
    dr = 1 - U_alive.mean()
    C = CC[idx]
    
    # CN
    Cn, pr = normalize_couplings(C, eps_rate=0.5)
    #
    art = (1 - pr).mean()
    #
    pr = pr[U_alive == True]
    Cn = Cn[:, U_alive==True, :]
    
    # C
    C = C[:, U_alive == True, :]
    mma = mma_capsules_n(Cn, pr)
    mm = mm_capsules(C)
    #
    lmma = mma_layer(Cn, pr)
    lmm = mm_layer(C, pr)
    ar = (1 - pr).mean()
    print_str = "mma {:.3f} lmma: {:.3f} mm {:.3f} lmm {:.3f} | dead {:.2f} alive {:.2f} art {:.2f}  ara {:.2f}"
    print(print_str.format(mma.mean(), lmma,  mm.mean(), lmm, dr, 1-dr, art, ar))

In [None]:
#
# Coupling Strength Removing permanently dead capsulesClasswise
#
vals = []

#
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    for idx in range(len(CC)):
        U_alive = US_alive[idx]
        dr = 1 - U_alive.mean()
        C = CC[idx][idcs]

        # CN
        Cn, pr = normalize_couplings(C, eps_rate=0.5)
        #
        art = (1 - pr).mean()
        nat = art * C.shape[1]
        #
        pr = pr[U_alive == True]
        Cn = Cn[:, U_alive==True, :]

        # C
        C = C[:, U_alive == True, :]
        mma = mma_capsules_n(Cn, pr)
        mm = mm_capsules(C)
        #
        lmma = mma_layer(Cn, pr)
        lmm = mm_layer(C, pr)
        ara = (1 - pr).mean()
        #
        naa = ara * C.shape[1]
        #
        vals.append((cls, idx, mma.mean(), lmma, mm.mean(), lmm, dr, 1-dr,ara, art, (1-dr)*ara, nat, naa))

In [None]:
df = pd.DataFrame(data=vals, columns=["class", "layer", "mma", "lmma", "mm", "lmm", "dead", "alive", "ara", "art", "ara*alive", "nat", "naa"])
#
for idx in range(len(CC)):
    sdf = df[df["layer"] == idx].drop(columns=["layer"])
    print(sdf)

### HOW DYNAMIC ARE THE COUPLINGS

In [None]:
EPS = 1e-9
#
def ma_couplings_n(C, pr):
    _, n_l, _ = C.shape
    pr = pr.reshape(n_l, 1)
    return C.mean(axis=0) / ( 1 - pr + EPS)
#
def stda_couplings_n(C, pr):
    _, n_l, _ = C.shape
    pr = pr.reshape(n_l, 1)
    ma = ma_couplings_n(C, pr)
    p1 = ((C - ma)**2).mean(axis=0) / (1 - pr  + EPS)
    p2 = ma**2 * pr / (1 - pr  + EPS)
    p3 = p1 - p2
    p3 = np.maximum(0, p3)
    return np.sqrt(p3)
#
def lmstda_capsules(C, pr):
    mstda = stda_couplings_n(C, pr).mean(axis=1)
    ws = (1 - pr) / (1 - pr + EPS).sum()
    lmstda = (mstda * ws).sum()
    return lmstda

def dyc_capsules_n1(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    masd = stda_couplings_n(C, pr).mean(axis=1)
    dyc = masd / std_pr
    #dyc = dyc.mean() / (1 - pr.mean() + 1e-9)
    return dyc

def dyc_capsules_n2(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    masd = stda_couplings_n(C, pr).mean(axis=1)
    mx = C.max(axis=(0,2)) 
    dyc = masd / (std_pr * mx + EPS)
    return dyc

def dyc_capsules_n3(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    stda = stda_couplings_n(C, pr)
    mstda = stda.mean(axis=1)
    
def dycm_capsules(C, pr):
    mstda = stda_couplings_n(C, pr).mean(axis=1)
    mma = mma_capsules_n(C, pr)
    assert np.all(mstda <= mma)
    dycm = mstda / (mma + EPS)
    return dycm

def ldycm_capsules(C, pr):
    dyc = dycm_capsules(C, pr)
    ws = (1 - pr) / (1 - pr + EPS).sum()
    ldyc = (ws * dyc).sum()
    return ldyc

def dycpr_capsules(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    mstda = stda_couplings_n(C, pr).mean(axis=1)
    dycpr = mstda / std_pr
    return dycpr

def ldycpr_capsules(C, pr):
    dycpr = dycpr_capsules(C, pr)
    ws = (1 - pr) / (1 - pr + EPS).sum()
    ldycpr = (ws * dycpr).sum()
    return ldycpr

def dycmpr_capsules(C, pr):
    n_samples, n_l, n_h = C.shape
    mstda = stda_couplings_n(C, pr).mean(axis=1)
    mma = mma_capsules_n(C, pr)
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    assert np.all(mstda <= mma)
    dycmpr = mstda / (mma * std_pr + EPS)
    return dycmpr

def ldycmpr_capsules(C, pr):
    dycmpr = dycmpr_capsules(C, pr)
    ws = (1 - pr) / (1 - pr + EPS).sum()
    ldycmpr =  (ws * dycmpr).sum()
    return ldycmpr

def dycmxpr_capsules(C, pr):
    n_samples, n_l, n_h = C.shape
    mstda = stda_couplings_n(C, pr).mean(axis=1)
    cmx = C.max(axis=(0,-1))
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    assert np.all(mstda <= cmx)
    dycmxpr = mstda / (cmx * std_pr + EPS)
    return dycmxpr

def ldycmxpr_capsules(C, pr):
    n_samples, n_l, n_h = C.shape
    mstda = stda_couplings_n(C, pr).mean(axis=1)
    cmx = C.max(axis=(0,-1))
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    assert np.all(mstda <= cmx)
    dycmxpr = mstda / (cmx + EPS)
    #
    ws = (1 - pr) / (1 - pr + EPS).sum()
    ldycmxpr =  (ws * dycmxpr).sum() / std_pr
    return ldycmxpr

In [None]:
for idx in range(len(CC)):
    U_alive = US_alive[idx]
    dr = 1 - U_alive.mean()
    C = CC[idx]

    # CN
    Cn, pr = normalize_couplings(C, eps_rate=0.5)
    #
    art = (1 - pr).mean()
    nat = art * C.shape[1]
    #
    pr = pr[U_alive == True]
    Cn = Cn[:, U_alive==True, :]
    ara = (1 - pr).mean()
    
    Cn, pr = normalize_couplings(C, eps_rate=0.5) # this should not have an influence
    
    dycm = dycm_capsules(Cn, pr)
    ldycm = ldycm_capsules(Cn, pr)
    dycpr = dycpr_capsules(Cn, pr)
    ldycpr = ldycpr_capsules(Cn, pr)
    dycmpr = dycmpr_capsules(Cn, pr)
    ldycmpr = ldycmpr_capsules(Cn, pr)
    dycmxpr = dycmxpr_capsules(Cn, pr)
    ldycmxpr = ldycmxpr_capsules(Cn, pr)
    #
    print("ldycm {:.3f} ldycpr: {:.3f} ldycmpr {:.3f} ldycmxpr: {:.3f} | art {:.3f} ara {:.3f}".format(
        ldycm, ldycpr, ldycmpr, ldycmxpr, art, ara))
    break

In [None]:
c = Cn[:,0,:]

In [None]:
c.std(axis=0)

In [None]:
n_h = 10
s = torch.randint(0, n_h, (1000000,))
c = torch.nn.functional.one_hot(s).float()
c.std(axis=0)

In [None]:
std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
std_pr

In [None]:
c = torch.nn.functional.one_hot(s).float().numpy()
c = np.maximum(0, c - 0.3)
c.std(axis=0)

In [None]:
std_pr = np.sqrt(1/n_h * (1 - 1/n_h)) * 0.7
std_pr

In [None]:
ara = (1 - pr).mean()
for idx in range(len(CC)):
    U_alive = US_alive[idx]
    dr = 1 - U_alive.mean()
    C = CC[idx]
    
    
    
    # CN
    Cn, pr = normalize_couplings(C, eps_rate=0.5)
    art = (1 - pr).mean()
    
    Cn = Cn[:, U_alive == True, :]
    
    dycm = dycm_capsules(Cn, pr)
    ldycm = ldycm_capsules(Cn, pr)
    dycpr = dycpr_capsules(Cn, pr)
    ldycpr = ldycpr_capsules(Cn, pr)
    dycmpr = dycmpr_capsules(Cn, pr)
    ldycmpr = ldycmpr_capsules(Cn, pr)
    dycmxpr = dycmxpr_capsules(C, pr)
    ldycmxpr = ldycmxpr_capsules(C, pr)
    #
    print("ldycm {:.3f} ldycpr: {:.3f} ldycmpr {:.3f} ldycmxpr: {:.3f} | art {:.3f}".format(
        ldycm, ldycpr, ldycmpr, ldycmxpr, art))

In [None]:
for idx in range(len(CC)):
    U_alive = US_alive[idx]
    dr = 1 - U_alive.mean()
    C = CC[idx]
    
    # CN
    Cn, pr = normalize_couplings(C, eps_rate=0.5)
    dycm = dycm_capsules(Cn, pr)
    ldycm = ldycm_capsules(Cn, pr)
    dycpr = dycpr_capsules(Cn, pr)
    ldycpr = ldycpr_capsules(Cn, pr)
    dycmpr = dycmpr_capsules(Cn, pr)
    ldycmpr = ldycmpr_capsules(Cn, pr)
    dycmxpr = dycmxpr_capsules(C, pr)
    ldycmxpr = ldycmxpr_capsules(C, pr)
    #
    print("dycm {:.3f} ldycm {:.3f} dycpr: {:.3f} ldycpr: {:.3f} dycmpr {:.3f} ldycmpr {:.3f} dycmxpr: {:.3f} ldycmxpr: {:.3f}".format(
        dycm.mean(), ldycm, dycpr.mean(), ldycpr, dycmpr.mean(), ldycmpr, dycmxpr.mean(), ldycmxpr))

In [None]:
for idx in range(len(CC)):
    U_alive = US_alive[idx]
    dr = 1 - U_alive.mean()
    C = CC[idx]
    C = C[:, U_alive == True, :]
    
    # CN
    Cn, pr = normalize_couplings(C, eps_rate=0.5)
    dycm = dycm_capsules(Cn, pr)
    ldycm = ldycm_capsules(Cn, pr)
    dycpr = dycpr_capsules(Cn, pr)
    ldycpr = ldycpr_capsules(Cn, pr)
    dycmpr = dycmpr_capsules(Cn, pr)
    ldycmpr = ldycmpr_capsules(Cn, pr)
    dycmxpr = dycmxpr_capsules(C, pr)
    ldycmxpr = ldycmxpr_capsules(C, pr)
    #
    print("dycm {:.3f} ldycm {:.3f} dycpr: {:.3f} ldycpr: {:.3f} dycmpr {:.3f} ldycmpr {:.3f} dycmxpr: {:.3f} ldycmxpr: {:.3f}".format(
        dycm.mean(), ldycm, dycpr.mean(), ldycpr, dycmpr.mean(), ldycmpr, dycmxpr.mean(), ldycmxpr))

In [None]:
C = Cn
n_samples, n_l, n_h = C.shape
std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
std_pr

In [None]:
cmax = C.max(axis=(0, -1))

In [None]:
mma = mma_capsules_n(C, pr)
mma

In [None]:
mstda = stda_couplings_n(C, pr).mean(axis=1)
mstda

In [None]:
(mstda / (mma * std_pr)).mean()

In [None]:
(mstda / (cmax * std_pr)).mean()

In [None]:
(mstda / (mma * std_pr)).mean()

In [None]:
(ws * dyc).sum() / std_pr

In [None]:
vals = []

#
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    for idx in range(len(CC)):
        C = CC[idx][idcs]
        #
        Cn, pr = normalize_couplings(C)
        dyna = dyc_capsules_n(Cn, pr)
        mma = mma_capsules_n(Cn, pr)
        #
        mmaa = mma.mean() / (1 - pr.mean() + 1e-9)
        dynaa = dyna.mean() / (1 - pr.mean() + 1e-9)
        #
        dyn = dyc_capsules(C)
        mm = mm_capsules(C)
        vals.append((cls, idx, 1 - pr.mean(), mmaa, mm.mean(), dynaa, dyn.mean()))

In [None]:
df = pd.DataFrame(data=vals, columns=["class", "layer", "ar", "mma", "mm", "dyna", "dyn"])
for idx in range(len(CC)):
    print(df[df["layer"] == idx])

In [None]:
for idx in range(len(CC)):
    U = US[idx]
    C = CC[idx]
    #
    U_alive = US_alive[idx]
    rate_alive_perm = U_alive.sum() / U_alive.shape[0]
    rate_dead_perm = 1 - rate_alive_perm
    #
    ar_total = (np.linalg.norm(U, axis=2) > 0.1).mean(axis=0)
    pr_total = 1 - ar_total
    #
    mma_total = np.array([0.0])  # adjusted_mean_max(C, pr_total)
    dya_total = np.array([0.0])  # only defined for pr<1 adjusted_dynamics(C, pr_total)
    #
    C = C[:,np.where(U_alive == True)[0],:]
    U = U[:,np.where(U_alive == True)[0],:]
    #
    ar_alive = (np.linalg.norm(U, axis=2) > 0.01).mean(axis=0)
    #ar_alive = np.linalg.norm(U, axis=2).mean(axis=0)
    pr_alive = 1 - ar_alive
    #
    mma_alive = adjusted_mean_max(C, pr_alive)
    dya_alive = adjusted_dynamics(C, pr_alive)
    
    print("LAYER {}".format(idx))
    print("#" * 20)
    print("Permanently Alive:     {:.2f}   ({:3d}/{:3d})".format(rate_alive_perm, U_alive.sum(), U_alive.shape[0]))
    #print("Permanently Dead:      {:.2f}".format(rate_dead_perm))
    print("Activity Rate (Alive): {:.2f}".format(ar_alive.mean()))
    print("Activity Rate (Total): {:.2f}".format(ar_total.mean()))
    print("#" * 20)
    print("Couplings")
    print("#" * 20)
    print("    total  alive")
    print("mma {:.3f} {:.3f}".format(mma_total.mean(), mma_alive.mean()))
    print("dya {:.3f} {:.3f}".format(dya_total.mean(), dya_alive.mean()))
    #
    print("#" * 20)

In [None]:
vals = []

#
for cls in range(10):
    idcs = np.where(YY == cls)[0]
    for idx in range(len(CC)):
        U = US[idx][idcs]
        C = CC[idx][idcs]
        #
        U_alive = US_alive[idx]
        rate_alive_perm = U_alive.sum() / U_alive.shape[0]
        rate_dead_perm = 1 - rate_alive_perm
        #
        ar_total = (np.linalg.norm(U, axis=2) > 0.1).mean(axis=0)
        #
        C = C[:,np.where(U_alive == True)[0],:]
        U = U[:,np.where(U_alive == True)[0],:]
        #
        ar_alive = (np.linalg.norm(U, axis=2) > 0.01).mean(axis=0)
        pr_alive = 1 - ar_alive
        #
        mma = adjusted_mean_max(C, pr_alive)
        mmn = mean_max(C)
        dya = adjusted_dynamics(C, pr_alive)
        dyn = dynamics(C)
        vals.append((cls, idx, ar_total.mean(), ar_alive.mean(), mma.mean(), mmn.mean(), dya.mean(), dyn.mean()))

In [None]:
df = pd.DataFrame(data=vals, columns=["class", "layer", "art", "ara", "mma", "mmn", "dya", "dyn"])

In [None]:
for idx in range(len(CC)):
    print(df[df["layer"] == idx])

# Couplings

#### Couplings FROM DEAD Capsules

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    C = C[:,np.where(Ul_alive == False)[0],:]
    
    if len(C.flatten()) < 1:
        print("No dead capsules for layer {}".format(idx))
        continue
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(33, 11))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

#### Couplings FROM Alive Capsules

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    C = C[:,np.where(Ul_alive == True)[0],:]
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(42, 14))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

per sample count max coupling and use max to find out if coupling in general gets lower or just the average as they are loosly connected


### Couplings FROM ALIVE to DEAD

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    Uh_alive = US_alive[idx + 1]
    if (1 - Uh_alive).sum() < 1:
        print("{} No dead capsules for upper layer {}".format(idx ,idx + 1))
        continue
    
    
    C = C[:,np.where(Ul_alive == True)[0],:][:,:,np.where(Uh_alive == False)[0]]
    
    if len(C.flatten()) < 1:
        print("No dead capsules for layer {}".format(idx))
        continue
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(33, 11))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

### Couplings FROM ALIVE TO ALIVE

In [None]:
for idx in range(len(CC)):
    C = CC[idx]
    Ul_alive = US_alive[idx]
    Uh_alive = US_alive[idx + 1]
    if (Uh_alive).sum() < 1:
        print("No dead capsules for upper layer {}".format(idx + 1))
        continue
    
    
    C = C[:,np.where(Ul_alive == True)[0],:][:,:,np.where(Uh_alive == True)[0]]
    
    if len(C.flatten()) < 1:
        print("No dead capsules for layer {}".format(idx))
        continue
    
    C_mu = C.mean(axis=0)
    C_sd = C.std(axis=0)
    C_mx = C.max(axis=0)
    fig, axes = plt.subplots(1, 3, figsize=(33, 11))
    plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=0.5)
    plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
    plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=0.5)
plt.show()

# CNN Only Baseline

In [None]:
"""
epochs = 050, acc = 98,0, 73.0
epochs = 101, acc = 98.5, 74.1
epochs = 201, acc = 
"""

In [None]:
model = CustomBB(ch_in=1, n_classes=10)
#
model = model.to(device)
#backbone
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=2e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
#
criterion = nn.CrossEntropyLoss()

In [None]:
print(count_parameters(model))

In [None]:
num_epochs = 101
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_mnist_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    for x,y_true in pbar:
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        
        logits = model.forward(x)
        loss = criterion(logits, y_true)

        loss.backward()
        
        optimizer.step()
        
        y_pred = torch.argmax(logits, dim=1)
        acc = (y_true == y_pred).sum() / y_true.shape[0]
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'acc': acc.item()
                 }
        )
    lr_scheduler.step()
    #
    # ####################
    # VALID
    # ####################
    if epoch_idx % 5 != 0:
        continue
    
    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_mnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            logits = model.forward(x)
            
            y_pred = torch.argmax(logits, dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   mnist acc_valid: {:.3f}".format(total_correct / total))

    model.eval()
    
    total_correct = 0
    total = 0

    for x,y_true in dl_affnist_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        
        with torch.no_grad():
            logits = model.forward(x)
            
            y_pred = torch.argmax(logits, dim=1)
            total_correct += (y_true == y_pred).sum()
            total += y_true.shape[0]
    print("   affnist acc_valid: {:.3f}".format(total_correct / total))