In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./../")

In [None]:
from pathlib import Path
import math
import pickle
#
import torch
import torchvision
from torchvision import utils
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
#
from misc.plot_utils import plot_mat, imshow
from effcn.models import AffnistEffCapsNetWOBias
from datasets import AffNIST

In [None]:
p_experiments = Path("/mnt/experiments/effcn/affnist/grid_search/")
device = torch.device("cuda:1")

In [None]:
runs_nan = []
runs_success = []
for p_experiment in p_experiments.iterdir():
    p_config = p_experiment / "config.pkl"
    p_stats = p_experiment / "stats.pkl"
    if not p_stats.exists():
        continue
    with open(p_config, "rb") as file:
        config = pickle.load(file)
    with open(p_stats, "rb") as file:
        stats = pickle.load(file)
        all_accs = stats["valid"]["affnist"]["acc"]
        if math.isnan(stats["train"]["loss"][-1]):
            runs_nan.append((config, stats))
        else:
            runs_success.append((config, stats))
        max_acc = max(all_accs)
        #print("bs={:4d}, lr={:.5f}, wd={:.5f} rec_weights={:8.5f}  max_acc={:.4f} notes={}".format(
        #    config.train.batch_size,
        #    config.optimizer_args.lr,
        #    config.optimizer_args.weight_decay,
        #    config.loss.rec.weight,
        #    max_acc,
        #    stats["notes"]
        #     )
        #     )

In [None]:
for config, stats in runs_success:
    all_accs_an = stats["valid"]["affnist"]["acc"]
    all_accs_mn = stats["valid"]["mnist"]["acc"]
    max_acc_an = max(all_accs_an)
    max_acc_mn = max(all_accs_mn)
    if max_acc_an < 0.87:
        continue
    print("bs={:4d}, lr={:.5f}, wd={:.5f} rec_weights={:8.5f}  max_acc_aff={:.4f} max_acc_mn={:.4f}".format(
            config.train.batch_size,
            config.optimizer_args.lr,
            config.optimizer_args.weight_decay,
            config.loss.rec.weight,
            max_acc_an,
            max_acc_mn
             )
             )
    print(p_experiments / config.names.model_dir)

In [None]:
config.names

In [None]:
#p_experiment = Path("/mnt/experiments/effcn/affnist/grid_search/effcn_affnist_2021_12_08_15_58_05")
p_experiment = Path("/mnt/experiments/effcn/affnist/grid_search_wob/effcn_affnist_2021_12_13_19_32_19")
p_config = p_experiment / "config.pkl"
p_stats = p_experiment / "stats.pkl"
p_ckpts = p_experiment / "ckpts"
p_data = config.paths.data
with open(p_config, "rb") as file:
    config = pickle.load(file)
with open(p_stats, "rb") as file:
    stats = pickle.load(file)
#

p_model = p_ckpts / config.names.model_file.format(150)
p_model.exists()

In [None]:
state = torch.load(p_model)

In [None]:
list(state.keys())

In [None]:
state["fcncaps.B"]

In [None]:
model = AffnistEffCapsNetWOBias()
model.load_state_dict(torch.load(p_model))
model = model.to(device)
model.eval()

In [None]:
ds_mnist_train = AffNIST(p_root=p_data, split="mnist_train",
                             download=True, transform=None, target_transform=None)
ds_mnist_valid = AffNIST(p_root=p_data, split="mnist_valid",
                             download=True, transform=None, target_transform=None)
ds_affnist_valid = AffNIST(p_root=p_data, split="affnist_valid",
                               download=True, transform=None, target_transform=None)


In [None]:
dl_mnist_train = torch.utils.data.DataLoader(
        ds_mnist_train,
        batch_size=config.train.batch_size,
        shuffle=True,
        # prefetch_factor=3,
        persistent_workers=True,
        pin_memory=config.train.pin_memory,
        num_workers=config.train.num_workers)
dl_mnist_valid = torch.utils.data.DataLoader(
        ds_mnist_valid,
        batch_size=config.valid.batch_size,
        shuffle=True,
        pin_memory=config.valid.pin_memory,
        persistent_workers=True,
        num_workers=config.valid.num_workers)
dl_affnist_valid = torch.utils.data.DataLoader(
        ds_affnist_valid,
        batch_size=config.valid.batch_size,
        shuffle=True,
        persistent_workers=True,
        pin_memory=config.valid.pin_memory,
        num_workers=config.valid.num_workers)

x_train, y_train = next(iter(dl_mnist_train))
#x_vis_train = x[:config.train.num_vis]

x_mn_valid, y_mn_valid = next(iter(dl_mnist_valid))
x_an_valid, y_an_valid = next(iter(dl_affnist_valid))

In [None]:
model

