In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./../../")

In [None]:
from pathlib import Path
import math
import pickle
#
import torch
import torchvision
from torchvision import utils
import torchvision.transforms as T
#
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
#
from misc.plot_utils import plot_mat, imshow
from effcn.functions import max_norm_masking
from effcn.models_affnist import EffCapsNet
from datasets import AffNIST

### Preprocessing

In [None]:
device = torch.device("cuda")
device

In [None]:
p_experiment = "/mnt/data/experiments/EfficientCN/affnist/effcn_affnist_2022_01_18_18_58_59"
p_experiment = Path(p_experiment)
p_config = p_experiment / "config.pkl"
p_stats = p_experiment / "stats.pkl"
p_ckpts = p_experiment / "ckpts"
with open(p_config, "rb") as file:
    config = pickle.load(file)
with open(p_stats, "rb") as file:
    stats = pickle.load(file)
#
p_data = config.paths.data
p_model = p_ckpts / config.names.model_file.format(150)
p_model.exists()

In [None]:
#config.paths.data = '/mnt/data/datasets/smallnorb'
config.paths.data



In [None]:
model = EffCapsNet()
model.load_state_dict(torch.load(p_model))
model = model.to(device)
model.eval()

In [None]:
ds_mnist_train = AffNIST(p_root=p_data, split="mnist_train",
                             download=True, transform=None, target_transform=None)
ds_mnist_valid = AffNIST(p_root=p_data, split="mnist_valid",
                             download=True, transform=None, target_transform=None)
ds_affnist_valid = AffNIST(p_root=p_data, split="affnist_valid",
                               download=True, transform=None, target_transform=None)
  

In [None]:
idx = 0

print(len(ds_mnist_train.data))


x,y = ds_mnist_train[idx]

x = torch.unsqueeze(x,dim=0)
y = torch.unsqueeze(y,dim=0)

print(x.size())
print(y.size(), y)

### Generate Capsuls for influenced Reconstruction

In [None]:
x = x.to(device)
uh, _ = model.forward(x)

In [None]:
print(uh.size())

In [None]:
m_uh = max_norm_masking(uh)
m_uh = torch.flatten(m_uh, start_dim=1)
x_rec = model.decoder(m_uh)

In [None]:
def imshow(img, cmap="gray", vmin=None, vmax=None):
    npimg = img.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap=cmap, vmin=None, vmax=None)
    plt.show()

In [None]:
imshow(torch.squeeze(x_rec,dim=0))

In [None]:
print(torch.argmax(uh,dim=2))

sq_uh = torch.squeeze(uh,dim=0)
a = torch.norm(sq_uh, dim=1)
plt.bar(np.arange(0,10,1), a.detach().cpu().numpy())
plt.show()

i_cap = torch.argmax(torch.norm(uh, dim=2))
ref = uh[:,i_cap,:]

uh_delta = uh.clone()
uh_delta[:,i_cap,:] += torch.tensor([0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0]).to(device)
delta = uh_delta[:,i_cap,:]

torch.manual_seed(42)
uh_delta2 = uh.clone()
uh_delta2[:,i_cap,:] += (torch.rand(16).to(device)-0.5) * 0.1
delta2 = uh_delta2[:,i_cap,:]

print(torch.squeeze(ref,dim=0))
plt.plot(torch.squeeze(ref,dim=0).detach().cpu().numpy(),"b*")
print(torch.squeeze(delta,dim=0))
plt.plot(torch.squeeze(delta,dim=0).detach().cpu().numpy(),"r*")
print(torch.squeeze(delta2,dim=0))
plt.plot(torch.squeeze(delta2,dim=0).detach().cpu().numpy(),"g*")
plt.show()




uh_n = torch.cat((uh, uh_delta, uh_delta2), dim=0)  

uh_n.size()

#print(uh_delta-uh)


In [None]:
m_uh_n = max_norm_masking(uh_n)
m_uh_n = torch.flatten(m_uh_n, start_dim=1)
x_rec_n = model.decoder(m_uh_n)

In [None]:
print(x_rec_n.size())
rec = x_rec_n.cpu()
scal = lambda x: (x-x.min())/(x.max()-x.min())
img = torchvision.utils.make_grid(torch.cat([scal(rec)], dim=0), nrow=rec.shape[0])
plt.imshow(img.permute(1,2,0))
plt.show()

### Prepare affine transformations embedding

In [None]:
x_aff = T.functional.affine(img=x, angle=0, translate=[0,0], scale=1.,shear=0)

imshow(torch.squeeze(x_aff,dim=0))

In [None]:
def affine_xtrans(img, target, range=[-5.,5.,1]):
    arange = np.arange(range[0],(range[1]+range[2]),range[2])
    x_trans = torch.zeros([len(arange),img.shape[1],img.shape[2],img.shape[3]])
    l_target = torch.zeros(len(arange))

    for i, l in enumerate(arange):
        x_trans[i] = T.functional.affine(img=x, angle=0, translate=[l,0], scale=1.,shear=0)
        l_target[i] = target
    
    return x_trans, l_target

def affine_ytrans(img, target, range=[-5.,5.,1]):
    arange = np.arange(range[0],(range[1]+range[2]),range[2])
    y_trans = torch.zeros([len(arange),img.shape[1],img.shape[2],img.shape[3]])
    l_target = torch.zeros(len(arange))

    for i, l in enumerate(arange):
        y_trans[i] = T.functional.affine(img=x, angle=0, translate=[0,l], scale=1.,shear=0)
        l_target[i] = target
    
    return y_trans, l_target

