# Word Attribution Analysis

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jtooates/blind_lm/blob/main/word_attribution_analysis.ipynb)

This notebook analyzes which parts of the RGB latent are responsible for generating each word.

**Method**: Gradient-based attribution - compute ∂(logit)/∂(latent) to see which pixels influence each word's prediction.

## 1. Environment Setup

In [None]:
# Check GPU availability
import torch
print("="*70)
print("GPU Check")
print("="*70)
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("⚠️  Note: No GPU found. Running on CPU (slower but will work).")
print("="*70)

In [None]:
# Install dependencies
print("Installing dependencies...")
!pip install -q transformers torch matplotlib numpy

# Suppress tokenizer warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("✓ Dependencies installed")

In [None]:
# Clone repository (for Colab)
if IN_COLAB:
    import os
    repo_dir = 'blind_lm'
    repo_url = 'https://github.com/jtooates/blind_lm.git'
    
    if os.path.exists(repo_dir):
        print("Repository already exists. Pulling latest changes...")
        %cd blind_lm
        !git pull origin main
        print("✓ Repository updated")
    else:
        print("Cloning repository...")
        !git clone {repo_url}
        %cd blind_lm
        print("✓ Repository cloned")
else:
    print("✓ Skipping clone (running locally)")

In [None]:
# Mount Google Drive (for Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    print("✓ Google Drive mounted")
    print("✓ Checkpoints location: /content/drive/MyDrive/blind_lm_outputs/")
except:
    IN_COLAB = False
    print("✓ Running locally")

In [None]:
## 2. Load Model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import sys
import os

# Add phase1 to path (handle both Colab and local)
if IN_COLAB:
    # In Colab, we're already in /content/blind_lm after cloning
    sys.path.insert(0, '/content/blind_lm/phase1')
    os.chdir('/content/blind_lm')  # Ensure we're in repo root
else:
    # Local: assume we're running from repo root
    sys.path.insert(0, 'phase1')

from model import create_model
from decoder_nonar import create_decoder
from transformers import AutoTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Determine checkpoint location based on environment
if IN_COLAB:
    checkpoint_dir = Path('/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce')
else:
    checkpoint_dir = Path('outputs/phase1_rgb_infonce')

checkpoint_path = checkpoint_dir / 'checkpoint_latest.pt'
config_path = checkpoint_dir / 'config.json'

# Check if files exist
if not checkpoint_path.exists():
    print(f"❌ Checkpoint not found at {checkpoint_path}")
    print("\nPlease ensure you have a trained model.")
    if IN_COLAB:
        print("Expected location: /content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce/")
        print("You can train a model using phase1_colab_training.ipynb")
    else:
        print("Expected location: outputs/phase1_rgb_infonce/")
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

if not config_path.exists():
    print(f"❌ Config not found at {config_path}")
    raise FileNotFoundError(f"Config not found: {config_path}")

print(f"✓ Found checkpoint: {checkpoint_path}")
print(f"✓ Found config: {config_path}")

# Load config
with open(config_path) as f:
    config = json.load(f)

print(f"\nModel configuration:")
print(f"  Channels: {config['model']['num_channels']} (RGB)")
print(f"  Grid size: {config['model']['grid_size']}x{config['model']['grid_size']}")
print(f"  Hidden size: {config['model']['hidden_size']}")

# Create models
encoder = create_model(config['model']).to(device)
decoder = create_decoder(config['decoder']).to(device)

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
encoder.eval()
decoder.eval()

print(f"\n✓ Loaded checkpoint from step {checkpoint['step']}")

# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

print("✓ Models loaded and ready!")

In [None]:
## 3. Attribution Functions

## Load Model

In [None]:
## 4. Interactive Analysis

Enter a sentence to analyze which parts of the latent are responsible for each word.

## Attribution Functions

In [None]:
## 5. Analysis Tips

**What to look for:**

1. **Color words** (red, blue, yellow): Do they activate specific colored regions?
2. **Spatial words** (under, right, left): Do they activate specific spatial patterns?
3. **Object words** (cube, block, box): Do they have consistent activation patterns?
4. **Function words** (the, is): Typically should have low/diffuse activation

**Heatmap interpretation:**
- **Bright (yellow/white)**: High importance - this pixel strongly influenced the word
- **Dark (red/black)**: Low importance - this pixel didn't affect the word much

**Questions to explore:**
- Do different occurrences of "the" activate different regions?
- Do color words consistently activate the same colored blobs?
- Do spatial relations show positional patterns in the latent?

## 6. Try More Sentences

Run the cell below multiple times with different sentences to explore patterns.

In [None]:
# Prompt for sentence
sentence = input("Enter a sentence to analyze: ")

if not sentence.strip():
    sentence = "the red cube is under the yellow block"  # Default example
    print(f"Using default: {sentence}")

# Generate visualization
print("\n" + "="*70)
fig, tokens, heatmaps = visualize_word_attributions(encoder, decoder, tokenizer, sentence, device)
print("="*70)

plt.show()

## Analysis Tips

**What to look for:**

1. **Color words** (red, blue, yellow): Do they activate specific colored regions?
2. **Spatial words** (under, right, left): Do they activate specific spatial patterns?
3. **Object words** (cube, block, box): Do they have consistent activation patterns?
4. **Function words** (the, is): Typically should have low/diffuse activation

**Heatmap interpretation:**
- **Bright (yellow/white)**: High importance - this pixel strongly influenced the word
- **Dark (red/black)**: Low importance - this pixel didn't affect the word much

**Questions to explore:**
- Do different occurrences of "the" activate different regions?
- Do color words consistently activate the same colored blobs?
- Do spatial relations show positional patterns in the latent?

## Try More Sentences

Run the cell below multiple times with different sentences to explore patterns.

In [None]:
# Example sentences to try:
examples = [
    "the red cube is under the yellow block",
    "the blue box is right of the green sphere",
    "the yellow block is on the red cube",
    "the green cube is left of the blue box",
]

print("Example sentences you can try:")
for i, ex in enumerate(examples, 1):
    print(f"{i}. {ex}")

print("\n" + "="*70)
sentence = input("Enter a sentence (or leave blank for random example): ")

if not sentence.strip():
    import random
    sentence = random.choice(examples)
    print(f"Using random example: {sentence}")

fig, tokens, heatmaps = visualize_word_attributions(encoder, decoder, tokenizer, sentence, device)
print("="*70)
plt.show()