In [None]:
weights = torch.clone(model.primcaps.dw_conv2d.weight)
#weights = torch.clone(model.backbone.layers[9].weight)
#
print(weights.min().item(), weights.max().item(), weights.mean().item())

In [None]:
def imshow(img, cmap="gray", vmin=None, vmax=None):
    npimg = img.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap=cmap, vmin=None, vmax=None)
    plt.show()

def visTensor(tensor, ch=0, allkernels=False, nrow=8, padding=1):
        tensor = tensor.cpu()
        n,c,w,h = tensor.shape

        if allkernels: tensor = tensor.view(n*c, -1, w, h)
        elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)

        rows = np.min((tensor.shape[0] // nrow + 1, 64))    
        grid = utils.make_grid(tensor, nrow=nrow, normalize=True, padding=padding)
        plt.figure( figsize=(nrow,rows) )
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        #return grid

In [None]:
visTensor(weights, allkernels=True, padding=1)

In [None]:
for weight in weights:
    plt.imshow(weight[0].detach().cpu())
    plt.show()

In [None]:
YY = []
CC = []
CCB = []
UUH = []
UUHSQ = []
UUHFIN = []
UUL = []

for X, Y in dl_mnist_valid:
    x_bb = model.backbone(X.to(device))
    U_l = model.primcaps(x_bb)
    U_hat, A, A_scaled, A_sum, C, CB, U_h_fin, U_h_sq = model.fcncaps.forward_debug(U_l)
    
    UUL.append(U_l.detach().cpu().numpy())
    UUH.append(U_hat.detach().cpu().numpy())
    YY.append(Y.numpy())
    CC.append(C.detach().cpu().numpy())
    CCB.append(CB.detach().cpu().numpy())
    UUHSQ.append(U_h_sq.detach().cpu())
    UUHFIN.append(U_h_fin.detach().cpu())
YY = np.concatenate(YY)
CC = np.concatenate(CC)
CCB = np.concatenate(CCB)
UUHSQ = np.concatenate(UUHSQ)
UUHFIN = np.concatenate(UUHFIN)
UUH = np.concatenate(UUH)
UUL = np.concatenate(UUL)

print(YY.shape)
print(CC.shape)
#
print(x_bb.shape)
print(U_l.shape)
print(U_hat.shape)
print(A.shape)
print(A_scaled.shape)
print(A_sum.shape)
print(C.shape)
print(CB.shape)
print(U_h_fin.shape)
print(U_h_sq.shape)

### Analyse Routing and resuls

In [None]:
idcs = np.where(YY == 1)
Y = YY[idcs]
C = CC[idcs]
CB = CCB[idcs]
UH = UUH[idcs]
UHS = UUHSQ[idcs]
UHF = UUHFIN[idcs]

In [None]:
idx = 0
y = Y[idx]
c = C[idx]
cb = CB[idx]
uh = UH[idx]
uf = UHF[idx]
#
print(c.shape)
print(cb.shape)
print(uh.shape)
print(uf.shape)

In [None]:
i = 0
uh[:, i,:].T.dot(cb[:,i])

In [None]:
uf[i]

In [None]:
 plot_mat(c, scale_factor=0.4)

In [None]:
 plot_mat(cb, scale_factor=0.4)

In [None]:
for gh in range(10):
    plot_mat(uh[:, gh,:], scale_factor=0.4)

In [None]:
for idx in range(1):
    print("#"*100)
    y = Y[idx]
    c = C[idx]
    cb = CCB[idx]
    uhs = UHS[idx]
    uhf = UHF[idx]
    ul = UUL[idx]
    plot_mat(ul, scale_factor=0.4, title="U_l = lower level capsules")
    plot_mat(c, scale_factor=0.4, title="C")
    plot_mat(cb, scale_factor=0.4, title="C+B")
    plot_mat(uhf, scale_factor=0.4, title="U_h, upper layer capsules w/o squash")
    plot_mat(uhs, scale_factor=0.4, title="squash(U_h)")
    fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(10, 2))
    axes[0].bar(range(10), c.mean(axis=0))
    axes[0].set_title("C")
    axes[1].bar(range(10), cb.mean(axis=0))
    axes[1].set_title("CB")
    axes[3].bar(range(10), np.linalg.norm(uhs, axis=1))
    axes[3].set_title("sqash(U_h)")
    axes[2].bar(range(10), np.linalg.norm(uhf, axis=1))
    axes[2].set_title("U_h without Squash")
    axes[4].bar(range(16), np.linalg.norm(ul, axis=1))
    axes[4].set_title("U_l")
    plt.show()

In [None]:
np.linalg.norm(ul, axis=1)

In [None]:
ul.shape

In [None]:
for u in uhs:
    print(np.linalg.norm(u))