In [90]:
import clip
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from types import MethodType
import os
import zipfile
import requests
from io import BytesIO
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [110]:
def wrap_vit_blocks(model):
    original_blocks = model.visual.transformer.resblocks
    activations = {}

    for i, block in enumerate(original_blocks):
        def make_custom_forward(orig_block, layer_name):
            def custom_forward(self, x):
                out = orig_block(x)
                activations[layer_name] = out.clone()
                return out
            return custom_forward

        block.forward = MethodType(make_custom_forward(block.forward, f"layer_{i}"), block)

    return activations


def logit_lens_analysis(activations, projection_head, ln_post, final_output, text_features, dictionary, temperature, top_k=5):
    '''
    Perform logit lens analysis on the activations
    Returns:
        - distances: cosine similarity to final output
        - predictions: (predicted_label, similarity_score)
    '''
    distances = {}
    predictions = {}

    for name, x in activations.items():
        # x shape: (seq_len, batch, dim)
        x = x.permute(1, 0, 2)  # -> (batch, seq_len, dim)
        cls_token = x[:, 0, :]  # take CLS token

        # Apply final layer norm
        cls_token = ln_post(cls_token)

        # Project using CLIP's final projection matrix
        projected = cls_token @ projection_head  # (1, 512)
        projected = F.normalize(projected, dim=-1)

        # Cosine similarity with final output
        similarity = F.cosine_similarity(projected, final_output, dim=-1)
        distances[name] = similarity.detach().cpu().numpy().tolist()[0]

        # Cosine similarity with all text features
        text_similarity = F.cosine_similarity(projected, text_features, dim=-1)
        text_similarity = text_similarity * temperature
        
        all_probs = F.softmax(text_similarity, dim=-1)
        top_k_probs, top_k_indices = torch.topk(all_probs, k=top_k)
        
        # top_k_values, top_k_indices = torch.topk(text_similarity, k=top_k)
        # softmax_probs = F.softmax(top_k_values, dim=0)

        # Predictions: collect the top-k predictions and their probabilities
        top_k_predictions = []
        for idx, prob in zip(top_k_indices, top_k_probs):
            predicted_idx = idx.item()
            predicted_label = dictionary[predicted_idx]
            top_k_predictions.append((predicted_label, prob.item()))

        predictions[name] = top_k_predictions

    return distances, predictions


import os
import csv

def perform_logit_lense_analysis(model, dataset, device):
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    model.eval()

    # Utwórz folder na wyniki
    os.makedirs("logit_lens_results", exist_ok=True)

    # Rejestracja hooków
    activations = wrap_vit_blocks(model)
    headers = sorted([f"layer_{i}" for i in range(len(model.visual.transformer.resblocks))], key=lambda x: int(x.split('_')[1]))

    # Przygotuj pliki CSV z nagłówkami
    cosine_path = "logit_lens_results/cosine_similarity.csv"
    preds_path = "logit_lens_results/predictions.csv"

    all_classes = [f"a photo of a {idx_to_class[i]}" for i in range(len(idx_to_class))]
    text_tokens_all = clip.tokenize(all_classes).to(device)
    with torch.no_grad():
        all_text_features = model.encode_text(text_tokens_all)
        all_text_features = F.normalize(all_text_features, dim=-1)

    for image_idx, (image, label) in enumerate(dataset):
        image = image.unsqueeze(0).to(device)
        # class_name = idx_to_class[label]
        # text = clip.tokenize([f"a photo of a {class_name}"]).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)
            # text_features = model.encode_text(text)

        final_output = F.normalize(image_features, dim=-1)

        distances, predictions = logit_lens_analysis(
            activations,
            model.visual.proj,
            model.visual.ln_post,
            final_output,
            all_text_features,
            idx_to_class, 
            model.logit_scale.exp()
        )

        # --- Zapis cosine similarity ---
        cosine_row = [f"Image_{image_idx + 1}"] + [distances[layer] for layer in headers]
        with open(cosine_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(cosine_row)

        # --- Zapis predykcji i prawdopodobieństw ---
        pred_labels = [predictions[layer][0][0] for layer in headers]  # label from top-1
        pred_probs = [predictions[layer][0][1] for layer in headers]  # prob from top-1
        pred_row = [f"Image_{image_idx + 1}"] + pred_labels + pred_probs
        with open(preds_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(pred_row)


def plot_results(distances, predictions):
    '''
    Plots cosine similarity and prediction probabilities for each layer
    '''
    layer_names = sorted(distances.keys(), key=lambda x: int(x.split('_')[1]))
    
    similarity_values = [distances[layer] for layer in layer_names]
    prob_values = [predictions.get(f"{layer}_prob", np.nan) for layer in layer_names]  # fallback if missing
    predicted_labels = [predictions.get(f"{layer}_label", "") for layer in layer_names]

    fig, axes = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={'height_ratios': [1, 1]})

    sns.heatmap(np.array(similarity_values).reshape(1, -1), annot=True, cmap="viridis",
                xticklabels=layer_names, yticklabels=["Cosine Similarity"], cbar=True,
                ax=axes[0], cbar_kws={'label': 'Cosine Similarity'})

    sns.heatmap(np.array(prob_values).reshape(1, -1), annot=True, cmap="magma",
                xticklabels=layer_names, yticklabels=["Prediction Prob."], cbar=True,
                ax=axes[1], cbar_kws={'label': 'Prediction Probability'})

    for i, label in enumerate(predicted_labels):
        axes[1].text(i + 0.5, -0.3, label, ha='center', va='center', color='black', fontsize=9, rotation=90, transform=axes[1].transData)

    plt.suptitle("Cosine Similarity & Prediction Probability per Layer", fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [92]:
# URL i ścieżka tymczasowa
url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
extract_path = "./tiny-imagenet-200"

print("Downloading Tiny ImageNet...")
response = requests.get(url)
with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
    zip_ref.extractall(".")

print("Download and extraction complete.")


train_dir = os.path.join(extract_path, "train")
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711]),
])
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)



Downloading Tiny ImageNet...
Download and extraction complete.


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP ViT model
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

In [111]:
perform_logit_lense_analysis(model=model, dataset=train_dataset, device=device)

In [98]:
import pandas as pd

distances = pd.read_csv("logit_lens_results/cosine_similarity.csv")
predictions = pd.read_csv("logit_lens_results/predictions.csv")

In [None]:
plot_results(distances.iloc[100001, :].drop("Image", errors="ignore"), predictions.iloc[100001,:].drop("Image", errors="ignore"))