In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torchvision.models import resnet50, vit_b_16, vit_b_32
from tqdm.autonotebook import tqdm
from copy import deepcopy
from cka import CKACalculator
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (7, 7)

  from tqdm.autonotebook import tqdm


## Setup DataLoader and Models 

An important detail is that although we are using the Validation set for `CIFAR10`, we **shuffle** and drop the last batch. This is to ensure that 1) the batches of each epoch are mixed, and 2) each iteration has the same batch size.

In [2]:
transforms = Compose([Resize(224),ToTensor(), 
                      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

dataset = CIFAR10(root='../data/', train=False, download=True, transform=transforms)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)

In [3]:
model1 = resnet50(pretrained=True).cuda()
model1.eval()
model2 = vit_b_32(pretrained=True).cuda()
model2.eval()
print('Dummy models created')



Dummy models created


## Compute CKA 

### Basic Usage 

Initializing the `CKACalculator` object will add forward hooks to both `model1` and `model2`. 
The default modules that are hooked are: `Bottleneck`, `BasicBlock`, `Conv2d`, `AdaptiveAvgPool2d`, `MaxPool2d`, and all instances of `BatchNorm`. 
Note that `Bottleneck` and `BasicBlock` are from the `torchvision` implementation, and will not add hooks to any custom implementations of `Bottleneck/BasicBlock`.

For ResNet18, a total of 50 hooks are added.

By default, the intermediate features are flattened with `flatten_hook_fn` and 10 epochs are run.

In [4]:
calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader)

No hook function provided. Using flatten_hook_fn.
126 Hooks registered. Total hooks: 126
No hook function provided. Using flatten_hook_fn.
87 Hooks registered. Total hooks: 87


In [5]:
model1

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
model2

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

Now we can calculate the CKA matrix 

In [None]:
cka_output = calculator.calculate_cka_matrix()
print(f"CKA output size: {cka_output.size()}")

Epoch 0:  11%|â–ˆ         | 1/9 [00:20<02:40, 20.06s/it]

#### Visualize the output

Note that the returned matrix has its origin at the top left. In most papers, the CKA matrix is visualized with its origin at the bottom left. Thus, we may flip the matrix first before visualization; however, this example chooses not to.

In [None]:
plt.imshow(cka_output.cpu().numpy(), cmap='inferno',origin='lower')

### Advanced Usage 

We can customize other parameters of the `CKACalculator`. 
Most importantly, we can select which modules to hook. 

Before instantiating a new instance of `CKACalculator` on, make sure to first call the `reset` method. 
This clears all hooks registered in the models.

In [None]:
# Reset calculator to clear hooks
calculator.reset()
torch.cuda.empty_cache()

In [None]:
import torch.nn as nn

Let's consider outputs of `Conv2d` and `BatchNorm2d` only. This will create 40 hooks.

For custom layers, add the custom modules in the same manner as shown below.

In [None]:
layers = (nn.Conv2d, nn.BatchNorm2d)

In [None]:
calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader, hook_layer_types=layers)

In [None]:
cka_output = calculator.calculate_cka_matrix()

#### Visualize output 

In [None]:
plt.imshow(cka_output.cpu().numpy(), cmap='inferno')

#### Extract module names 

In [None]:
for i, name in enumerate(calculator.module_names_X):
    print(f"Layer {i}: \t{name}")

In [9]:
# Reset calculator to clear hooks
calculator.reset()
torch.cuda.empty_cache()

126 handles removed.
87 handles removed.


In [None]:
model2 = vit_b_16(pretrained=True).cuda()
model2.eval()
print("vit_b_16 created")

In [None]:
model2

In [None]:
calculator = CKACalculator(model1=model1, model2=model2, dataloader=dataloader)

Now we can calculate the CKA matrix 

In [None]:
cka_output = calculator.calculate_cka_matrix()
print(f"CKA output size: {cka_output.size()}")

In [None]:
value = cka_output.cpu().numpy()

In [None]:
plt.imshow(value,cmap="inferno",origin='lower')