In [16]:
import os
import csv
import random
import zipfile
import requests
from io import BytesIO
from types import MethodType
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

import clip
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


In [17]:
def set_all_seeds(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [18]:
set_all_seeds(42)

In [None]:
def wrap_vit_blocks(model):
    activations = {}
    original_blocks = model.visual.transformer.resblocks

    for i, block in enumerate(original_blocks):
        def make_custom_forward(orig_block, layer_name):
            def custom_forward(self, x):
                out = orig_block(x)
                activations[layer_name] = out.clone()
                return out
            return custom_forward

        block.forward = MethodType(make_custom_forward(block.forward, f"layer_{i}"), block)

    return activations


def logit_lens_analysis(activations, projection_head, ln_post, final_output, text_features, dictionary, temperature, top_k=1):
    distances = {}
    predictions = {}

    for name, x in activations.items():
        x = x.permute(1, 0, 2)  # -> (batch, seq_len, dim)
        cls_token = x[:, 0, :]  # take CLS token

        cls_token = ln_post(cls_token)

        # [1] Projekcja CLS tokena
        projected = cls_token @ projection_head
        projected = F.normalize(projected, dim=-1)

        # [2] Oblicz logity jak w CLIP
        logits = temperature * projected @ text_features.T 

        # [3] Softmax i top-k
        probs = F.softmax(logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(probs, k=top_k)

        # [4] Zapisz predykcje
        top_k_predictions = []
        for idx, prob in zip(top_k_indices[0], top_k_probs[0]):  # [0] bo batch = 1
            predicted_idx = idx.item()
            predicted_label = dictionary[predicted_idx]
            top_k_predictions.append((predicted_label, prob.item()))

        predictions[name] = top_k_predictions

        # [5] Dodatkowo: cosine similarity do final_output — zostaje
        similarity = F.cosine_similarity(projected, final_output, dim=-1)
        distances[name] = similarity.detach().cpu().numpy().tolist()[0]

    return distances, predictions

def load_tiny_imagenet_labels(path="tiny-imagenet-200/words.txt"):
    wnid_to_label = {}
    with open(path, "r") as f:
        for line in f:
            wnid, label = line.strip().split("\t")
            wnid_to_label[wnid] = label
    return wnid_to_label

def perform_logit_lens_analysis(model, dataset, device, cosine_path = "logit_lens_results/cosine_similarity.csv", preds_path = "logit_lens_results/predictions.csv"):
    idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}
    model.eval()

    os.makedirs("logit_lens_results", exist_ok=True)

    prev_activations = activations.copy() if 'activations' in globals() else {}
    activations = wrap_vit_blocks(model)

    for key in prev_activations:
        assert key not in activations, f"Key {key} from previous activations is still present! Possible accumulation."

    headers = sorted([f"layer_{i}" for i in range(len(model.visual.transformer.resblocks))], key=lambda x: int(x.split('_')[1]))

    wnid_to_label = load_tiny_imagenet_labels()
    all_classes = [f"a photo of {wnid_to_label[idx_to_class[i]]}" for i in range(len(idx_to_class))]
    
    text_tokens_all = clip.tokenize(all_classes).to(device)
    
    with torch.no_grad():
        all_text_features = model.encode_text(text_tokens_all)
        all_text_features = all_text_features / all_text_features.norm(dim=1, keepdim=True)

    for image_idx, (image, label) in enumerate(dataset):
        image = image.unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)

        final_output = F.normalize(image_features, dim=-1)

        true_wnid = idx_to_class[label] 
        true_label = wnid_to_label.get(true_wnid, "") 

        distances, predictions = logit_lens_analysis(
            activations,
            model.visual.proj,
            model.visual.ln_post,
            final_output,
            all_text_features,
            idx_to_class, 
            model.logit_scale.exp()
        )

        os.makedirs(os.path.dirname(cosine_path), exist_ok=True)
        os.makedirs(os.path.dirname(preds_path), exist_ok=True) 

        # Update headers to include ground truth info
        cosine_header = ['Image', 'True_WNID', 'True_Label'] + headers
        pred_header = ['Image', 'True_WNID', 'True_Label'] + \
                    [f"{layer}_label" for layer in headers] + \
                    [f"{layer}_prob" for layer in headers]

        # --- COSINE FILE ---
        write_header = not os.path.exists(cosine_path) or os.path.getsize(cosine_path) == 0
        with open(cosine_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(cosine_header)
            cosine_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + \
                        [distances[layer] for layer in headers]
            writer.writerow(cosine_row)

        # --- PREDICTIONS FILE ---
        write_header = not os.path.exists(preds_path) or os.path.getsize(preds_path) == 0
        with open(preds_path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(pred_header)
            pred_wnids = [predictions[layer][0][0] for layer in headers]
            pred_labels = [wnid_to_label.get(wnid, wnid) for wnid in pred_wnids]
            pred_probs = [predictions[layer][0][1] for layer in headers]
            pred_row = [f"Image_{image_idx + 1}", true_wnid, true_label] + pred_labels + pred_probs
            writer.writerow(pred_row)



In [20]:

# Config
RANDOM_SEED = 42
SUBSET_FRACTION = 0.1  # 0.05 for 5%, 0.1 for 10%
BATCH_SIZE = 64

# Transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])

# Load full dataset
train_dir = os.path.join("./tiny-imagenet-200", "train")
full_dataset = datasets.ImageFolder(train_dir, transform=transform)

# Group indices by class
class_indices = defaultdict(list)
for idx, (_, label) in enumerate(full_dataset.samples):
    class_indices[label].append(idx)

# Stratified sampling
rng = random.Random(RANDOM_SEED)
subset_indices = []
for label, indices in class_indices.items():
    rng.shuffle(indices)  # shuffle within each class
    n_select = int(SUBSET_FRACTION * len(indices))
    n_select = max(n_select, 1)  # ensure at least 1 per class
    subset_indices.extend(indices[:n_select])

# Sort for consistent image loading order
subset_indices = sorted(subset_indices)

# Create subset dataset
subset_dataset = Subset(full_dataset, subset_indices)
train_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP ViT model
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [None]:
perform_logit_lens_analysis(model=model, dataset=subset_dataset, device=device, cosine_path="logit_lens_results/CLIP/cosine_similarity.csv", preds_path="logit_lens_results/CLIP/predictions.csv")

KeyboardInterrupt: 

In [8]:

distances = pd.read_csv("logit_lens_results/CLIP/cosine_similarity.csv")
predictions = pd.read_csv("logit_lens_results/CLIP/predictions.csv")