In [2]:
import os

# Set the GPUs you want to use (e.g., GPUs 4, 5, 6, 7)
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5,6,7"

In [3]:
import contextlib
import requests
import torch
import transformer_lens
import transformers
from tempfile import TemporaryDirectory
from PIL import Image
import circuitsvis as cv

es = contextlib.ExitStack()
es.enter_context(torch.inference_mode())

model_name = "llava-hf/llava-1.5-7b-hf"
model = transformers.AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map=0,
)
processor = transformers.AutoProcessor.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


In [4]:
def get_hooked_model(model, tokenizer):
    with TemporaryDirectory() as model_name:
        model.config.save_pretrained(model_name)
        cfg = transformer_lens.loading.get_pretrained_model_config(
            model_name,
            device=1,
            dtype=model.dtype,
        )
        state_dict = transformer_lens.loading.get_pretrained_state_dict(
            model_name,
            cfg,
            model,
        )
    hooked_model = transformer_lens.HookedTransformer(cfg, tokenizer)
    hooked_model.load_and_process_state_dict(state_dict)
    return hooked_model

def get_input_embeds(input_ids, pixel_values):
    input_embeds = model.get_input_embeddings()(input_ids)
    image_features = model.get_image_features(
        pixel_values,
        model.config.vision_feature_layer,
        model.config.vision_feature_select_strategy,
    )
    # Ensure input_ids is on the same device as input_embeds
    input_ids = input_ids.to(input_embeds.device)
    # Replace image_token_index (=32000) with image_feature tokens
    input_embeds[input_ids == model.config.image_token_index] = image_features
    return input_embeds

def run_model(prompt, image):
    inp = processor(image, prompt, return_tensors="pt").to(model.device)
    input_embeds = get_input_embeds(inp.input_ids, inp.pixel_values)
    return hooked_model.run_with_cache(input_embeds, start_at_layer=0)

In [6]:
# Verify text generation
conv = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "describe the image"},
        ]},
    ]
prompt = processor.apply_chat_template(conv, add_generation_prompt=True)

image = requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True)
image = Image.open(image.raw)
inp = processor(image, prompt, return_tensors="pt").to(model.device)

output = model(
    **inp,
    output_hidden_states=True,
    num_logits_to_keep=1,
    use_cache=False
)

In [20]:
hooked_model = get_hooked_model(model.language_model, processor.tokenizer)



In [21]:
conversation = [{"role": "user", "content": [
    {"type": "image"},
    {"type": "text", "text": "describe the image"},
]}]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

image = requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True)
image = Image.open(image.raw)

logits, activations = run_model(prompt, image)

In [None]:
print(activations)

# Test hooked model

In [5]:
import numpy as np
from typing import Tuple, Dict

def test_llava_setup(
    model, 
    processor, 
    hooked_model, 
    test_image_url: str = "https://raw.githubusercontent.com/llava-forge/llava-2/main/images/llava2_logo.png",
    test_prompt: str = "What do you see in this image?"
) -> Tuple[bool, Dict]:
    """
    Comprehensive test suite for LLaVA + TransformerLens setup.
    Returns (success_flag, diagnostics_dict)
    """
    diagnostics = {}

    conversation = [{"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": test_prompt},
    ]}]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    
    # 1. Test image processing
    try:
        image = Image.open(requests.get(test_image_url, stream=True).raw)
        inp = processor(image, prompt, return_tensors="pt").to(model.device)
        diagnostics['image_processing'] = {
            'success': True,
            'input_shape': inp.pixel_values.shape,
            'device': inp.pixel_values.device
        }
    except Exception as e:
        diagnostics['image_processing'] = {'success': False, 'error': str(e)}
        return False, diagnostics

    # 2. Test embedding generation
    try:
        input_embeds = get_input_embeds(inp.input_ids, inp.pixel_values)
        diagnostics['embeddings'] = {
            'success': True,
            'shape': input_embeds.shape,
            'dtype': input_embeds.dtype,
            'device': input_embeds.device,
            'non_zero': torch.any(input_embeds != 0).item(),
            'mean': input_embeds.mean().item(),
            'std': input_embeds.std().item()
        }
    except Exception as e:
        diagnostics['embeddings'] = {'success': False, 'error': str(e)}
        return False, diagnostics

    # 3. Test hooked model forward pass
    try:
        logits, cache = hooked_model.run_with_cache(
            input_embeds,
            start_at_layer=0,
            return_type='logits'
        )
        diagnostics['hooked_forward'] = {
            'success': True,
            'logits_shape': logits.shape,
            'cache_keys': list(cache.keys()),
            'num_layers': len([k for k in cache.keys() if 'pattern' in k])
        }
    except Exception as e:
        diagnostics['hooked_forward'] = {'success': False, 'error': str(e)}
        return False, diagnostics

    # 4. Basic sanity checks
    checks = {
        'embedding_dim_match': input_embeds.shape[-1] == hooked_model.cfg.d_model,
        'layer_count_match': len([k for k in cache.keys() if 'pattern' in k]) == hooked_model.cfg.n_layers,
        'output_vocab_size': logits.shape[-1] == hooked_model.cfg.d_vocab
    }
    diagnostics['sanity_checks'] = checks

    success = all(diagnostics['sanity_checks'].values())
    
    return success, diagnostics

