In [111]:
import os
import csv
import random
import zipfile
import requests
from io import BytesIO
from types import MethodType
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

import clip
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


In [112]:
def set_all_seeds(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [113]:
set_all_seeds(42)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from types import MethodType
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def wrap_vit_blocks_dino(model):
    activations = {}
    original_blocks = model.blocks  

    for i, block in enumerate(original_blocks):
        def make_custom_forward(orig_forward, layer_name):
            def custom_forward(self, x):
                out = orig_forward(x)
                activations[layer_name] = out.detach()
                return out
            return custom_forward

        block.forward = MethodType(make_custom_forward(block.forward, f"layer_{i}"), block)

    return activations

def logit_lens_analysis_dino(activations, model, final_cls_token, temperature=1.0, true_class_idx=None):
    distances = {}
    predictions = {}
    true_class_probs = []
    first_top_class_probs = []
    last_layer_top_class_probs = []
    last_layer_second_top_class_probs = []
    last_layer_third_top_class_probs = []
    random_class_probs = []
    kl_divergence = []

    first_layer_top_class = None
    last_layer_top_class = None

    last_layer_name = list(activations.keys())[-1]

    last_layer_activ = activations[last_layer_name]
    last_cls_token = last_layer_activ[:, 0, :]
    last_normed = model.norm(last_cls_token)
    last_logits = model.head(last_normed)
    last_probs = F.softmax(last_logits / temperature, dim=-1)
    topk_probs, topk_indices = torch.topk(last_probs, k=3, dim=-1)  

    last_layer_top_class = topk_indices[0][0].item()    # top-1
    last_layer_second_top_class = topk_indices[0][1].item()  # top-2
    last_layer_third_top_class = topk_indices[0][2].item()   # top-3


    for i, (name, x) in enumerate(activations.items()):
        cls_token = x[:, 0, :]
        normed = model.norm(cls_token)

        similarity = F.cosine_similarity(normed, final_cls_token, dim=-1)
        distances[name] = similarity.detach().cpu().item()

        logits = model.head(normed)
        probs = F.softmax(logits / temperature, dim=-1)

        if i == len(activations) - 1:
            kl_div = torch.zeros(probs.shape[0], device=probs.device)
        else:
            last_probs = F.softmax(last_logits / temperature, dim=-1)
            kl_div = F.kl_div(probs.log(), last_probs, reduction='none').sum(dim=-1)

        top_prob, top_class = torch.max(probs, dim=-1)
        predictions[f"{name}_label"] = int(top_class[0].cpu().item())
        predictions[f"{name}_prob"] = float(top_prob[0].cpu().item())

        if true_class_idx is not None:
            true_class_probs.append(float(probs[0, true_class_idx].item()))
        else:
            true_class_probs[name] = None

        if i == 0:
            first_layer_top_class = int(top_class[0].cpu().item())

        if first_layer_top_class is not None:
           first_top_class_probs.append(float(probs[0, first_layer_top_class].item()))
        else:
            first_top_class_probs[name] = None
    
        if last_layer_top_class is not None:
            last_layer_top_class_probs.append(float(probs[0, last_layer_top_class].item()))
        else:
            last_layer_top_class_probs.append(None)
        
        if last_layer_second_top_class is not None:
            last_layer_second_top_class_probs.append(float(probs[0, last_layer_second_top_class].item()))
        else:
            last_layer_second_top_class_probs.append(None)
        
        if last_layer_third_top_class is not None:
           last_layer_third_top_class_probs.append(float(probs[0, last_layer_third_top_class].item()))
        else:
            last_layer_third_top_class_probs.append(None)

        if kl_div is not None:
            kl_divergence.append(float(kl_div.mean().item()))
    
    return distances, predictions, true_class_probs, first_top_class_probs, last_layer_top_class_probs, last_layer_second_top_class_probs, last_layer_third_top_class_probs, kl_divergence, first_layer_top_class, last_layer_top_class, last_layer_second_top_class, last_layer_third_top_class

def load_tiny_imagenet_labels(path="tiny-imagenet-200/words.txt"):
    wnid_to_label = {}
    with open(path, "r") as f:
        for line in f:
            wnid, label = line.strip().split("\t")
            wnid_to_label[wnid] = label
    return wnid_to_label

def perform_logit_lens_analysis(model, dataset, device,
                                cosine_path = "logit_lens_results/cosine_similarity.csv", 
                                preds_path = "logit_lens_results/predictions.csv",
                                true_probs_path = "logit_lens_results/true_class_probs.csv",
                                first_top_probs_path = "logit_lens_results/first_top_class_probs.csv",
                                last_layer_probs_path = "logit_lens_results/last_layer_top_class_probs.csv",
                                second_last_layer_probs_path = "logit_lens_results/second_last_layer_top_class_probs.csv",
                                third_last_layer_probs_path = "logit_lens_results/third_last_layer_top_class_probs.csv",
                                kl_divergence_path = "logit_lens_results/kl_divergence.csv"):
    model.eval()
    idx_to_wnid = {v: k for k, v in dataset.dataset.class_to_idx.items()}

    os.makedirs("logit_lens_results", exist_ok=True)

    activations = wrap_vit_blocks_dino(model)  
    headers = [f"layer_{i}" for i in range(len(model.blocks))]
    wnid_to_label = load_tiny_imagenet_labels()

    for image_idx, (image, label) in enumerate(dataset):
        image = image.unsqueeze(0).to(device)
        
        true_wnid = idx_to_wnid[label] 
        true_label = wnid_to_label.get(true_wnid, "") 

        with torch.no_grad():
            model.activations = {}  
            features = model.forward_features(image)     
            cls_token = features[:, 0, :]  
            final_cls_token = model.norm(cls_token)     
            final_output = model.head(final_cls_token)   

            true_class_idx = label.item() if hasattr(label, 'item') else int(label)
            distances, predictions, true_class_probs, first_top_probs, last_layer_probs, second_last_layer_probs, third_last_layer_probs, kl_div, first_class_idx, last_layer_top_class_idx, last_layer_second_top_class_idx, last_layer_third_top_class_idx = logit_lens_analysis_dino(
                activations, model, final_cls_token, true_class_idx=true_class_idx)
            
        # --- KL DIVERGENCE FILE ---
        os.makedirs(os.path.dirname(kl_divergence_path), exist_ok=True)
        kl_header = ['Image', 'True_WNID', 'True_Label'] + headers
        write_header = not os.path.exists(kl_divergence_path) or os.path.getsize(kl_divergence_path) == 0
        with open(kl_divergence_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(kl_header)
            kl_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + kl_div
            writer.writerow(kl_row)


        # --- COSINE FILE ---
        os.makedirs(os.path.dirname(cosine_path), exist_ok=True)
        cosine_header = ['Image', 'True_WNID', 'True_Label'] + headers
        write_header = not os.path.exists(cosine_path) or os.path.getsize(cosine_path) == 0
        with open(cosine_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(cosine_header)
            cosine_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + \
                        [distances[layer] for layer in headers]
            writer.writerow(cosine_row)

        # --- PREDICTIONS FILE ---
        os.makedirs(os.path.dirname(preds_path), exist_ok=True) 
        pred_header = ['Image', 'True_WNID', 'True_Label'] + \
            [f"{layer}_label" for layer in headers] + \
            [f"{layer}_prob" for layer in headers]
        write_header = not os.path.exists(preds_path) or os.path.getsize(preds_path) == 0
        with open(preds_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(pred_header)
            pred_probs = [predictions[f"{layer}_prob"] for layer in headers]
            pred_indices = [predictions[f"{layer}_label"] for layer in headers]
            pred_wnids = [idx_to_wnid[int(idx)] for idx in pred_indices]
            pred_labels = [wnid_to_label.get(wnid, wnid) for wnid in pred_wnids]

            pred_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + pred_labels + pred_probs
            writer.writerow(pred_row)

        # --- TRUE CLASS PROBS FILE ---
        os.makedirs(os.path.dirname(true_probs_path), exist_ok=True)
        true_header = ['Image', 'True_WNID', 'True_Label'] + headers

        write_header = not os.path.exists(true_probs_path) or os.path.getsize(true_probs_path) == 0
        with open(true_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(true_header)
            true_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + true_class_probs

            writer.writerow(true_probs_row)

        # --- FIRST PREDICTED CLASS PROBS FILE ---
        os.makedirs(os.path.dirname(first_top_probs_path), exist_ok=True)
        first_top_header = ['Image', 'True_WNID', 'True_Label', 'First_Top1_WNID', 'First_Top1_Label'] + headers

        # Get first top class info
        first_top_idx = predictions.get(f"{headers[0]}_label", None)
        first_top_wnid = idx_to_wnid.get(first_top_idx, "N/A")
        first_top_label = wnid_to_label.get(first_top_wnid, "N/A")

        write_header = not os.path.exists(first_top_probs_path) or os.path.getsize(first_top_probs_path) == 0
        with open(first_top_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(first_top_header)
            first_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, first_top_wnid, first_top_label] + first_top_probs  
            writer.writerow(first_probs_row)


        # --- LAST LAYER TOP CLASS PROBS FILE ---
        os.makedirs(os.path.dirname(last_layer_probs_path), exist_ok=True)
        last_layer_header = ['Image', 'True_WNID', 'True_Label', 'Last_Top1_WNID', 'Last_Top1_Label'] + headers

        last_layer_top_idx = predictions.get(f"{headers[-1]}_label", None)
        last_layer_top_wnid = idx_to_wnid.get(last_layer_top_idx, "N/A")
        last_layer_top_label = wnid_to_label.get(last_layer_top_wnid, "N/A")
        write_header = not os.path.exists(last_layer_probs_path) or os.path.getsize(last_layer_probs_path) == 0
        with open(last_layer_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(last_layer_header)
            last_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, last_layer_top_wnid, last_layer_top_label] + last_layer_probs  
            writer.writerow(last_probs_row)

        # --- SECOND LAST LAYER TOP CLASS PROBS FILE ---
        second_last_layer_probs_header = ['Image', 'True_WNID', 'True_Label', 'Second_Last_Top1_WNID', 'Second_Last_Top1_Label'] + headers
        second_last_wnid = idx_to_wnid.get(last_layer_second_top_class_idx, "N/A")
        second_last_label = wnid_to_label.get(second_last_wnid, "N/A")
        write_header = not os.path.exists(second_last_layer_probs_path) or os.path.getsize(second_last_layer_probs_path) == 0
        with open(second_last_layer_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(second_last_layer_probs_header)
            second_last_layer_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, second_last_wnid, second_last_label] + second_last_layer_probs
            writer.writerow(second_last_layer_probs_row)
            
        # --- THIRD LAST LAYER TOP CLASS PROBS FILE ---
        third_last_layer_probs_header = ['Image', 'True_WNID', 'True_Label', 'Third_Last_Top1_WNID', 'Third_Last_Top1_Label'] + headers
        third_last_wnid = idx_to_wnid.get(last_layer_third_top_class_idx, "N/A")
        third_last_label = wnid_to_label.get(third_last_wnid, "N/A")
        write_header = not os.path.exists(third_last_layer_probs_path) or os.path.getsize(third_last_layer_probs_path) == 0
        with open(third_last_layer_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(third_last_layer_probs_header)
            third_last_layer_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, third_last_wnid, third_last_label] + third_last_layer_probs
            writer.writerow(third_last_layer_probs_row)


In [None]:

# config
RANDOM_SEED = 42
SUBSET_FRACTION = 0.1  # 0.05 for 5%, 0.1 for 10%
BATCH_SIZE = 64
VAL_FRACTION = 0.1 

# transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])

train_dir = os.path.join("./tiny-imagenet-200", "train")
full_dataset = datasets.ImageFolder(train_dir, transform=transform)

test_dir = os.path.join("./tiny-imagenet-200", "test")
full_test_dataset = datasets.ImageFolder(test_dir, transform=transform)

class_indices = defaultdict(list)
for idx, (_, label) in enumerate(full_dataset.samples):
    class_indices[label].append(idx)

# stratified sampling
rng = random.Random(RANDOM_SEED)
train_indices = []
val_indices = []

for label, indices in class_indices.items():
    rng.shuffle(indices)  
    
    n_total = len(indices)
    n_train = int(SUBSET_FRACTION * n_total)
    n_val = int(VAL_FRACTION * n_total)

    n_train = max(n_train, 1)
    n_val = max(n_val, 1)

    available = indices[:n_train + n_val]
    train_indices.extend(available[:n_train])
    val_indices.extend(available[n_train:n_train + n_val])


train_indices = sorted(train_indices)
val_indices = sorted(val_indices)

train_subset = Subset(full_dataset, train_indices)
val_subset = Subset(full_dataset, val_indices)

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader
# import timm
# from torchvision import datasets, transforms

# MODEL TRAINING

# device = "cuda" if torch.cuda.is_available() else "cpu"
# num_classes = 200
# batch_size = 32
# epochs = 10
# lr = 0.001


# train_transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# ])

# val_transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# ])


# model = timm.create_model('vit_small_patch16_224.dino', pretrained=True)

# for param in model.parameters():
#     param.requires_grad = False

# num_features = model.num_features 

# model.head = nn.Linear(model.num_features, num_classes)
# for param in model.head.parameters():
#     param.requires_grad = True

# model = model.to(device)

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.head.parameters(), lr=lr)


# for epoch in range(epochs):
#     model.train()
#     total_loss = 0
#     correct = 0
#     total = 0
    
#     for images, labels in val_loader:
#         images, labels = images.to(device), labels.to(device)
        
#         outputs = model(images)
#         loss = criterion(outputs, labels)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item() * images.size(0)
#         _, predicted = outputs.max(1)
#         correct += predicted.eq(labels).sum().item()
#         total += labels.size(0)

#     train_acc = correct / total
#     train_loss = total_loss / total
#     print(f"Epoch {epoch+1}/{epochs} - Loss: {train_loss:.4f} - Acc: {train_acc:.4f}")

# model.eval()

# model_path = "vit_dino_finetuned.pth"
# torch.save(model.state_dict(), model_path)
# print(f"Model saved to {model_path}")

In [116]:
import timm
num_classes = 200
device = "cuda" if torch.cuda.is_available() else "cpu"

model = timm.create_model('vit_small_patch16_224.dino', pretrained=False) 
num_features = model.num_features 
model.head = nn.Linear(model.num_features, num_classes)
state_dict = torch.load("vit_dino_finetuned.pth", map_location="cpu") 
model.load_state_dict(state_dict)

<All keys matched successfully>

In [117]:
perform_logit_lens_analysis(model=model, dataset=train_subset, device=device,
                            cosine_path="logit_lens_results/DINO_lp/cosine_similarity.csv", 
                            preds_path="logit_lens_results/DINO_lp/predictions.csv", 
                            true_probs_path="logit_lens_results/DINO_lp/true_class_probs.csv", 
                            first_top_probs_path="logit_lens_results/DINO_lp/first_top_class_probs.csv", 
                            last_layer_probs_path="logit_lens_results/DINO_lp/last_layer_top_class_probs.csv", 
                            second_last_layer_probs_path="logit_lens_results/DINO_lp/second_last_layer_top_class_probs.csv",
                            third_last_layer_probs_path="logit_lens_results/DINO_lp/third_last_layer_top_class_probs.csv",
                            kl_divergence_path="logit_lens_results/DINO_lp/kl_divergence.csv")