In [1]:
! pip install surgeon-pytorch

Collecting surgeon-pytorch
  Downloading surgeon_pytorch-0.0.4-py3-none-any.whl.metadata (649 bytes)
Collecting data-science-types>=0.2 (from surgeon-pytorch)
  Downloading data_science_types-0.2.23-py3-none-any.whl.metadata (5.4 kB)
Downloading surgeon_pytorch-0.0.4-py3-none-any.whl (6.1 kB)
Downloading data_science_types-0.2.23-py3-none-any.whl (42 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.7/42.7 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: data-science-types, surgeon-pytorch
Successfully installed data-science-types-0.2.23 surgeon-pytorch-0.0.4


In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
#from surgeon_pytorch import Inspect, get_layers
import os, random, pathlib, warnings, itertools, math
from torchvision.transforms import Resize, ToTensor
from torch.utils.data import DataLoader, Subset
from PIL import Image
from transformers import ViTModel, ViTImageProcessor
from transformers import AutoModel, AutoFeatureExtractor
import seaborn as sns


Download Data

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("khalidboussaroual/2d-geometric-shapes-17-shapes")

print("Path to dataset files:", path)

from google.colab import files
files.upload()  # This will prompt you to upload the `kaggle.json` file

In [None]:
# Setup Kaggle authentication
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download the dataset
!kaggle datasets download -d khalidboussaroual/2d-geometric-shapes-17-shapes -p /content/data --unzip


In [None]:
data_dir = '/content/data/2D_Geometric_Shapes_Dataset'
categories = os.listdir(data_dir)
categories

Define Function

In [3]:
# Function to count classes and their samples
def count_classes_and_samples(data_path):
    class_counts = {}
    class_names = []

    for label in os.listdir(data_path):
        label_path = os.path.join(data_path, label)
        if os.path.isdir(label_path):
            num_samples = len(os.listdir(label_path))
            class_counts[label] = num_samples
            class_names.append(label)

    return class_names, class_counts

def plot_class_samples(data_path, num_samples=3):
    #class_names, class_counts = count_classes_and_samples(data_path)
    class_names = ['square', 'circle', 'triangle', 'star', 'trapezoid']
    class_counts = 5

    print(f"Number of classes: {len(class_names)}")
    print(f"Class names and sample counts: {class_counts}")

    # Adjust figure size for smaller images
    plt.figure(figsize=(6, 6))

    for idx, class_name in enumerate(class_names):
        class_path = os.path.join(data_path, class_name)
        images = os.listdir(class_path)[:num_samples]

        for i, image_name in enumerate(images):
            image_path = os.path.join(class_path, image_name)
            try:
                image = Image.open(image_path).convert("RGB")

                # Subplots with smaller images
                plt.subplot(len(class_names), num_samples, idx * num_samples + i + 1)
                plt.imshow(image)
                plt.xticks([])
                plt.yticks([])
                if i == 1:  # Center the class name over the middle image
                    plt.title(class_name, fontsize=8)
            except Exception as e:
                print(f"Error loading image {image_path}: {e}")

    plt.tight_layout()
    save_path = "dataset.png"
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()

# define CKA
def center_gram(gram):
    """Center a gram matrix."""
    if isinstance(gram, torch.Tensor):
        gram = gram.numpy()  # Convert to NumPy

    n = gram.shape[0]
    unit = np.ones((n, n))
    identity = np.eye(n)
    return gram - unit @ gram / n - gram @ unit / n + unit @ gram @ unit / (n * n)

def linear_CKA(X, Y):
    """Compute linear CKA similarity between two matrices X and Y."""
    # Center the gram matrices
    X_centered = center_gram(X @ X.T)
    Y_centered = center_gram(Y @ Y.T)

    # Compute the CKA similarity
    numerator = np.trace(X_centered @ Y_centered)
    denominator = np.sqrt(np.trace(X_centered @ X_centered) * np.trace(Y_centered @ Y_centered))
    return numerator / denominator if denominator != 0 else 0


# Define a function to process a batch
def collate_fn(batch):
    images, labels = zip(*batch)
    # get image name
    image_names = [os.path.basename(img_path) for img_path in images]
    # Preprocess images using ViTImageProcessor
    pixel_values = processor(images=[Image.open(img).convert("RGB") for img in images], return_tensors="pt")

    #images = [processor(image, return_tensors="pt")["pixel_values"] for image in images]
    # Concatenate into a single tensor
    #images = torch.cat(images, dim=0)

    #return images, torch.tensor(labels), image_names
    return pixel_values["pixel_values"], torch.tensor(labels)


# Extract Representations from Selected Layers
def extract_representations(model, layers, dataloader):
    layer_outputs = {layer: [] for layer in layers}

    def hook_fn(module, input, output):
        layer_outputs[module.name].append(output.detach())

    hooks = []
    for name, module in model.named_modules():
        if name in layers:
            module.name = name
            hooks.append(module.register_forward_hook(hook_fn))

    # Pass data through the model
    for batch in dataloader:

        images, labels = batch
        with torch.no_grad():

            _ = model(images)

    # Aggregate results and remove hooks
    for hook in hooks:
        hook.remove()

    return {layer: torch.cat(layer_outputs[layer]) for layer in layers}




Data preparation

In [None]:
from PIL import Image
from transformers import ViTModel, ViTImageProcessor
from transformers import AutoModel

# Prepare dataset
result = ['square', 'circle', 'triangle', 'star', 'trapezoid']
data = []
labels = []

for class_idx, category in enumerate(result):
    category_path = os.path.join(data_dir, category)
    if os.path.isdir(category_path):

        selected_images = [
            os.path.join(category_path, img)
            for img in os.listdir(category_path)
            if img.endswith((".png", ".jpg", ".jpeg")) and
               any(img.endswith(f"_{i}.{ext}") for i in range(1, 16) for ext in ["png", "jpg", "jpeg"])
        ]

        #selected_images = image_files[:num_images_per_class]
        data.extend(selected_images)
        labels.extend([class_idx] * len(selected_images))


# Wrap data into a PyTorch Dataset
dataset = list(zip(data, labels))



Plot the sample data

In [None]:
# Execute the plotting function
plot_class_samples(data_dir, num_samples=5)

CKA analysis in model

In [None]:
model_names = ["google/vit-large-patch16-224",
              "facebook/vit-mae-large", "microsoft/beit-large-patch16-512"]

cka_similarities = []

for n in range(3):
# step 1: Initialize ViT Image Processor
  model_name = model_names[n]  # supervised learning
  processor = ViTImageProcessor.from_pretrained(model_name)

# step 2: Create DataLoader
  dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Step 3: Load the Vision Transformer model
  model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
  model.eval()

# step 4. Define layers to analyze
  layers_to_analyze = [f'encoder.layer.{i}.output.dense' for i in range(1, 24)]
  #layers_to_analyze = [f'encoder.layer.{i}.mlp.fc2' for i in range(1, 24)]

# step 5. Extract Representations from Selected Layers
  representations = extract_representations(model, layers_to_analyze, dataloader)

# Step 6: Compute CKA for Layer Representations
  layer_keys = list(representations.keys())
  n_layers = len(layer_keys)
  cka_similarity = torch.zeros((n_layers, n_layers))

  for i in range(n_layers):
    for j in range(n_layers):
        X = representations[layer_keys[i]].view(representations[layer_keys[i]].shape[0], -1)
        Y = representations[layer_keys[j]].view(representations[layer_keys[j]].shape[0], -1)
        cka = linear_CKA(X, Y)
        cka_similarity[i, j] = cka

  cka_similarities.append(cka_similarity)




for model facebook/dinov2-large-imagenet1k-1-layer


In [None]:
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader, Subset

# Define transformations
transform = Compose([
    Resize((224, 224)),  # Resize to model input size
    ToTensor(),          # Convert image to tensor
])

# Load the dataset using ImageFolder
dataset_fordino = ImageFolder(root=data_dir, transform=transform)

# Get class-to-index mapping
class_to_idx = dataset_fordino.class_to_idx

# Filter dataset for specific classes
filtered_indices = []
for target_class in result:
    class_index = class_to_idx[target_class]
    # Get all indices for the selected class
    class_indices = [i for i, (_, label) in enumerate(dataset) if label == class_index]
    # Take only the first 10 images for this class
    filtered_indices.extend(class_indices[:10])

# Create a subset of the dataset
filtered_dataset = Subset(dataset_fordino, filtered_indices)

# Create DataLoader
dataloader_fordino = DataLoader(filtered_dataset, batch_size=8, shuffle=True)

model_name = "facebook/dinov2-large-imagenet1k-1-layer"  # supervised learning

# Step 3: Load the Vision Transformer model
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
model.eval()

# step 4. Define layers to analyze
layers_to_analyze = [f'encoder.layer.{i}.mlp.fc2' for i in range(1, 24)]
# step 5. Extract Representations from Selected Layers
representations = extract_representations(model, layers_to_analyze, dataloader_fordino)

# Step 6: Compute CKA for Layer Representations
layer_keys = list(representations.keys())
n_layers = len(layer_keys)
cka_similarity = torch.zeros((n_layers, n_layers))

for i in range(n_layers):
   for j in range(n_layers):
        X = representations[layer_keys[i]].view(representations[layer_keys[i]].shape[0], -1)
        Y = representations[layer_keys[j]].view(representations[layer_keys[j]].shape[0], -1)
        cka = linear_CKA(X, Y)
        cka_similarity[i, j] = cka

cka_similarities.append(cka_similarity)


"vit_large_patch16_224.augreg_in21k"


In [None]:
pip install timm

In [None]:
import timm
from torchvision.datasets import ImageFolder
from torchvision.transforms import Resize, ToTensor, Compose
from sklearn.manifold import TSNE
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

model_name = "vit_large_patch16_224.augreg_in21k"
model = timm.create_model(model_name, pretrained=True)
model.eval()

# Define the dataset
transform = Compose([
    Resize((224, 224)),  # Resize to match model input size
    ToTensor(),  # Convert image to PyTorch tensor
])
# Load the dataset using ImageFolder
dataset_fortimm= ImageFolder(root=data_dir, transform=transform)

# Create a subset of the dataset
filtered_dataset = Subset(dataset_fortimm, filtered_indices)

# Create DataLoader
dataloader_fortimm = DataLoader(filtered_dataset, batch_size=8, shuffle=True)

# Example usage
#for batch in dataloader_fortimm:
#    images, labels = batch
#    print("Batch image tensor shape:", images.shape)  # Should be [batch_size, 3, 224, 224]
#    print("Batch labels:", labels)
#    break

# Define the layers to analyze
layers_to_analyze = [f'blocks.{i}.mlp.fc2' for i in range(1, 24)]

representations = extract_representations(model, layers_to_analyze, dataloader_fortimm)

# Step 6: Compute CKA for Layer Representations
layer_keys = list(representations.keys())
n_layers = len(layer_keys)
cka_similarity = torch.zeros((n_layers, n_layers))

for i in range(n_layers):
   for j in range(n_layers):
        X = representations[layer_keys[i]].view(representations[layer_keys[i]].shape[0], -1)
        Y = representations[layer_keys[j]].view(representations[layer_keys[j]].shape[0], -1)
        cka = linear_CKA(X, Y)
        cka_similarity[i, j] = cka

cka_similarities.append(cka_similarity)


Plot

In [None]:
# plot
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

model_names = ["vit","mae", "beit", 'dinov2', "vit_augreg"]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))  # 2x2 grid of subplots

