# Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb netcal

# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

from transformer_lens2 import HookedEncoder, HookedTransformerConfig
import numpy as np

# Dataset

In [None]:
DATASET = 'CivilComments'  # 'CivilComments' , 'MultiNLI'

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    # train_dataset = MultiNLI(data_path, 'tr', hparams)
    # val_dataset = MultiNLI(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = MultiNLI(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_mnli'
    representations_path = path_to_root + '/representations/representations_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    # train_dataset = CivilComments(data_path, 'tr', hparams, granularity="fine")
    # val_dataset = CivilComments(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = CivilComments(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_civilcomments'
    representations_path = path_to_root + '/representations/representations_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

# Load Reprs

In [None]:
SEED = 2
algorithm_names =  ['random', 'randominit', 'pretrained', 'erm', 'groupdro', 'focal', 'jtt', 'lff', ]
REPRS = torch.load(f'{representations_path}/seed{SEED}'+'_reprs')

In [None]:
CAT_REPRS = {}

## Add the rest
for algorithm_key in algorithm_names:
    CAT_REPRS[algorithm_key] = {}
    for layer_key in tqdm.tqdm(REPRS[algorithm_key].keys()):
        CAT_REPRS[algorithm_key][layer_key] = []
        for y_key in REPRS[algorithm_key][layer_key].keys():
            for a_key in REPRS[algorithm_key][layer_key][y_key].keys():
                CAT_REPRS[algorithm_key][layer_key].append(REPRS[algorithm_key][layer_key][y_key][a_key])
        CAT_REPRS[algorithm_key][layer_key] = torch.cat(CAT_REPRS[algorithm_key][layer_key])

In [None]:
CAT_REPRS['erm']['layer1'].shape

# CKA Cuda Implementation

In [None]:
# From https://github.com/jayroxis/CKA-similarity/blob/main - modified

import math

class CudaCKA(object):
    def __init__(self, device):
        self.device = device

    def centering(self, K):
        n = K.shape[0]
        unit = torch.ones([n, n], device=self.device)
        I = torch.eye(n, device=self.device)
        H = I - unit / n
        return torch.matmul(torch.matmul(H, K), H)

    def rbf(self, X, sigma=None):
        GX = torch.matmul(X, X.T)
        KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
        if sigma is None:
            mdist = torch.median(KX[KX != 0])
            sigma = math.sqrt(mdist)
        KX *= - 0.5 / (sigma * sigma)
        KX = torch.exp(KX)
        return KX

    def kernel_HSIC(self, X, Y, sigma):
        return torch.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)))

    def linear_HSIC(self, X, Y):
        L_X = torch.matmul(X, X.T)
        L_Y = torch.matmul(Y, Y.T)
        return torch.sum(self.centering(L_X) * self.centering(L_Y))

    def linear_CKA(self, X, Y):
        hsic = self.linear_HSIC(X, Y)
        var1 = torch.sqrt(self.linear_HSIC(X, X))
        var2 = torch.sqrt(self.linear_HSIC(Y, Y))
        return hsic / (var1 * var2)

    def kernel_CKA(self, X, Y, sigma=None):
        hsic = self.kernel_HSIC(X, Y, sigma)
        var1 = torch.sqrt(self.kernel_HSIC(X, X, sigma))
        var2 = torch.sqrt(self.kernel_HSIC(Y, Y, sigma))
        return hsic / (var1 * var2)

device = torch.device('cuda')
cka = CudaCKA(device)

# Run CKA

In [None]:
import numpy as np
from sklearn.metrics.pairwise import linear_kernel, rbf_kernel

algorithms = list(CAT_REPRS.keys())
layers = list(CAT_REPRS[algorithms[0]].keys())

cka_results = {layer: np.zeros((len(algorithms), len(algorithms))) for layer in layers}

for layer in tqdm.tqdm(layers):
    for i, algo1 in enumerate(algorithms):
        for j, algo2 in enumerate(algorithms):
            if i <= j:  
                cka_results[layer][i, j] = cka.linear_CKA(CAT_REPRS[algo1][layer].cuda(), CAT_REPRS[algo2][layer].cuda())
                cka_results[layer][j, i] = cka_results[layer][i, j]

for layer in layers:
    print(f"CKA Similarity for Layer {layer}:")
    print(cka_results[layer])

In [None]:
torch.save(cka_results, path_to_root + f'/results/CKA/{DATASET}_seed{SEED}.pth')

# Create Figure

In [None]:
cka_results = [torch.load(path_to_root + f'/results/CKA/{DATASET}_seed{SEED}.pth') for SEED in [0,1,2]]

cka_mean = {}
for layer in range(0,12):
    cka_mean[f'layer{layer}'] = np.zeros([8,8])
for s in [0,1,2]:
    for layer in range(0,12):
        curr = cka_results[s][f'layer{layer}']
        curr[curr < 0 ] = 0 # there are results that are very close to zero but negative due to numerical issues
        cka_mean[f'layer{layer}'] += curr/3
cka_results = cka_mean