# Analysing our trainable prototypes

In [1]:
import torch
from utils.sup_finetuning import Classifier
import einops

In [2]:
c = Classifier(512, 5)

In [3]:
x = torch.randn(5, 5, 512)

In [49]:
protos = x.mean(1)

In [5]:
x = einops.rearrange(x, "c s e -> (c s) e")
x.shape

torch.Size([25, 512])

In [6]:
c.init_params_from_prototypes(x, 5, 5)

In [7]:
c.fc.weight

Parameter containing:
tensor([[ 0.5491,  1.0450,  0.0804,  ...,  0.5097, -0.7651,  0.5821],
        [-2.0117, -1.7803,  0.5387,  ...,  0.9011,  1.0330,  0.6310],
        [-1.0230, -2.1573,  0.2695,  ...,  0.6123,  1.1428, -0.6075],
        [-0.4493,  0.5694,  1.0883,  ..., -0.2787,  0.0188,  0.2655],
        [ 0.0794, -2.2574, -1.1109,  ..., -0.2563, -0.1068, -0.9108]],
       requires_grad=True)

In [8]:
einops.rearrange(x, "(c s) e -> c s e", c=5, s=5).mean(1) * 2

tensor([[ 0.5491,  1.0450,  0.0804,  ...,  0.5097, -0.7651,  0.5821],
        [-2.0117, -1.7803,  0.5387,  ...,  0.9011,  1.0330,  0.6310],
        [-1.0230, -2.1573,  0.2695,  ...,  0.6123,  1.1428, -0.6075],
        [-0.4493,  0.5694,  1.0883,  ..., -0.2787,  0.0188,  0.2655],
        [ 0.0794, -2.2574, -1.1109,  ..., -0.2563, -0.1068, -0.9108]])

In [11]:
m = torch.nn.utils.weight_norm(c.fc)

In [12]:
m.weight_g

Parameter containing:
tensor([[18.8429],
        [20.0234],
        [19.0432],
        [20.5469],
        [22.1594]], requires_grad=True)

In [15]:
m.weight_v

Parameter containing:
tensor([[ 0.5491,  1.0450,  0.0804,  ...,  0.5097, -0.7651,  0.5821],
        [-2.0117, -1.7803,  0.5387,  ...,  0.9011,  1.0330,  0.6310],
        [-1.0230, -2.1573,  0.2695,  ...,  0.6123,  1.1428, -0.6075],
        [-0.4493,  0.5694,  1.0883,  ..., -0.2787,  0.0188,  0.2655],
        [ 0.0794, -2.2574, -1.1109,  ..., -0.2563, -0.1068, -0.9108]],
       requires_grad=True)

In [27]:
import torch
import pytorch_lightning as pl
from protoclr_obow import PCLROBoW
from clr_gat import CLRGAT
from feature_extractors import feature_extractor
from omegaconf import OmegaConf
from dataloaders import UnlabelledDataset, UnlabelledDataModule
import seaborn as sns
import matplotlib.pyplot as plt
import umap

%matplotlib inline
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
OmegaConf.register_new_resolver("uuid", lambda : "123")

In [17]:
pl.seed_everything(72)

Global seed set to 72


72

In [30]:
sys.modules["bow.feature_extractor"] = feature_extractor

In [31]:
ckp = torch.load("ckpts/cnn.ckpt", map_location="cpu")

In [32]:
state_keys = list(ckp["state_dict"].keys())

In [33]:
for _, k in enumerate(state_keys):
    newkey = f"model.{k}"
    ckp["state_dict"][newkey] = ckp["state_dict"].pop(k)

In [34]:
from clr_gat import CLRGAT

In [35]:
from omegaconf import OmegaConf
OmegaConf.register_new_resolver("uuid", lambda : "123")

ValueError: resolver 'uuid' is already registered

In [36]:
ckp["hyper_parameters"].pop("inner_lr")

0.001

In [37]:
ckp["hyper_parameters"]["mpnn_dev"] = "cpu"
model = CLRGAT(**ckp["hyper_parameters"])

