In [1]:
import os
import zipfile
import random
import csv
from io import BytesIO
from types import MethodType
from collections import defaultdict

import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

import clip



In [2]:
def set_all_seeds(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
set_all_seeds(42)

In [4]:
def wrap_vit_blocks(model):
    activations = {}
    original_blocks = model.visual.transformer.resblocks

    for i, block in enumerate(original_blocks):
        def make_custom_forward(orig_block, layer_name):
            def custom_forward(self, x):
                out = orig_block(x)
                activations[layer_name] = out.clone()
                return out
            return custom_forward

        block.forward = MethodType(make_custom_forward(block.forward, f"layer_{i}"), block)

    return activations


def logit_lens_analysis(activations, projection_head, ln_post, final_output, text_features, dictionary, temperature, top_k=1, true_class_idx=None):
    distances = {}
    predictions = {}
    true_class_probs = []
    first_top_class_probs = []
    last_layer_top_class_probs = []
    last_layer_second_top_class_probs = []
    last_layer_third_top_class_probs = []
    random_class_probs = []
    kl_divergence = []

    first_layer_top_class = None
    last_layer_top_class = None

    first_class_idx = None
    last_class_idx = None

    last_layer_name = list(activations.keys())[-1]
    # get probs at the last layer
    last_layer_activ = activations[last_layer_name]
    last_cls_token = last_layer_activ[:,0,:]
    last_cls_token = ln_post(last_cls_token)
    last_logits = last_cls_token @ projection_head
    last_logits = F.normalize(last_logits, dim=-1)
    last_logits = temperature * last_logits @ text_features.T
    last_probs = F.softmax(last_logits, dim=-1)
    topk_probs, topk_indices = torch.topk(last_probs, k=3, dim=-1)  # zakładam batch_size = 1

    last_layer_top_class = topk_indices[0][0].item()    # top-1
    last_layer_second_top_class = topk_indices[0][1].item()  # top-2
    last_layer_third_top_class = topk_indices[0][2].item()   # top-3

    
    for name, x in activations.items():
        x = x.permute(1, 0, 2)  # -> (batch, seq_len, dim)
        cls_token = x[:, 0, :]  # take CLS token

        cls_token = ln_post(cls_token)

        # [1] Projekcja CLS tokena
        projected = cls_token @ projection_head
        projected = F.normalize(projected, dim=-1)

        # [2] Oblicz logity jak w CLIP
        logits = temperature * projected @ text_features.T 

        # [3] Softmax i top-k
        probs = F.softmax(logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(probs, k=top_k)

        # Calculate KL divergence if last layer return zeros
        if name == last_layer_name:
            kl_div = torch.zeros(logits.shape[-1], device=logits.device)
        else:
            kl_div = F.kl_div(probs.log(), last_probs, reduction='none').sum(dim=-1)

        # [4] Zapisz predykcje
        top_k_predictions = []
        for idx, prob in zip(top_k_indices[0], top_k_probs[0]):  # [0] bo batch = 1
            predicted_idx = idx.item()
            predicted_label = dictionary[predicted_idx]
            top_k_predictions.append((predicted_label, prob.item()))

        predictions[name] = top_k_predictions

        # random_class_idx = random.randint(0, logits.shape[-1] - 1)
        # if random_class_idx is not None:
        #     random_class_probs.append(float(probs[0, random_class_idx].item()))

        if true_class_idx is not None:
            true_class_probs.append(float(probs[0, true_class_idx].item()))
        else: 
            true_class_probs[name] = None

        if name == list(activations.keys())[0]:
            first_class_idx = top_k_indices[0][0].item()
            first_layer_top_class = top_k_predictions[0][0] if top_k_predictions else None

        if first_layer_top_class is not None:
            first_top_class_probs.append(float(probs[0, first_class_idx].item()))
        else:
            first_top_class_probs[name] = None

        if last_layer_top_class is not None:
            last_layer_top_class_probs.append(float(probs[0, last_layer_top_class].item()))
        else:
            last_layer_top_class_probs.append(None)

        if last_layer_second_top_class is not None:
            last_layer_second_top_class_probs.append(float(probs[0, last_layer_second_top_class].item()))
        else:
            last_layer_second_top_class_probs.append(None)
        
        if last_layer_third_top_class is not None:
           last_layer_third_top_class_probs.append(float(probs[0, last_layer_third_top_class].item()))
        else:
            last_layer_third_top_class_probs.append(None)

        if kl_div is not None:
            kl_divergence.append(float(kl_div.mean().item()))
        
        # [5] Dodatkowo: cosine similarity do final_output — zostaje
        similarity = F.cosine_similarity(projected, final_output, dim=-1)
        distances[name] = similarity.detach().cpu().numpy().tolist()[0]

    return distances, predictions, true_class_probs, first_top_class_probs, last_layer_top_class_probs, last_layer_second_top_class_probs, last_layer_third_top_class_probs, kl_divergence, first_class_idx, last_layer_top_class, last_layer_second_top_class, last_layer_third_top_class


def load_tiny_imagenet_labels(path="tiny-imagenet-200/words.txt"):
    wnid_to_label = {}
    with open(path, "r") as f:
        for line in f:
            wnid, label = line.strip().split("\t")
            wnid_to_label[wnid] = label
    return wnid_to_label

def perform_logit_lens_analysis(model, dataset, device,
                                cosine_path = "logit_lens_results/cosine_similarity.csv", 
                                preds_path = "logit_lens_results/predictions.csv",
                                true_probs_path = "logit_lens_results/true_class_probs.csv",
                                first_top_probs_path = "logit_lens_results/first_top_class_probs.csv",
                                last_layer_probs_path = "logit_lens_results/last_layer_top_class_probs.csv",
                                second_last_layer_probs_path = "logit_lens_results/second_last_layer_top_class_probs.csv",
                                third_last_layer_probs_path = "logit_lens_results/third_last_layer_top_class_probs.csv",
                                kl_divergence_path = "logit_lens_results/kl_divergence.csv"):
    
    idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}
    model.eval()

    os.makedirs("logit_lens_results", exist_ok=True)

    prev_activations = activations.copy() if 'activations' in globals() else {}
    activations = wrap_vit_blocks(model)

    for key in prev_activations:
        assert key not in activations, f"Key {key} from previous activations is still present! Possible accumulation."

    headers = sorted([f"layer_{i}" for i in range(len(model.visual.transformer.resblocks))], key=lambda x: int(x.split('_')[1]))

    wnid_to_label = load_tiny_imagenet_labels()
    all_classes = [f"a photo of {wnid_to_label[idx_to_class[i]]}" for i in range(len(idx_to_class))]
    
    text_tokens_all = clip.tokenize(all_classes).to(device)
    
    with torch.no_grad():
        all_text_features = model.encode_text(text_tokens_all)
        all_text_features = all_text_features / all_text_features.norm(dim=1, keepdim=True)

    for image_idx, (image, label) in enumerate(dataset):
        image = image.unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)

        final_output = F.normalize(image_features, dim=-1)

        true_wnid = idx_to_class[label] 
        true_label = wnid_to_label.get(true_wnid, "") 
        true_class_idx = label.item() if hasattr(label, 'item') else int(label)

        distances, predictions, true_class_probs, first_top_probs, last_layer_probs, second_last_layer_probs, third_last_layer_probs, kl_div, first_class_idx, last_layer_top_class_idx, last_layer_second_top_class_idx, last_layer_third_top_class_idx = logit_lens_analysis(
            activations,
            model.visual.proj,
            model.visual.ln_post,
            final_output,
            all_text_features,
            idx_to_class, 
            model.logit_scale.exp(),
            true_class_idx=true_class_idx
        )

        os.makedirs(os.path.dirname(cosine_path), exist_ok=True)
        os.makedirs(os.path.dirname(preds_path), exist_ok=True)
        os.makedirs(os.path.dirname(true_probs_path), exist_ok=True)
        os.makedirs(os.path.dirname(first_top_probs_path), exist_ok=True)
        os.makedirs(os.path.dirname(last_layer_probs_path), exist_ok=True)
        os.makedirs(os.path.dirname(second_last_layer_probs_path), exist_ok=True)
        os.makedirs(os.path.dirname(third_last_layer_probs_path), exist_ok=True)
        os.makedirs(os.path.dirname(kl_divergence_path), exist_ok=True) 
        
        # --- KL DIVERGENCE FILE ---
        kl_divergence_header = ['Image', 'True_WNID', 'True_Label'] + headers
        write_header = not os.path.exists(kl_divergence_path) or os.path.getsize(kl_divergence_path) == 0
        with open(kl_divergence_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(kl_divergence_header)
            kl_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + kl_div
            writer.writerow(kl_row)


        # --- COSINE FILE ---
        cosine_header = ['Image', 'True_WNID', 'True_Label'] + headers
        write_header = not os.path.exists(cosine_path) or os.path.getsize(cosine_path) == 0
        with open(cosine_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(cosine_header)
            cosine_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + \
                        [distances[layer] for layer in headers]
            writer.writerow(cosine_row)

        # --- PREDICTIONS FILE ---
        pred_header = ['Image', 'True_WNID', 'True_Label'] + [f"{layer}_label" for layer in headers] + [f"{layer}_prob" for layer in headers]
        write_header = not os.path.exists(preds_path) or os.path.getsize(preds_path) == 0
        with open(preds_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(pred_header)
            pred_wnids = [predictions[layer][0][0] for layer in headers]
            pred_labels = [wnid_to_label.get(wnid, wnid) for wnid in pred_wnids]
            pred_probs = [predictions[layer][0][1] for layer in headers]
            pred_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + pred_labels + pred_probs
            writer.writerow(pred_row)

        # --- TRUE CLASS PROBS FILE ---
        true_probs_header = ['Image', 'True_WNID', 'True_Label'] + headers
        write_header = not os.path.exists(true_probs_path) or os.path.getsize(true_probs_path) == 0
        with open(true_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(true_probs_header)
            true_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + true_class_probs
            writer.writerow(true_probs_row)

        # --- FIRST TOP CLASS PROBS FILE ---
        first_top_probs_header = ['Image', 'True_WNID', 'True_Label', 'First_Top1_WNID', 'First_Top1_Label'] + headers

         # Get first top class info
        first_top_idx = predictions.get(f"{headers[0]}", None)
        first_top_idx = first_top_idx[0][0] if first_top_idx else "N/A"
        first_top_label = wnid_to_label.get(first_top_idx, "N/A")

        write_header = not os.path.exists(first_top_probs_path) or os.path.getsize(first_top_probs_path) == 0
        with open(first_top_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(first_top_probs_header)
            first_top_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, first_top_idx, first_top_label] + first_top_probs
            writer.writerow(first_top_probs_row)

        # --- LAST LAYER TOP CLASS PROBS FILE ---
        last_layer_probs_header = ['Image', 'True_WNID', 'True_Label', 'Last_Top1_WNID', 'Last_Top1_Label'] + headers
        last_layer_top_idx = predictions.get(f"{headers[-1]}", None)
        last_layer_top_idx = last_layer_top_idx[0][0] if last_layer_top_idx else "N/A"
        last_label = wnid_to_label.get(last_layer_top_idx, "N/A")

        write_header = not os.path.exists(last_layer_probs_path) or os.path.getsize(last_layer_probs_path) == 0
        with open(last_layer_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(last_layer_probs_header)
            last_layer_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, last_layer_top_idx, last_label] + last_layer_probs
            writer.writerow(last_layer_probs_row)

        # --- SECOND LAST LAYER TOP CLASS PROBS FILE ---
        second_last_layer_probs_header = ['Image', 'True_WNID', 'True_Label', 'Second_Last_Top1_WNID', 'Second_Last_Top1_Label'] + headers
        second_last_wnid = idx_to_class.get(last_layer_second_top_class_idx, "N/A")
        second_last_label = wnid_to_label.get(second_last_wnid, "N/A")
        write_header = not os.path.exists(second_last_layer_probs_path) or os.path.getsize(second_last_layer_probs_path) == 0
        with open(second_last_layer_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(second_last_layer_probs_header)
            second_last_layer_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, second_last_wnid, second_last_label] + second_last_layer_probs
            writer.writerow(second_last_layer_probs_row)
            
        # --- THIRD LAST LAYER TOP CLASS PROBS FILE ---
        third_last_layer_probs_header = ['Image', 'True_WNID', 'True_Label', 'Third_Last_Top1_WNID', 'Third_Last_Top1_Label'] + headers
        third_last_wnid = idx_to_class.get(last_layer_third_top_class_idx, "N/A")
        third_last_label = wnid_to_label.get(third_last_wnid, "N/A")
        write_header = not os.path.exists(third_last_layer_probs_path) or os.path.getsize(third_last_layer_probs_path) == 0
        with open(third_last_layer_probs_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(third_last_layer_probs_header)
            third_last_layer_probs_row = [f"Image_{image_idx + 1}", true_wnid, true_label, third_last_wnid, third_last_label] + third_last_layer_probs
            writer.writerow(third_last_layer_probs_row)



In [5]:
# Config
RANDOM_SEED = 42
SUBSET_FRACTION = 0.1  # 0.05 for 5%, 0.1 for 10%
BATCH_SIZE = 64

# Transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])

# Load full dataset
train_dir = os.path.join("./tiny-imagenet-200", "train")
full_dataset = datasets.ImageFolder(train_dir, transform=transform)

# Group indices by class
class_indices = defaultdict(list)
for idx, (_, label) in enumerate(full_dataset.samples):
    class_indices[label].append(idx)

# Stratified sampling
rng = random.Random(RANDOM_SEED)
subset_indices = []
for label, indices in class_indices.items():
    rng.shuffle(indices)  # shuffle within each class
    n_select = int(SUBSET_FRACTION * len(indices))
    n_select = max(n_select, 1)  # ensure at least 1 per class
    subset_indices.extend(indices[:n_select])

# Sort for consistent image loading order
subset_indices = sorted(subset_indices)

# Create subset dataset
subset_dataset = Subset(full_dataset, subset_indices)
train_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [6]:

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP ViT-L/14 model
model, preprocess = clip.load("ViT-L/14", device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
          )
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)


In [7]:
perform_logit_lens_analysis(model=model, dataset=subset_dataset, device=device,
                            cosine_path="logit_lens_results/CLIP_L/cosine_similarity.csv", 
                            preds_path="logit_lens_results/CLIP_L/predictions.csv", 
                            true_probs_path="logit_lens_results/CLIP_L/true_class_probs.csv", 
                            first_top_probs_path="logit_lens_results/CLIP_L/first_top_class_probs.csv", 
                            last_layer_probs_path="logit_lens_results/CLIP_L/last_layer_top_class_probs.csv", 
                            second_last_layer_probs_path="logit_lens_results/CLIP_L/second_last_layer_top_class_probs.csv",
                            third_last_layer_probs_path="logit_lens_results/CLIP_L/third_last_layer_top_class_probs.csv",
                            kl_divergence_path="logit_lens_results/CLIP_L/kl_divergence.csv")