In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import warnings
warnings.filterwarnings("ignore")

import argparse 
import torch

from datasets import load_dataset
from transformers import AutoImageProcessor
from lens import HookedVisionTransformer

VisionTransformer

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load a model (e.g. ViT Base)
model_name_or_path = "google/vit-base-patch16-224"
model = HookedVisionTransformer.from_pretrained(model_name_or_path, device)
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)
dataset = load_dataset(
    "cifar10",
    cache_dir="./../cache",
    task="image-classification"
)

# extract image features
features = image_processor(images=dataset['test']['image'][0], return_tensors="pt")
features.to(device)

# run the model and get outputs and activations
outputs, activations = model.run_with_cache(**features)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
import math
import plotly.express as px
import torch.nn as nn


def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
    attention_head_size = int(model.config.hidden_size / model.config.num_attention_heads)
    new_x_shape = x.size()[:-1] + (model.config.num_attention_heads, attention_head_size)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

def compute_attention_patterns(activations: dict, layer: int):

    key_layer = transpose_for_scores(activations[f"encoder.layer.{layer}.attention.attention.key"])
    query_layer = transpose_for_scores(activations[f"encoder.layer.{layer}.attention.attention.query"])

    # take the dot product between "query" and "key" to get the raw attention scores
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(int(model.config.hidden_size / model.config.num_attention_heads))

    # normalize the attention scores to probabilities
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    return attention_probs.cpu()


def display_attention_head(layer:int, head: int):

    attention_patterns = compute_attention_patterns(activations, layer)  # shape: (batch_size, n_heads, seq_len, seq_len)

    attention_pattern = attention_patterns[0, head, :, :]
    fig = px.imshow(
        attention_pattern,
        title=f"Attention Pattern of Head {head}",
        color_continuous_scale='RdBu_r',
        zmin=-1, zmax=1,
        height=600, width=600
    )
    fig.show()


display_attention_head(0, 0)

In [5]:
layer_to_ablate = 0
head_index_to_ablate = 0
def head_ablation_hook(value, hook):
    
    # convert (batch_size, seq_len, d_model) to (batch_size, n_heads, seq_len, d_head)
    value = transpose_for_scores(value)

    # zero-ablate an attention head
    value[:, head_index_to_ablate, :, :] = 0.

    # convert (batch_size, n_heads, sew_len, d_head) to (batch_size, seq_len, d_model)
    value = value.permute(0, 2, 1, 3).flatten(2, 3)
    return value


original_outputs = model(**features, return_dict=True)
ablated_outputs = model.run_with_hooks(
    **features,
    fwd_hooks=[(
        "encoder.layer.0.attention.attention.value",
        head_ablation_hook
        )]
    )
print(f"Original Outputs: {original_outputs[0]}")
print(f"Ablated Outputs: {ablated_outputs[0]}")

Original Outputs: tensor([[[-0.3164,  0.8203, -0.1163,  ..., -1.7754, -0.2287,  2.1335],
         [-0.3456,  0.3930,  0.2273,  ..., -1.1242,  0.0411,  0.2292],
         [-0.4601,  0.6536,  1.0274,  ..., -1.4169, -0.7038, -0.1002],
         ...,
         [ 0.1290, -0.5134,  0.2571,  ..., -1.0091,  0.7674,  0.8107],
         [ 0.4043,  0.1345,  0.9750,  ..., -1.1067,  1.2420,  0.6697],
         [-0.2196,  0.9205, -0.4363,  ..., -0.6910, -0.2315,  1.6317]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)
Ablated Outputs: tensor([[[-0.3466,  0.6985, -0.3183,  ..., -2.4621,  0.0420,  2.2638],
         [-0.2161,  0.2667, -0.1684,  ..., -1.5102,  0.2997,  0.3439],
         [-0.6365,  0.7776,  0.9260,  ..., -1.9624, -0.2277,  0.3949],
         ...,
         [ 0.1574, -0.7405,  0.1987,  ..., -1.4594,  0.9176,  1.1872],
         [ 0.2932, -0.1248,  0.8330,  ..., -1.6230,  1.3223,  0.7320],
         [-0.1743,  0.5063, -0.3566,  ..., -1.1652, -0.3060,  1.5869]]],
       device='cuda:0

CLIP

In [6]:
from PIL import Image
import requests

from transformers import CLIPProcessor


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load a model (e.g. ViT Base)
model_name_or_path = "openai/clip-vit-base-patch32"
model = HookedVisionTransformer.from_pretrained(model_name_or_path, device)
processor = CLIPProcessor.from_pretrained(model_name_or_path)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# extract image features
inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
inputs.to(device)

# run the model and get outputs and activations
outputs, activations = model.run_with_cache(**inputs)

In [52]:
modules = {}
for n, m in model.named_modules():
    modules[n] = m
print(modules) 

{'': HookedVisionTransformer(
  (model): CLIPModel(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 512)
        (position_embedding): Embedding(77, 512)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (fc2): Linear(in_features=2048, out_features=51

In [5]:
activations.keys()

dict_keys(['vision_embeddings', 'vision_model.ln1', 'vision_model.0.ln1', 'vision_model.0.attn_q', 'vision_model.0.attn_k', 'vision_model.0.attn_v', 'vision_model.0.attn_z', 'vision_model.0.ln2', 'vision_model.0.mlp.fc1', 'vision_model.0.mlp.act_fn', 'vision_model.0.mlp.fc2', 'vision_model.1.ln1', 'vision_model.1.attn_q', 'vision_model.1.attn_k', 'vision_model.1.attn_v', 'vision_model.1.attn_z', 'vision_model.1.ln2', 'vision_model.1.mlp.fc1', 'vision_model.1.mlp.act_fn', 'vision_model.1.mlp.fc2', 'vision_model.2.ln1', 'vision_model.2.attn_q', 'vision_model.2.attn_k', 'vision_model.2.attn_v', 'vision_model.2.attn_z', 'vision_model.2.ln2', 'vision_model.2.mlp.fc1', 'vision_model.2.mlp.act_fn', 'vision_model.2.mlp.fc2', 'vision_model.3.ln1', 'vision_model.3.attn_q', 'vision_model.3.attn_k', 'vision_model.3.attn_v', 'vision_model.3.attn_z', 'vision_model.3.ln2', 'vision_model.3.mlp.fc1', 'vision_model.3.mlp.act_fn', 'vision_model.3.mlp.fc2', 'vision_model.4.ln1', 'vision_model.4.attn_q', '