## Imports

In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from nn_core.common import PROJECT_ROOT
import random

from pathlib import Path
from la.utils.utils import MyDatasetDict

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from pytorch_lightning import seed_everything
import matplotlib.pyplot as plt
import random
from collections import namedtuple
import timm
from transformers import AutoModel, AutoProcessor
from typing import Sequence, List
from PIL.Image import Image
from tqdm import tqdm
import functools
from timm.data import resolve_data_config
from datasets import load_dataset, load_from_disk, Dataset, DatasetDict
import torchvision
import torch
from timm.data import create_transform
from la.utils.cka import CKA

import logging
from typing import Any, Sequence, Tuple, Union, Dict
import seaborn as sns
import plotly
import hydra
import torch
from hydra.utils import instantiate
from nn_core.model_logging import NNLogger
from omegaconf import DictConfig
from torch.optim import Optimizer
import pytorch_lightning as pl
import torch.nn as nn

from la.pl_modules.pl_module import MyLightningModule

pylogger = logging.getLogger(__name__)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import DataLoader
from la.utils.utils import add_tensor_column
from datasets import concatenate_datasets

## Data loading

In [None]:
data: MyDatasetDict = MyDatasetDict.load_from_disk("../data/cifar100/partitioned")

num_tasks = data["metadata"]["num_tasks"]

for task_ind in range(num_tasks + 1):
    data[f"task_{task_ind}_anchors"] = data["anchors"]

In [None]:
transform_func = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
for mode in {"train", "val", "test", "anchors"}:
    for task_ind in range(num_tasks + 1):
        data[f"task_{task_ind}_{mode}"] = data[f"task_{task_ind}_{mode}"].map(
            lambda x: {"x": transform_func(x["x"])}, batched=False
        )
        data[f"task_{task_ind}_{mode}"].set_format("torch", columns=["x", "y"])

## Model definition

In [None]:
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from nn_core.model_logging import NNLogger
from kornia.augmentation import (
    ColorJiggle,
    RandomChannelShuffle,
    RandomHorizontalFlip,
    RandomThinPlateSpline,
    RandomRotation,
    RandomCrop,
    Normalize,
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class ShakeShake(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x1, x2, training=True):
        if training:
            alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_()
            alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1)
        else:
            alpha = 0.5
        return alpha * x1 + (1 - alpha) * x2

    @staticmethod
    def backward(ctx, grad_output):
        beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_()
        beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
        beta = Variable(beta)

        return beta * grad_output, (1 - beta) * grad_output, None


class Shortcut(nn.Module):
    def __init__(self, in_ch, out_ch, stride):
        super(Shortcut, self).__init__()
        self.stride = stride
        self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False)
        self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        h = F.relu(x)

        h1 = F.avg_pool2d(h, 1, self.stride)
        h1 = self.conv1(h1)

        h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride)
        h2 = self.conv2(h2)

        h = torch.cat((h1, h2), 1)
        return self.bn(h)

In [None]:
import math


class ShakeBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super(ShakeBlock, self).__init__()
        self.equal_io = in_ch == out_ch
        self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride)

        self.branch1 = self._make_branch(in_ch, out_ch, stride)
        self.branch2 = self._make_branch(in_ch, out_ch, stride)

    def forward(self, x):
        h1 = self.branch1(x)
        h2 = self.branch2(x)
        h = ShakeShake.apply(h1, h2, self.training)
        h0 = x if self.equal_io else self.shortcut(x)
        return h + h0

    def _make_branch(self, in_ch, out_ch, stride=1):
        return nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=False),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch),
        )