def visualize_attention(cache, layer: int = 0, head: int = 0):
    """
    Visualize attention patterns for a specific layer and head
    """
    attention_pattern = cache[f'pattern.{layer}.{head}'][0]  # Get first batch item
    return cv.attention.attention_patterns(
        tokens=hooked_model.to_str_tokens(attention_pattern),
        attention=attention_pattern
    )

def run_verification_test(image_url: str, prompt: str):
    """
    Run the full verification suite and print results
    """
    success, diagnostics = test_llava_setup(model, processor, hooked_model, image_url, prompt)
    
    print(f"🔍 LLaVA + TransformerLens Verification Results:")
    print(f"Overall Success: {'✅' if success else '❌'}\n")
    
    for stage, results in diagnostics.items():
        print(f"\n{stage.upper()}:")
        if isinstance(results, dict):
            for key, value in results.items():
                print(f"  {key}: {value}")
        else:
            print(f"  {results}")
            
    return success, diagnostics

In [6]:
success, diagnostics = run_verification_test(
    "https://llava-vl.github.io/static/images/view.jpg",
    "What do you see in this image?"
)

🔍 LLaVA + TransformerLens Verification Results:
Overall Success: ✅


IMAGE_PROCESSING:
  success: True
  input_shape: torch.Size([1, 3, 336, 336])
  device: cuda:0

EMBEDDINGS:
  success: True
  shape: torch.Size([1, 596, 4096])
  dtype: torch.float16
  device: cuda:0
  non_zero: True
  mean: -0.00444793701171875
  std: 0.9052734375