# Flatten axes for easy iteration
axes = axes.flatten()

for i in range(5):
    ax = axes[i]
    sns.heatmap(cka_similarities[i], cmap='coolwarm', ax=ax, cbar=True,
                cbar_kws={'orientation': 'horizontal', 'pad': 0.1})
    ax.set_title(model_names[i])

# Turn off the last axis (the empty one)
axes[-1].axis('off')

# Adjust layout to make space for the color bar
plt.subplots_adjust(bottom=0.2)  # Increase bottom space to ensure room for color bar

fig.suptitle('Comparison of CKA Similarities Across Models', fontsize=16)

# Display the plot
# Display the plot
save_path = "_plots.png"
plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.show()

compute the correlation between the 5 heatmaps to quantify their similarity.

In [None]:
from itertools import product
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity

def compute_correlation(matrix1, matrix2):
    vec1 = matrix1.flatten()
    vec2 = matrix2.flatten()
    return np.corrcoef(vec1, vec2)[0, 1]

def compute_pearson(matrix1, matrix2):
    vec1 = matrix1.flatten()
    vec2 = matrix2.flatten()
    spearman_corr, _ = spearmanr(vec1, vec2)
    return spearman_corr

def compute_cosine(matrix1, matrix2):
    vec1 = matrix1.flatten()
    vec2 = matrix2.flatten()
    cosine_sim = cosine_similarity(vec1.reshape(1, -1), vec2.reshape(1, -1))[0, 0]
    return cosine_sim

