In [1]:
import os

# Set the GPUs you want to use (e.g., GPUs 4, 5, 6, 7)
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5,6,7"

In [2]:
import contextlib
import requests
import torch
import transformer_lens
import transformers
from tempfile import TemporaryDirectory
from PIL import Image
import circuitsvis as cv

es = contextlib.ExitStack()
es.enter_context(torch.inference_mode())

model_name = "llava-hf/llava-1.5-7b-hf"
model = transformers.AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map=0,
)
processor = transformers.AutoProcessor.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


In [15]:
def get_hooked_model(model, tokenizer):
    with TemporaryDirectory() as model_name:
        model.config.save_pretrained(model_name)
        cfg = transformer_lens.loading.get_pretrained_model_config(
            model_name,
            device=1,
            dtype=model.dtype,
        )
        state_dict = transformer_lens.loading.get_pretrained_state_dict(
            model_name,
            cfg,
            model,
        )
    hooked_model = transformer_lens.HookedTransformer(cfg, tokenizer)
    hooked_model.load_and_process_state_dict(state_dict)
    return hooked_model

def get_input_embeds(input_ids, pixel_values):
    input_embeds = model.get_input_embeddings()(input_ids)
    image_features = model.get_image_features(
        pixel_values,
        model.config.vision_feature_layer,
        model.config.vision_feature_select_strategy,
    )
    # Replace image_token_index (=32000) with image_feature tokens
    input_embeds[input_ids == model.config.image_token_index] = image_features
    return input_embeds

def run_model(prompt, image):
    inp = processor(image, prompt, return_tensors="pt").to(model.device)
    input_embeds = get_input_embeds(inp.input_ids, inp.pixel_values)
    return hooked_model.run_with_cache(input_embeds, start_at_layer=0)

In [4]:
hooked_model = get_hooked_model(model.language_model, processor.tokenizer)



In [16]:
conversation = [{"role": "user", "content": [
    {"type": "image"},
    {"type": "text", "text": "describe the image"},
]}]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

image = requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True)
image = Image.open(image.raw)

logits, activations = run_model(prompt, image)

# Test hooked model

In [17]:
import numpy as np
from typing import Tuple, Dict

def test_llava_setup(
    model, 
    processor, 
    hooked_model, 
    test_image_url: str = "https://raw.githubusercontent.com/llava-forge/llava-2/main/images/llava2_logo.png",
    test_prompt: str = "What do you see in this image?"
) -> Tuple[bool, Dict]:
    """
    Comprehensive test suite for LLaVA + TransformerLens setup.
    Returns (success_flag, diagnostics_dict)
    """
    diagnostics = {}

    conversation = [{"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": test_prompt},
    ]}]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    
    # 1. Test image processing
    try:
        image = Image.open(requests.get(test_image_url, stream=True).raw)
        inp = processor(image, prompt, return_tensors="pt").to(model.device)
        diagnostics['image_processing'] = {
            'success': True,
            'input_shape': inp.pixel_values.shape,
            'device': inp.pixel_values.device
        }
    except Exception as e:
        diagnostics['image_processing'] = {'success': False, 'error': str(e)}
        return False, diagnostics

    # 2. Test embedding generation
    try:
        input_embeds = get_input_embeds(inp.input_ids, inp.pixel_values)
        diagnostics['embeddings'] = {
            'success': True,
            'shape': input_embeds.shape,
            'dtype': input_embeds.dtype,
            'device': input_embeds.device,
            'non_zero': torch.any(input_embeds != 0).item(),
            'mean': input_embeds.mean().item(),
            'std': input_embeds.std().item()
        }
    except Exception as e:
        diagnostics['embeddings'] = {'success': False, 'error': str(e)}
        return False, diagnostics

    # 3. Test hooked model forward pass
    try:
        logits, cache = hooked_model.run_with_cache(
            input_embeds,
            start_at_layer=0,
            return_type='logits'
        )
        diagnostics['hooked_forward'] = {
            'success': True,
            'logits_shape': logits.shape,
            'cache_keys': list(cache.keys()),
            'num_layers': len([k for k in cache.keys() if 'pattern' in k])
        }
    except Exception as e:
        diagnostics['hooked_forward'] = {'success': False, 'error': str(e)}
        return False, diagnostics

    # 4. Basic sanity checks
    checks = {
        'embedding_dim_match': input_embeds.shape[-1] == hooked_model.cfg.d_model,
        'layer_count_match': len([k for k in cache.keys() if 'pattern' in k]) == hooked_model.cfg.n_layers,
        'output_vocab_size': logits.shape[-1] == hooked_model.cfg.d_vocab
    }
    diagnostics['sanity_checks'] = checks

    success = all(diagnostics['sanity_checks'].values())
    
    return success, diagnostics

def visualize_attention(cache, layer: int = 0, head: int = 0):
    """
    Visualize attention patterns for a specific layer and head
    """
    attention_pattern = cache[f'pattern.{layer}.{head}'][0]  # Get first batch item
    return cv.attention.attention_patterns(
        tokens=hooked_model.to_str_tokens(attention_pattern),
        attention=attention_pattern
    )

def run_verification_test(image_url: str, prompt: str):
    """
    Run the full verification suite and print results
    """
    success, diagnostics = test_llava_setup(model, processor, hooked_model, image_url, prompt)
    
    print(f"🔍 LLaVA + TransformerLens Verification Results:")
    print(f"Overall Success: {'✅' if success else '❌'}\n")
    
    for stage, results in diagnostics.items():
        print(f"\n{stage.upper()}:")
        if isinstance(results, dict):
            for key, value in results.items():
                print(f"  {key}: {value}")
        else:
            print(f"  {results}")
            
    return success, diagnostics

In [18]:
success, diagnostics = run_verification_test(
    "https://llava-vl.github.io/static/images/view.jpg",
    "What do you see in this image?"
)

🔍 LLaVA + TransformerLens Verification Results:
Overall Success: ✅


IMAGE_PROCESSING:
  success: True
  input_shape: torch.Size([1, 3, 336, 336])
  device: cuda:0

EMBEDDINGS:
  success: True
  shape: torch.Size([1, 596, 4096])
  dtype: torch.float16
  device: cuda:0
  non_zero: True
  mean: -0.00444793701171875
  std: 0.9052734375

HOOKED_FORWARD:
  success: True
  logits_shape: torch.Size([1, 596, 32064])
  cache_keys: ['blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pr

# 