HOOKED_FORWARD:
  success: True
  logits_shape: torch.Size([1, 596, 32064])
  cache_keys: ['blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pr

# Exploratory Data Analysis

In [12]:
# Clear GPU memory
torch.cuda.empty_cache()

In [23]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Optional

def analyze_attention_patterns(
    cache,
    hooked_model,
    input_ids,
    num_image_tokens: int,
    layer: Optional[int] = None,
    head: Optional[int] = None
) -> Dict:
    """
    Analyze attention patterns between text and image tokens.
    Adapted for LLaVA's cache structure.
    """
    # Get token strings for visualization
    tokens = hooked_model.to_str_tokens(input_ids[0])
    
    # Find image token positions (usually at the start for LLaVA)
    image_token_positions = torch.where(input_ids[0] == model.config.image_token_index)[0]
    text_positions = [i for i in range(len(tokens)) if i not in image_token_positions]
    
    layers = [layer] if layer is not None else range(hooked_model.cfg.n_layers)
    heads = [head] if head is not None else range(hooked_model.cfg.n_heads)
    
    attention_stats = {
        'text_to_image': torch.zeros(len(layers), len(heads)).to(cache['blocks.0.attn.hook_pattern'].device),
        'image_to_text': torch.zeros(len(layers), len(heads)).to(cache['blocks.0.attn.hook_pattern'].device),
        'self_attention': torch.zeros(len(layers), len(heads)).to(cache['blocks.0.attn.hook_pattern'].device)
    }
    
    for l in layers:
        for h in heads:
            # Access pattern using LLaVA's cache structure
            pattern = cache[f'blocks.{l}.attn.hook_pattern'][0]  # [seq_len, seq_len]
            
            # Calculate average attention scores
            # Move indicies to same device as pattern
            text_positions = torch.tensor(text_positions).to(pattern.device)
            image_token_positions = torch.tensor(image_token_positions).to(pattern.device)
            text_to_image = pattern[text_positions][:, image_token_positions].mean()
            image_to_text = pattern[image_token_positions][:, text_positions].mean()
            self_attention = pattern.diagonal().mean()
            
            attention_stats['text_to_image'][l, h] = text_to_image
            attention_stats['image_to_text'][l, h] = image_to_text
            attention_stats['self_attention'][l, h] = self_attention
    
    return attention_stats

def run_attention_analysis(prompt: str, image, model, processor, hooked_model):
    """
    Run complete attention analysis pipeline with LLaVA's conversation template
    """
    # Apply LLaVA's conversation template
    conversation = [{"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": prompt},
    ]}]
    formatted_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    
    # Process input
    inp = processor(image, formatted_prompt, return_tensors="pt").to(model.device)
    input_embeds = get_input_embeds(inp.input_ids, inp.pixel_values)
    
    # Run model with cache
    _, cache = hooked_model.run_with_cache(
        input_embeds,
        start_at_layer=0,
        return_type='logits'
    )
    
    # Count image tokens
    num_image_tokens = (inp.input_ids[0] == model.config.image_token_index).sum().item()
    
    # Analyze attention patterns
    attention_stats = analyze_attention_patterns(
        cache,
        hooked_model,
        inp.input_ids,
        num_image_tokens
    )
    
    # Visualizations
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    metrics = ['text_to_image', 'image_to_text', 'self_attention']
    for idx, metric in enumerate(metrics):
        sns.heatmap(
            attention_stats[metric],
            ax=axes[idx],
            cmap='viridis',
            xticklabels=[f'H{i}' for i in range(attention_stats[metric].shape[1])],
            yticklabels=[f'L{i}' for i in range(attention_stats[metric].shape[0])]
        )
        axes[idx].set_title(f'{metric.replace("_", " ").title()}')
        axes[idx].set_xlabel('Heads')
        axes[idx].set_ylabel('Layers')
    
    plt.tight_layout()
    plt.show()
    
    return cache, attention_stats

def visualize_layer_attention(
    cache,
    hooked_model,
    input_ids,
    layer: int,
    head: int,
    tokens: Optional[List[str]] = None
):
    """
    Visualize attention patterns for a specific layer and head
    """
    if tokens is None:
        tokens = hooked_model.to_str_tokens(input_ids[0])
    
    pattern = cache[f'blocks.{layer}.attn.hook_pattern'][0]
    
    plt.figure(figsize=(12, 8))
    sns.heatmap(
        pattern.cpu(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='viridis'
    )
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.title(f'Attention Pattern (Layer {layer}, Head {head})')
    plt.tight_layout()
    plt.show()

In [18]:
torch.cuda.empty_cache()

import gc

# Delete all variables that are using CUDA memory
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            del obj
    except Exception:
        pass

# Run garbage collector
gc.collect()
torch.cuda.empty_cache()

  return isinstance(obj, torch.Tensor)
  if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):


In [24]:
image_url = "https://llava-vl.github.io/static/images/view.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)

# Run the analysis
prompt = "What do you see in this image?"
cache, attention_stats = run_attention_analysis(prompt, image, model, processor, hooked_model)

  image_token_positions = torch.tensor(image_token_positions).to(pattern.device)
  text_positions = torch.tensor(text_positions).to(pattern.device)
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [3630,0,0], thread: [32,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [3630,0,0], thread: [33,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [3630,0,0], thread: [34,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [3630,0,0], thread: [35,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [3630,0,0], thread: [36,0,0] Assertion `-sizes[i] <= index && ind

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