class ShakeResNet(nn.Module):
    def __init__(self, depth, w_base, label):
        super(ShakeResNet, self).__init__()
        n_units = (depth - 2) / 6

        in_chs = [16, w_base, w_base * 2, w_base * 4]
        self.in_chs = in_chs

        self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
        self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1])
        self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2)
        self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2)
        self.fc_out = nn.Linear(in_chs[3], label)

        # Initialize paramters
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        h = self.c_in(x)
        h = self.layer1(h)
        h = self.layer2(h)
        h = self.layer3(h)
        h = F.avg_pool2d(h, 8)
        embeds = h.view(-1, self.in_chs[3])
        h = F.relu(embeds)

        h = self.fc_out(h)
        return {"logits": h, "embeds": embeds}

    def _make_layer(self, n_units, in_ch, out_ch, stride=1):
        layers = []
        for i in range(int(n_units)):
            layers.append(ShakeBlock(in_ch, out_ch, stride=stride))
            in_ch, stride = out_ch, 1
        return nn.Sequential(*layers)

In [None]:
class DataAugmentation(nn.Module):
    """Module to perform data augmentation using Kornia on torch tensors."""

    def __init__(self, input_dim) -> None:
        super().__init__()

        self.transforms = nn.Sequential(
            RandomHorizontalFlip(p=0.5),
            RandomRotation(degrees=30),
            RandomCrop((input_dim, input_dim)),
            ColorJiggle(0.2, 0.2, 0.2, 0.2, p=0.5),
        )

    def forward(self, x: Tensor) -> Tensor:
        x_out = self.transforms(x)  # BxCxHxW
        return x_out

