## Notebook for computing avergae pairwise distance of high-attention tokens from different layers in CLIP

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
from transformers import CLIPProcessor, CLIPModel
from scipy.spatial.distance import pdist

In [None]:
COCO_IMAGE_DIR = "/media/daniel/Data/Datasets/val2017"
# Path to COCO val2017 dataset. You can download from
# http://images.cocodataset.org/zips/val2017.zip

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
TOP_PERCENT = 0.05  # 0.05, 0.1, 0.25 in our paper

In [None]:
clip_path = "openai/clip-vit-large-patch14-336"
model = CLIPModel.from_pretrained(clip_path).to(DEVICE)
processor = CLIPProcessor.from_pretrained(clip_path)

In [None]:
def get_image_paths(folder):
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
    return [
        os.path.join(folder, f) for f in os.listdir(folder) 
        if os.path.splitext(f)[1].lower() in image_extensions
    ]

def get_attention_and_last_features(image, layer_idx):
    inputs = processor(images=image, return_tensors="pt", padding=True)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.vision_model(
            **inputs, 
            output_attentions=True,
            output_hidden_states=True
        )

    attentions = outputs.attentions[layer_idx]  # (batch, heads, seq_len, seq_len)
    last_features = outputs.hidden_states[-1]     # (batch, seq_len, hidden_size)
    
    return attentions, last_features

def process_batch(images, layer_idx):
    attentions, last_features = get_attention_and_last_features(images, layer_idx)
    
    batch_results = []
    for i in range(len(images)):
        img_attn = attentions[i]
        img_feat = last_features[i]
        token_scores = img_attn.mean(dim=0).sum(0)[1:]

        num_top = max(1, int(len(token_scores) * TOP_PERCENT))
        top_indices = torch.topk(token_scores, num_top).indices + 1 # Exclude CLS

        top_features = img_feat[top_indices].cpu().numpy()

        if len(top_features) > 1:
            distances = pdist(top_features, 'euclidean')
            avg_distance = np.mean(distances)
        else:
            avg_distance = 0.0 
        batch_results.append(avg_distance)
    
    return batch_results

In [None]:
image_paths = get_image_paths(COCO_IMAGE_DIR)
print(f"Found {len(image_paths)} images in {COCO_IMAGE_DIR}")

for layer_idx in range(24):
    all_distances = []

    for i in tqdm(range(0, len(image_paths), BATCH_SIZE)):
        batch_paths = image_paths[i:i+BATCH_SIZE]
        batch_images = [Image.open(path) for path in batch_paths]
        batch_distances = process_batch(batch_images, layer_idx)
        all_distances.extend(batch_distances)
    
    print(f"Avg pairwise dist for Layer {layer_idx+1}: {np.mean(all_distances):.3f}")

In [None]:
plt.figure(figsize=(12, 3))
plt.plot(all_distances)