MultiHeadDotProduct


In [39]:
model.load_state_dict(ckp["state_dict"])

<All keys matched successfully>

In [41]:
from dataloaders import get_episode_loader

In [42]:
dl = get_episode_loader("miniimagenet", "../data", 5, 5, 15, 1,"train",)

Supervised data loader for miniimagenet:train.


In [43]:
xs = next(iter(dl))

In [44]:
x = xs["train"][0].squeeze(0)
y = xs["train"][1].squeeze(0)

In [45]:
x.shape, y.shape

(torch.Size([25, 3, 84, 84]), torch.Size([25]))

In [47]:
from utils.sk_finetuning import sinkhorned_finetuning, euclidean_distance, SK

In [65]:
z = model(x)

In [67]:
c = Classifier(1600, 5)

In [68]:
c.init_params_from_prototypes(z, 5, 5)

In [69]:
m = torch.nn.utils.weight_norm(c.fc)

In [70]:
m.weight_v

Parameter containing:
tensor([[-0.0213, -0.0219, -0.0050,  ...,  0.0241, -0.0722,  0.0190],
        [ 0.0060,  0.0049, -0.0036,  ..., -0.0104, -0.0042,  0.0059],
        [ 0.0266, -0.0354,  0.0170,  ...,  0.0504, -0.0120, -0.0091],
        [-0.0247,  0.0256, -0.0163,  ...,  0.0220, -0.0342, -0.0278],
        [ 0.0060,  0.0018, -0.0174,  ..., -0.0229, -0.0622,  0.0200]],
       requires_grad=True)

In [75]:
protos = m.weight_v
protos.shape

torch.Size([5, 1600])

In [77]:
protos = einops.rearrange(protos, "p e -> 1 p e")
z = z.unsqueeze(0)

In [78]:
dists = euclidean_distance(protos, z).squeeze(0)

In [79]:
sk = SK()

In [81]:
scores = sk(dists.t())
scores, scores.shape

(tensor([[2.6501e-09, 4.9425e-04, 9.9902e-01, 1.1603e-04, 3.7000e-04],
         [5.6937e-11, 2.3512e-01, 2.8041e-03, 7.3721e-01, 2.4865e-02],
         [2.1959e-07, 1.3729e-01, 2.1495e-01, 2.8295e-02, 6.1947e-01],
         [5.1008e-08, 5.3235e-01, 4.5724e-01, 4.0467e-06, 1.0399e-02],
         [2.2126e-12, 8.7157e-01, 7.3341e-02, 5.5048e-02, 3.7577e-05],
         [7.5121e-02, 1.7411e-08, 3.1844e-02, 8.9302e-01, 1.6658e-05],
         [9.8494e-01, 1.2671e-07, 1.3343e-02, 5.5661e-04, 1.1648e-03],
         [2.8721e-01, 1.5998e-07, 2.1192e-01, 5.0087e-01, 1.8731e-06],
         [5.0585e-01, 5.4965e-08, 8.5483e-02, 4.3914e-04, 4.0823e-01],
         [5.1849e-03, 6.5397e-09, 4.2641e-02, 1.2462e-06, 9.5217e-01],
         [2.3823e-01, 3.1445e-01, 2.4392e-14, 1.6782e-01, 2.7950e-01],
         [7.7682e-01, 3.6804e-02, 8.6866e-09, 1.8519e-01, 1.1866e-03],
         [2.0936e-01, 7.0634e-01, 2.3060e-08, 1.3723e-04, 8.4169e-02],
         [9.5291e-01, 6.6465e-03, 8.6761e-09, 2.9142e-02, 1.1304e-02],
      

In [83]:
_, preds = torch.max(scores, dim=1)
preds

tensor([2, 3, 4, 1, 1, 3, 0, 3, 0, 4, 1, 0, 1, 0, 4, 4, 1, 2, 4, 1, 2, 2, 0, 3,
        3])

In [87]:
torch.unique(preds, return_counts=True)

(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 4, 5, 5]))

In [90]:
y

tensor([2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 4, 4, 4, 4,
        4])