In [None]:
class ResNet(MyLightningModule):
    logger: NNLogger

    def __init__(self, num_classes, input_dim, model, *args, **kwargs) -> None:
        super().__init__(num_classes=num_classes, input_dim=input_dim, *args, **kwargs)

        self.save_hyperparameters(logger=False, ignore=("metadata",))

        self.model = model
        self.data_augm = DataAugmentation(input_dim)

    def forward(self, x: torch.Tensor) -> Dict:
        """Method for the forward pass.

        'training_step', 'validation_step' and 'test_step' should call
        this method in order to compute the output predictions and the loss.

        Returns:
            output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
        """
        model_out = self.model(x)

        return model_out

    def configure_optimizers(
        self,
    ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.

        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Return:
            Any of these 6 options.
            - Single optimizer.
            - List or Tuple - List of optimizers.
            - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
            - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'
              key whose value is a single LR scheduler or lr_dict.
            - Tuple of dictionaries as described, with an optional 'frequency' key.
            - None - Fit will run without any optimizer.
        """
        optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

In [None]:
import copy


def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option="A", input_dim=32, outputs=None):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        # self.bn1 = nn.LayerNorm(input_dim)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        # self.bn2 = nn.LayerNorm(input_dim)
        self.outputs = outputs

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes),
                # nn.LayerNorm(input_dim),
            )

    def forward(self, x):
        conv_out1 = self.conv1(x)
        conv_out1 = self.bn1(conv_out1)
        out1 = F.relu(conv_out1)

        conv_out2 = self.conv2(out1)
        conv_out2 = self.bn2(conv_out2)

        out2 = conv_out2
        out2 += self.shortcut(x)

        self.outputs.extend([conv_out1, conv_out2])

        return out2


class ResNetModule(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ResNetModule, self).__init__()
        self.in_planes = 16

        self.layer0 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )

        self.outputs = []

        self.layer1 = self._make_layer(block, planes=16, num_blocks=num_blocks[0], stride=1, input_dim=32)
        self.layer2 = self._make_layer(block, planes=32, num_blocks=num_blocks[1], stride=2, input_dim=16)
        self.layer3 = self._make_layer(block, planes=64, num_blocks=num_blocks[2], stride=2, input_dim=8)

        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride, input_dim):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, input_dim=input_dim, outputs=self.outputs))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        self.outputs.clear()

        # (B, C, H, W)
        self.outputs.append(x)

        out = self.layer0(x)
        self.outputs.append(out)
        out = F.relu(out)

        out = self.layer1(out)

        out = F.relu(out)

        out = self.layer2(out)
        out = F.relu(out)

        out = self.layer3(out)

        out = F.avg_pool2d(out, out.size()[3])

        embeds = out.view(out.size(0), -1)

        out = F.relu(embeds)

        out = self.linear(embeds)
        self.outputs.append(out)

        return {"embeds": embeds, "logits": out, "outputs": self.outputs}


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

## Loop

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


def train_model(train_data, val_data, anchors, test_data, seed):
    seed_everything(seed)
    # torch.backends.cudnn.deterministic = True

    # module = ResNetModule(BasicBlock, [3, 3, 3], num_classes=100)
    module = ShakeResNet(depth=20, w_base=16, label=100)
    model = ResNet(num_classes=100, model=module, input_dim=32, transform_func=None)

    model.configure_optimizers()

    val_dataloader = DataLoader(val_data, batch_size=128, shuffle=False, num_workers=8)

    train_and_anchors = concatenate_datasets([train_data, anchors])

    dataloader = DataLoader(train_and_anchors, batch_size=128, shuffle=True, num_workers=8)

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=100,
        precision=32,
        callbacks=[EarlyStopping(monitor="loss/val", patience=10, mode="min")],
    )

    trainer.fit(model, dataloader, val_dataloader)

    test_dataloader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=8)

    trainer.test(model, test_dataloader)
    model.eval()

    model = model.cuda()

    return model

In [None]:
def embed_samples(model, test_data, anchors):

    test_dataloader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=8)

    test_embeds = []
    for batch in test_dataloader:
        x, y = batch["x"], batch["y"]
        x = x.cuda()
        y = y.cuda()

        embeds = model(x)["embeds"]
        test_embeds.extend(embeds.detach())

    test_embeds = torch.stack(test_embeds)

    anchor_dataloader = DataLoader(anchors, batch_size=128, shuffle=False, num_workers=8)
    anchor_embeds = []
    for batch in anchor_dataloader:
        x, y = batch["x"], batch["y"]
        x = x.cuda()
        y = y.cuda()

        embeds = model(x)["embeds"]
        anchor_embeds.extend(embeds.detach())

    anchor_embeds = torch.stack(anchor_embeds)

    return test_embeds, anchor_embeds

In [None]:
def compute_relatives(test_embeds, anchor_embeds, num_anchors, seed):
    seed_everything(seed)

    anchor_embeds = anchor_embeds[:num_anchors]
    norm_anchors = F.normalize(anchor_embeds, p=2, dim=-1)

    abs_space = F.normalize(test_embeds, p=2, dim=-1)

    rel_space = abs_space @ norm_anchors.T

    return rel_space

## Same data, different seed

In [None]:
model1_same = train_model(data["task_1_train"], data["task_1_val"], data["task_1_anchors"], data["task_0_test"], seed=0)

In [None]:
model2_same = train_model(
    data["task_1_train"], data["task_1_val"], data["task_1_anchors"], data["task_0_test"], seed=42
)

In [None]:
num_anchors = 256
seed = 0

In [None]:
test_embeds, anchor_embeds = embed_samples(
    model1_same,
    data["task_0_test"],
    data["task_1_anchors"],
)

rel_space = compute_relatives(test_embeds, anchor_embeds, num_anchors, seed)

test_data1 = add_tensor_column(data["task_0_test"], "relative_embeddings", rel_space)

In [None]:
test_embeds, anchor_embeds = embed_samples(
    model2_same,
    data["task_0_test"],
    data["task_1_anchors"],
)

rel_space = compute_relatives(test_embeds, anchor_embeds, num_anchors, seed)

test_data2 = add_tensor_column(data["task_0_test"], "relative_embeddings", rel_space)

In [None]:
test_data1.set_format("torch", columns=["relative_embeddings", "id"])
test_data2.set_format("torch", columns=["relative_embeddings", "id"])
test_data1 = test_data1.sort("id")
test_data2 = test_data2.sort("id")
assert torch.all(test_data1["id"] == test_data2["id"])

In [None]:
cka = CKA(mode="linear", device="cuda")

cka_score = cka(test_data1["relative_embeddings"], test_data2["relative_embeddings"])
print(cka_score)

## Investigate block structure

In [None]:
dataloader = DataLoader(data["task_0_test"], batch_size=512, shuffle=True, num_workers=0)

In [None]:
import numpy as np


def CKA_matrix(embeds1, embeds2):

    n1, n2 = len(embeds1), len(embeds2)

    cka_score_matrix = [[0.0 for i in range(n1)] for j in range(n2)]

    for i, layer_i_embed in enumerate(embeds1):
        for j, layer_j_embed in enumerate(embeds2):

            layer_i_embed = layer_i_embed.view(layer_i_embed.size(0), -1)
            layer_j_embed = layer_j_embed.view(layer_j_embed.size(0), -1)

            cka_score = cka(layer_i_embed, layer_j_embed)

            cka_score_matrix[n1 - 1 - i][j] = cka_score.cpu().detach().numpy()

    cka_score_matrix = np.array(cka_score_matrix)

    return cka_score_matrix

In [None]:
# make a pretty heatmap with colorbar
import seaborn as sns
import matplotlib.pyplot as plt


def plot_heatmap(heatmap, with_values=False):
    fig, ax = plt.subplots(figsize=(10, 10))

    sns.heatmap(heatmap, ax=ax, cmap="rocket", square=True, cbar_kws={"label": "CKA Score"})
    ax.set_yticklabels([str(i) for i in range(len(heatmap) - 1, -1, -1)])
    if with_values:
        for i in range(heatmap.shape[0]):
            for j in range(heatmap.shape[1]):
                ax.text(j + 0.5, i + 0.5, f"{heatmap[i, j]:.2f}", ha="center", va="center", color="w")
    ax.set_xlabel("Layer")
    ax.set_ylabel("Layer")
    ax.set_title("CKA Score Matrix")
    plt.show()

In [None]:
batch = next(iter(dataloader))

x, y = batch["x"], batch["y"]
x = x.cuda()
y = y.cuda()

### Model with itself

In [None]:
embeds1 = model1_same(x)["outputs"]

cka_score_matrix = CKA_matrix(embeds1, embeds1)
print(cka_score_matrix)

In [None]:
plot_heatmap(cka_score_matrix, with_values=False)

In [None]:
embeds1 = model1_same(x)["outputs"]
embeds2 = model2_same(x)["outputs"]

cka_score_matrix = CKA_matrix(embeds1, embeds2)
print(cka_score_matrix)

In [None]:
plot_heatmap(cka_score_matrix, with_values=False)

## Different splits

In [None]:
model1 = train_model(data["task_1_train"], data["task_1_anchors"], data["task_0_test"], seed=42)

In [None]:
model2 = train_model(data["task_2_train"], data["task_2_anchors"], data["task_0_test"], seed=42)

In [None]:
test_embeds, anchor_embeds = embed_samples(
    model1,
    data["task_0_test"],
    data["task_1_anchors"],
)

rel_space = compute_relatives(test_embeds, anchor_embeds, num_anchors, seed)

test_data1 = add_tensor_column(data["task_0_test"], "relative_embeddings", rel_space)

In [None]:
test_embeds, anchor_embeds = embed_samples(
    model2,
    data["task_0_test"],
    data["task_2_anchors"],
)

rel_space = compute_relatives(test_embeds, anchor_embeds, num_anchors, seed)

test_data2 = add_tensor_column(data["task_0_test"], "relative_embeddings", rel_space)

In [None]:
test_data1.set_format("torch", columns=["relative_embeddings"])
test_data2.set_format("torch", columns=["relative_embeddings"])
test_data1 = test_data1.sort("id")
test_data2 = test_data2.sort("id")

In [None]:
from la.utils.cka import CKA

cka = CKA(mode="linear", device="cuda")

cka_score = cka(test_data1["relative_embeddings"], test_data2["relative_embeddings"])
print(cka_score)

In [None]:
centered1 = test_data1["relative_embeddings"] - test_data1["relative_embeddings"].mean(dim=0, keepdim=True)
centered2 = test_data2["relative_embeddings"] - test_data2["relative_embeddings"].mean(dim=0, keepdim=True)

cka_score = cka(centered1, centered2)
print(cka_score)