# Store pairwise results
corr_matrix = np.zeros((len(cka_similarities), len(cka_similarities)))
pear_matrix = np.zeros((len(cka_similarities), len(cka_similarities)))
cosine_matrix = np.zeros((len(cka_similarities), len(cka_similarities)))

# Fill the matrix with pairwise correlations
for i, j in product(range(len(cka_similarities)), repeat=2):
    corr_matrix[i, j] = compute_correlation(cka_similarities[i], cka_similarities[j])
    pear_matrix[i, j] = compute_pearson(cka_similarities[i], cka_similarities[j])
    cosine_matrix[i, j] = compute_cosine(cka_similarities[i], cka_similarities[j])

# Convert to a DataFrame for readability
corr_df = pd.DataFrame(corr_matrix, columns=model_names,
                       index=model_names)

pear_df = pd.DataFrame(pear_matrix, columns=model_names,
                       index=model_names)

cosine_df = pd.DataFrame(cosine_matrix, columns=model_names,
                       index=model_names)





In [None]:
# Create a figure with 1 row and 3 columns
fig, axs = plt.subplots(1, 3, figsize=(30, 10))  # 1 row, 3 columns

# Plot the first heatmap
sns.heatmap(corr_df, annot=True, cmap="coolwarm", fmt=".2f", ax=axs[0])
axs[0].set_title("Pairwise Correlation Between models")

# Plot the second heatmap
sns.heatmap(pear_df, annot=True, cmap="coolwarm", fmt=".2f", ax=axs[1])
axs[1].set_title("Pairwise Pearson Correlation Between models")

# Plot the third heatmap
sns.heatmap(cosine_df, annot=True, cmap="coolwarm", fmt=".2f", ax=axs[2])
axs[2].set_title("Pairwise cosine similarity Between models")

# Adjust spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()