def affine_rot(img, target, range=[-25.,25.,1]):
    arange = np.arange(range[0],(range[1]+range[2]),range[2])
    rot = torch.zeros([len(arange),img.shape[1],img.shape[2],img.shape[3]])
    l_target = torch.zeros(len(arange))

    for i, l in enumerate(arange):
        rot[i] = T.functional.affine(img=x, angle=l, translate=[0,0], scale=1.,shear=0)
        l_target[i] = target
    
    return rot, l_target

In [None]:
x_trams, y_trans = affine_ytrans(x,y)
print(x_trams.shape)
print(y_trans)

x_trams = x_trams.cpu()
scal = lambda x: (x-x.min())/(x.max()-x.min())
img = torchvision.utils.make_grid(torch.cat([scal(x_trams)], dim=0), nrow=x_trams.shape[0])
plt.imshow(img.permute(1,2,0))
plt.show()

imshow(x_trams[0,:,:,:])


In [None]:
x_trans, _ = affine_xtrans(x,y)
y_trans, _ = affine_ytrans(x,y)
rot, _ = affine_rot(x,y)

### ATE on x_trans

In [None]:
#x_aff = x_trans.to(device)
x_aff = y_trans.to(device)
#x_aff = rot.to(device)

uh_aff, _ = model.forward(x_aff)

In [None]:
def cov_uh_trans(uh):
    """
    uh in [k,n,m]
    k -> number of transformed images
    n -> number of output classes
    m -> number of capsul values
    """

    uh_mean = uh.mean(dim=0)
    """
    #could be cleaner, but it's the same ...
    uh_mean = uh_mean.unsqueeze(0)
    uh_mean = uh_mean.repeat(uh.shape[0],1,1)
    """

    z = uh - uh_mean

    c_k = torch.einsum('...ij, ...ik -> ...jk', z,z)
    c = torch.einsum('ijk -> jk', c_k) / c_k.shape[0]

    return c

In [None]:
c = cov_uh_trans(uh_aff)


In [None]:
def covshow(c):
    # https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py

    data = c.cpu().detach().numpy()

    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(data)

    fg_color = 'white'
    im.axes.tick_params(color=fg_color, labelcolor=fg_color)

    """
    # Loop over data dimensions and create text annotations.
    for i in range(len(data)):
        for j in range(len(data)):
            text = ax.text(j, i, data[i, j],
                        ha="center", va="center", color="fg_color")
    """

    cb = fig.colorbar(im)
    cb.ax.yaxis.set_tick_params(color=fg_color)
    cb.outline.set_edgecolor(fg_color)
    plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color=fg_color)
    fig.tight_layout()
    plt.show()

In [None]:
covshow(c)

### via PCA from Paper EffCN

In [None]:
cov = c.cpu()

eig, v_eig = torch.linalg.eig(cov)
eig = eig.float()
sig = eig / eig.sum()

#rint(eig)
#print(eig.sum())
print(sig)
plt.plot(sig.detach().numpy(),".")
plt.tick_params(colors="w")



### Via KL-Divergende from munich paper

In [None]:
### !!! Not shure about this!!!

# https://mail.python.org/pipermail/scipy-user/2011-May/029521.html

def KLdivergence(x, y):
  """Compute the Kullback-Leibler divergence between two multivariate samples.
  Parameters
  ----------
  x : 2D array (n,d)
    Samples from distribution P, which typically represents the true
    distribution.
  y : 2D array (m,d)
    Samples from distribution Q, which typically represents the approximate
    distribution.
  Returns
  -------
  out : float
    The estimated Kullback-Leibler divergence D(P||Q).
  References
  ----------
  Pérez-Cruz, F. Kullback-Leibler divergence estimation of
continuous distributions IEEE International Symposium on Information
Theory, 2008.
  """
  from scipy.spatial import cKDTree as KDTree

  # Check the dimensions are consistent
  x = np.atleast_2d(x)
  y = np.atleast_2d(y)

  n,d = x.shape
  m,dy = y.shape

  assert(d == dy)


  # Build a KD tree representation of the samples and find the nearest neighbour
  # of each point in x.
  xtree = KDTree(x)
  ytree = KDTree(y)

  # Get the first two nearest neighbours for x, since the closest one is the
  # sample itself.
  r = xtree.query(x, k=2, eps=.01, p=2)[0][:,1]
  s = ytree.query(x, k=1, eps=.01, p=2)[0]

  # There is a mistake in the paper. In Eq. 14, the right side misses a negative sign
  # on the first term of the right hand side.
  return -np.log(r/s).sum() * d / n + np.log(m / (n - 1.))

In [None]:
def cov_uh_kl(uh):
    c_k = torch.einsum('...ij, ...ik -> ...jk', uh,uh)
    c = torch.einsum('ijk -> jk', c_k) / c_k.shape[0]
    return c 

In [None]:
c_org = cov_uh_kl(uh.cpu())
c_aff = cov_uh_kl(uh_aff.cpu())

#determinaten
c_org_det = torch.linalg.det(c_org)
c_aff_det = torch.linalg.det(c_aff)

#trace
c_org_det = torch.trace(c_org)
c_aff_det = torch.trace(c_aff)

#compute kl div from git hub
kl_div = KLdivergence(c_org.detach().numpy(), c_aff.detach().numpy())

print(kl_div)
covshow(c_org)
covshow(c_aff)