# VLM vs LLM Attention: Understanding Vision-Language Model Attention

### Problem Statement

Compare how attention works in **Vision-Language Models (VLMs)** versus **pure Language Models (LLMs)**. This notebook focuses on the **LLaVA-style architecture**, which is the most common approach in modern VLMs.

### Background: Why VLM Attention is Different

In pure LLMs:
- Input: Text tokens only
- Attention: Causal (each token attends only to previous tokens)
- Mask: Lower triangular

In Vision-Language Models:
- Input: Image patches + text tokens
- Attention: Mixed (bidirectional for image, causal for text)
- Mask: NOT simply lower triangular

### Learning Path

1. **Part 1**: LLM Attention Review - Quick recap of causal self-attention
2. **Part 2**: Vision Transformer (ViT) - Patch embedding and bidirectional attention
3. **Part 3**: LLaVA-Style VLM - How image and text tokens interact
4. **Part 4**: Attention Visualization - Real patterns from pretrained models
5. **Part 5**: Side-by-Side Comparison - Visual comparison
6. **Part 6**: Interview Questions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from typing import Optional, Tuple

# For pretrained models
try:
    from transformers import CLIPModel, CLIPProcessor
    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False
    print("transformers not installed - pretrained visualization will be skipped")

try:
    from PIL import Image
    import requests
    from io import BytesIO
    HAS_PIL = True
except ImportError:
    HAS_PIL = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: LLM Attention Review

In pure language models (GPT, LLaMA, etc.), attention is **causal**: each token can only attend to itself and previous tokens. This is enforced with a lower-triangular mask.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$$

where $M$ is the causal mask with $-\infty$ in the upper triangle.

In [None]:
def create_causal_mask(seq_len: int, device=None) -> torch.Tensor:
    """
    Create a causal (lower-triangular) attention mask.
    
    Returns:
        mask: Boolean tensor (seq_len, seq_len)
              True = masked (cannot attend), False = can attend
    
    Hint: Use torch.triu() with diagonal=1 to create upper triangular mask
    """
    # TODO: Implement causal mask
    # The mask should have True in the upper triangle (positions that cannot be attended to)
    ...


def scaled_dot_product_attention(
    q: torch.Tensor, 
    k: torch.Tensor, 
    v: torch.Tensor, 
    mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute scaled dot-product attention.
    
    Args:
        q: Query (batch, seq_q, d_k)
        k: Key (batch, seq_k, d_k)
        v: Value (batch, seq_k, d_v)
        mask: Boolean mask where True = masked
    
    Returns:
        output: (batch, seq_q, d_v)
        attn_weights: (batch, seq_q, seq_k)
    """
    d_k = q.shape[-1]
    
    # TODO: Compute attention scores: Q @ K^T / sqrt(d_k)
    scores = ...
    
    # TODO: Apply mask if provided (set masked positions to -inf)
    if mask is not None:
        scores = ...
    
    # TODO: Apply softmax and compute output
    attn_weights = ...
    output = ...
    
    return output, attn_weights

In [None]:
# Test causal mask
seq_len = 8
causal_mask = create_causal_mask(seq_len)

print("Causal Mask for LLM (True = cannot attend):")
print(causal_mask.int())

# Verify
assert causal_mask.shape == (seq_len, seq_len), f"Wrong shape: {causal_mask.shape}"
assert causal_mask[0, 0] == False, "Position (0,0) should be False (can attend to self)"
assert causal_mask[0, 1] == True, "Position (0,1) should be True (cannot attend to future)"
assert causal_mask[7, 0] == False, "Position (7,0) should be False (can attend to past)"
print("\n✓ Causal mask test passed!")

In [None]:
# Demonstrate causal attention pattern
torch.manual_seed(42)
batch_size = 1
d_model = 64

q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)

_, attn_weights_causal = scaled_dot_product_attention(q, k, v, mask=causal_mask)

# Verify upper triangle is zero
upper_triangle = attn_weights_causal[0].triu(diagonal=1)
assert torch.allclose(upper_triangle, torch.zeros_like(upper_triangle), atol=1e-6), \
    "Upper triangle should be zero!"

# Plot
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(attn_weights_causal[0].detach().numpy(), cmap='Blues', vmin=0, vmax=0.5)
ax.set_title('LLM: Causal Self-Attention', fontsize=12)
ax.set_xlabel('Key Position (past → future)')
ax.set_ylabel('Query Position')
plt.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout()
plt.show()

print("\nKey insight: Upper triangle is zero (cannot attend to future tokens)")

## Part 2: Vision Transformer (ViT) Basics

In Vision Transformers, images are converted to a sequence of **patch embeddings**, then processed with **bidirectional** self-attention (no causal mask).

### Key Steps:
1. **Patch Embedding**: Split image into patches, flatten, project to embedding dim
2. **Position Encoding**: Add 2D positional information
3. **CLS Token**: Prepend a learnable classification token
4. **Bidirectional Attention**: All patches attend to all patches

In [None]:
class PatchEmbedding(nn.Module):
    """
    Convert an image into a sequence of patch embeddings.
    
    Image (B, C, H, W) -> Patches (B, N_patches, embed_dim)
    
    For a 224x224 image with 16x16 patches:
    - N_patches = (224/16) * (224/16) = 14 * 14 = 196
    """
    
    def __init__(
        self, 
        img_size: int = 224, 
        patch_size: int = 16, 
        in_channels: int = 3, 
        embed_dim: int = 768
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # TODO: Create a Conv2d layer that acts as patch extraction and projection
        # Hint: Use kernel_size=patch_size and stride=patch_size for non-overlapping patches
        # This is more efficient than manually splitting and projecting
        self.proj = ...
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Image tensor (batch, channels, height, width)
        Returns:
            patches: (batch, n_patches, embed_dim)
        """
        # TODO: Apply projection, flatten spatial dims, transpose
        # 1. Apply conv: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        # 2. Flatten: (B, embed_dim, H/P, W/P) -> (B, embed_dim, N_patches)
        # 3. Transpose: (B, embed_dim, N_patches) -> (B, N_patches, embed_dim)
        ...

In [None]:
# Test patch embedding
torch.manual_seed(42)

patch_embed = PatchEmbedding(img_size=224, patch_size=16, embed_dim=768)

# Simulate an image
dummy_image = torch.randn(1, 3, 224, 224)
patches = patch_embed(dummy_image)

print(f"Input image shape: {dummy_image.shape}")
print(f"Output patches shape: {patches.shape}")
print(f"Number of patches: {patches.shape[1]} = 14 x 14 grid")
print(f"Each patch embedding dim: {patches.shape[2]}")

assert patches.shape == (1, 196, 768), f"Wrong output shape: {patches.shape}"
print("\n✓ Patch embedding test passed!")

In [None]:
class ViTAttention(nn.Module):
    """
    Multi-head self-attention for Vision Transformer.
    
    Key difference from LLM attention: NO CAUSAL MASK
    All patches can attend to all other patches (bidirectional).
    """
    
    def __init__(self, embed_dim: int = 768, num_heads: int = 12):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Combined Q, K, V projection for efficiency
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, seq_len, embed_dim)
        Returns:
            output: (batch, seq_len, embed_dim)
            attn_weights: (batch, num_heads, seq_len, seq_len)
        """
        B, N, C = x.shape
        
        # TODO: Compute Q, K, V from single projection
        # 1. Apply qkv projection: (B, N, C) -> (B, N, 3*C)
        # 2. Reshape to (B, N, 3, num_heads, head_dim)
        # 3. Permute to (3, B, num_heads, N, head_dim)
        # 4. Split into q, k, v
        qkv = ...
        q, k, v = ...
        
        # TODO: Compute scaled dot-product attention (NO MASK!)
        # This is the key difference from LLM attention - bidirectional
        scores = ...
        attn_weights = ...
        
        # TODO: Apply attention to values and reshape back
        out = ...
        out = self.proj(out)
        
        return out, attn_weights

In [None]:
# Demonstrate ViT bidirectional attention
torch.manual_seed(42)

# Use smaller dimensions for visualization
n_patches = 16  # 4x4 grid for easy visualization
embed_dim = 64
num_heads = 4

vit_attn = ViTAttention(embed_dim=embed_dim, num_heads=num_heads)
patch_tokens = torch.randn(1, n_patches, embed_dim)

_, vit_attn_weights = vit_attn(patch_tokens)

# Plot attention from head 0
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(vit_attn_weights[0, 0].detach().numpy(), cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('ViT: Bidirectional Attention (Head 0)', fontsize=12)
ax.set_xlabel('Key Patch')
ax.set_ylabel('Query Patch')
plt.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout()
plt.show()

print("\nKey insight: NO zeros in the matrix - all patches attend to all patches!")
print("This is bidirectional attention (no causal mask).")

## Part 3: LLaVA-Style VLM Architecture

### How LLaVA Works

LLaVA (Large Language and Vision Assistant) uses a simple but effective approach:

1. **Vision Encoder**: Pretrained ViT extracts visual features
2. **Projection Layer**: Maps visual features to LLM embedding space
3. **Concatenation**: `[Image Tokens] + [Text Tokens]`
4. **LLM Processing**: Single decoder-only transformer processes both

### The Key Insight: VLM Attention Mask

The attention mask in a VLM is **NOT simply causal**:

```
                    Image Tokens    Text Tokens
                    [I0 I1 I2 I3]   [T0 T1 T2 T3]
Image   I0          [ ✓  ✓  ✓  ✓     ✗  ✗  ✗  ✗ ]
Tokens  I1          [ ✓  ✓  ✓  ✓     ✗  ✗  ✗  ✗ ]
        I2          [ ✓  ✓  ✓  ✓     ✗  ✗  ✗  ✗ ]
        I3          [ ✓  ✓  ✓  ✓     ✗  ✗  ✗  ✗ ]
Text    T0          [ ✓  ✓  ✓  ✓     ✓  ✗  ✗  ✗ ]
Tokens  T1          [ ✓  ✓  ✓  ✓     ✓  ✓  ✗  ✗ ]
        T2          [ ✓  ✓  ✓  ✓     ✓  ✓  ✓  ✗ ]
        T3          [ ✓  ✓  ✓  ✓     ✓  ✓  ✓  ✓ ]
```

- **Image-to-Image**: Bidirectional (all can see all)
- **Image-to-Text**: Cannot attend (image comes first, doesn't "see" future text)
- **Text-to-Image**: Can attend (text can see all image tokens)
- **Text-to-Text**: Causal (each text token sees only past text)

In [None]:
def create_vlm_attention_mask(n_image: int, n_text: int, device=None) -> torch.Tensor:
    """
    Create attention mask for Vision-Language Model (LLaVA-style).
    
    The mask has a specific structure:
    - Image tokens: bidirectional among themselves (can see all image tokens)
    - Image tokens: CANNOT attend to text tokens (they come first in sequence)
    - Text tokens: CAN attend to all image tokens
    - Text tokens: causal among themselves (can only see past text)
    
    Args:
        n_image: Number of image tokens (patches)
        n_text: Number of text tokens
        device: Device to create tensor on
    
    Returns:
        mask: Boolean tensor (n_image + n_text, n_image + n_text)
              True = masked (cannot attend), False = can attend
    
    The mask should look like:
    
        Image    Text
    Image [0 0 0 | 1 1 1]   <- Image sees image (0), not text (1)
          [0 0 0 | 1 1 1]
          ------+------
    Text  [0 0 0 | 0 1 1]   <- Text sees image (0), causal text
          [0 0 0 | 0 0 1]
          [0 0 0 | 0 0 0]
    """
    total = n_image + n_text
    
    # TODO: Create the VLM attention mask
    # Start with all zeros (can attend)
    mask = torch.zeros(total, total, dtype=torch.bool, device=device)
    
    # TODO: Image-to-text quadrant: MASKED (top-right)
    # Image tokens cannot see text tokens (they come before text)
    ...
    
    # TODO: Text-to-text quadrant: CAUSAL (bottom-right)
    # Text tokens have causal attention among themselves
    ...
    
    # Note: Image-to-image (top-left) stays zeros (bidirectional)
    # Note: Text-to-image (bottom-left) stays zeros (can attend)
    
    return mask

In [None]:
# Test VLM attention mask
n_image = 4
n_text = 4

vlm_mask = create_vlm_attention_mask(n_image, n_text)

print("VLM Attention Mask (1 = CANNOT attend):")
print(vlm_mask.int())

# Verify structure
# Top-left (image-to-image): all zeros
assert vlm_mask[:n_image, :n_image].sum() == 0, "Image-to-image should be all zeros (bidirectional)"

# Top-right (image-to-text): all ones
assert vlm_mask[:n_image, n_image:].sum() == n_image * n_text, "Image-to-text should be all ones (masked)"

# Bottom-left (text-to-image): all zeros
assert vlm_mask[n_image:, :n_image].sum() == 0, "Text-to-image should be all zeros (can attend)"

# Bottom-right (text-to-text): causal (upper triangle)
expected_causal = n_text * (n_text - 1) // 2  # Upper triangle count
assert vlm_mask[n_image:, n_image:].sum() == expected_causal, "Text-to-text should be causal"

print("\n✓ VLM mask test passed!")

In [None]:
# Visualize the mask structure
fig, ax = plt.subplots(figsize=(7, 6))

# Convert to float for visualization (1 = masked/red, 0 = can attend/white)
mask_viz = vlm_mask.float().numpy()

im = ax.imshow(mask_viz, cmap='RdYlGn_r', vmin=0, vmax=1)
ax.set_title('VLM Attention Mask Structure', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')

# Add grid lines to separate image and text regions
ax.axhline(y=n_image - 0.5, color='black', linewidth=2)
ax.axvline(x=n_image - 0.5, color='black', linewidth=2)

# Add labels
ax.text(n_image/2 - 0.5, -0.8, 'Image', ha='center', fontsize=10, fontweight='bold')
ax.text(n_image + n_text/2 - 0.5, -0.8, 'Text', ha='center', fontsize=10, fontweight='bold')
ax.text(-1.2, n_image/2 - 0.5, 'Image', va='center', fontsize=10, fontweight='bold', rotation=90)
ax.text(-1.2, n_image + n_text/2 - 0.5, 'Text', va='center', fontsize=10, fontweight='bold', rotation=90)

plt.colorbar(im, ax=ax, shrink=0.8, label='Masked')
plt.tight_layout()
plt.show()

In [None]:
class VisionProjection(nn.Module):
    """
    Project visual features from ViT to LLM embedding space.
    
    In LLaVA, this is typically a simple MLP:
    vision_dim -> hidden_dim -> llm_dim
    """
    
    def __init__(self, vision_dim: int, llm_dim: int, hidden_dim: int = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = llm_dim
        
        # TODO: Create a 2-layer MLP with GELU activation
        # vision_dim -> hidden_dim -> GELU -> llm_dim
        self.proj = ...
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Project vision features to LLM space."""
        return self.proj(x)

In [None]:
class VLMAttention(nn.Module):
    """
    Vision-Language Model attention with proper masking.
    
    Processes concatenated [image_tokens, text_tokens] with:
    - Bidirectional attention among image tokens
    - Image tokens cannot attend to text tokens
    - Text tokens can attend to all image tokens
    - Causal attention among text tokens
    """
    
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(
        self, 
        x: torch.Tensor, 
        n_image_tokens: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Concatenated [image_tokens, text_tokens] (batch, seq_len, d_model)
            n_image_tokens: Number of image tokens (to construct proper mask)
        
        Returns:
            output: (batch, seq_len, d_model)
            attn_weights: (batch, num_heads, seq_len, seq_len)
        """
        B, N, _ = x.shape
        n_text_tokens = N - n_image_tokens
        
        # TODO: Project to Q, K, V and reshape for multi-head attention
        Q = ...
        K = ...
        V = ...
        
        # TODO: Compute attention scores
        scores = ...
        
        # TODO: Create and apply VLM mask
        mask = create_vlm_attention_mask(n_image_tokens, n_text_tokens, device=x.device)
        scores = ...
        
        # TODO: Apply softmax and compute output
        attn_weights = ...
        out = ...
        out = self.W_o(out)
        
        return out, attn_weights

In [None]:
# Test VLM attention
torch.manual_seed(42)

d_model = 64
num_heads = 4
n_image = 9  # 3x3 patch grid
n_text = 6

vlm_attn = VLMAttention(d_model, num_heads)

# Simulate concatenated image + text tokens
image_tokens = torch.randn(1, n_image, d_model)
text_tokens = torch.randn(1, n_text, d_model)
combined = torch.cat([image_tokens, text_tokens], dim=1)

print(f"Image tokens: {n_image}")
print(f"Text tokens: {n_text}")
print(f"Combined sequence: {combined.shape}")

output, attn_weights = vlm_attn(combined, n_image_tokens=n_image)

print(f"\nOutput shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

assert output.shape == combined.shape, f"Output shape mismatch: {output.shape}"
print("\n✓ VLM attention test passed!")

In [None]:
# Visualize VLM attention pattern
fig, ax = plt.subplots(figsize=(8, 7))

# Average across heads
attn_avg = attn_weights[0].mean(dim=0).detach().numpy()

im = ax.imshow(attn_avg, cmap='Blues', vmin=0, vmax=0.3)
ax.set_title('VLM Attention Pattern (LLaVA-style)', fontsize=12)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')

# Add grid lines
ax.axhline(y=n_image - 0.5, color='red', linewidth=2, linestyle='--')
ax.axvline(x=n_image - 0.5, color='red', linewidth=2, linestyle='--')

# Add labels
ax.text(n_image/2, -1, 'Image Tokens', ha='center', fontsize=10)
ax.text(n_image + n_text/2, -1, 'Text Tokens', ha='center', fontsize=10)

plt.colorbar(im, ax=ax, shrink=0.8, label='Attention Weight')
plt.tight_layout()
plt.show()

print("Observations:")
print("- Top-left (Image→Image): Dense attention, all patches see each other")
print("- Top-right (Image→Text): All zeros (image can't see future text)")
print("- Bottom-left (Text→Image): Dense attention, text sees all image patches")
print("- Bottom-right (Text→Text): Causal pattern (lower triangular)")

## Part 4: Attention Visualization with Pretrained Model

Let's use a pretrained CLIP model to visualize real ViT attention patterns.

In [None]:
if HAS_TRANSFORMERS and HAS_PIL:
    print("Loading CLIP model for attention visualization...")
    
    # Load a small CLIP model
    model_name = "openai/clip-vit-base-patch16"
    model = CLIPModel.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained(model_name)
    
    # Create a simple test image (gradient for easy visualization)
    size = 224
    x = np.linspace(0, 1, size)
    y = np.linspace(0, 1, size)
    xx, yy = np.meshgrid(x, y)
    
    # Create RGB image with gradient patterns
    r = (xx * 255).astype(np.uint8)
    g = (yy * 255).astype(np.uint8)
    b = ((xx + yy) / 2 * 255).astype(np.uint8)
    test_image = np.stack([r, g, b], axis=-1)
    pil_image = Image.fromarray(test_image)
    
    print(f"Test image size: {test_image.shape}")
    print(f"Model: {model_name}")
    
    # Process and get attention
    inputs = processor(images=pil_image, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.vision_model(
            inputs['pixel_values'], 
            output_attentions=True
        )
    
    attentions = outputs.attentions
    print(f"\nNumber of layers: {len(attentions)}")
    print(f"Attention shape per layer: {attentions[0].shape}")
else:
    print("Skipping pretrained visualization (transformers or PIL not available)")
    attentions = None

In [None]:
if attentions is not None:
    # Visualize attention from different layers
    fig, axes = plt.subplots(2, 3, figsize=(14, 10))
    
    layers_to_show = [0, 3, 6, 9, 11, -1]
    layer_names = ['Layer 1', 'Layer 4', 'Layer 7', 'Layer 10', 'Layer 12', 'Last Layer']
    
    for idx, (layer_idx, name) in enumerate(zip(layers_to_show, layer_names)):
        ax = axes[idx // 3, idx % 3]
        
        # Get attention for this layer, average across heads
        attn = attentions[layer_idx][0].mean(dim=0)
        
        im = ax.imshow(attn.numpy(), cmap='Blues', vmin=0, vmax=0.1)
        ax.set_title(f'{name}', fontsize=11)
        ax.set_xlabel('Key')
        ax.set_ylabel('Query')
    
    plt.suptitle('ViT Attention Patterns Across Layers (CLIP)', fontsize=13)
    plt.tight_layout()
    plt.show()
    
    print("\nObservations:")
    print("- Early layers: More uniform attention, local patterns")
    print("- Later layers: More specialized, global patterns")
    print("- CLS token (row/col 0): Aggregates information from all patches")

## Part 5: Side-by-Side Comparison

Compare all three attention patterns.

In [None]:
# Create comparison visualization
torch.manual_seed(42)

seq_len = 12
d_model = 64

q = torch.randn(1, seq_len, d_model)
k = torch.randn(1, seq_len, d_model)
v = torch.randn(1, seq_len, d_model)

raw_scores = (q @ k.transpose(-2, -1)) / (d_model ** 0.5)

# 1. LLM (Causal) attention
llm_mask = create_causal_mask(seq_len)
llm_scores = raw_scores.clone().masked_fill(llm_mask, float('-inf'))
llm_attn = F.softmax(llm_scores, dim=-1)[0]

# 2. ViT (Bidirectional) attention
vit_attn = F.softmax(raw_scores, dim=-1)[0]

# 3. VLM attention (6 image + 6 text tokens)
n_img, n_txt = 6, 6
vlm_mask = create_vlm_attention_mask(n_img, n_txt)
vlm_scores = raw_scores.clone().masked_fill(vlm_mask, float('-inf'))
vlm_attn = F.softmax(vlm_scores, dim=-1)[0]

# Plot side-by-side
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

im0 = axes[0].imshow(llm_attn.detach().numpy(), cmap='Blues', vmin=0, vmax=0.5)
axes[0].set_title('LLM: Causal Attention', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Key')
axes[0].set_ylabel('Query')
plt.colorbar(im0, ax=axes[0], shrink=0.8)

im1 = axes[1].imshow(vit_attn.detach().numpy(), cmap='Blues', vmin=0, vmax=0.5)
axes[1].set_title('ViT: Bidirectional Attention', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Key')
axes[1].set_ylabel('Query')
plt.colorbar(im1, ax=axes[1], shrink=0.8)

im2 = axes[2].imshow(vlm_attn.detach().numpy(), cmap='Blues', vmin=0, vmax=0.5)
axes[2].set_title('VLM: Mixed Attention (LLaVA-style)', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Key')
axes[2].set_ylabel('Query')
axes[2].axhline(y=n_img - 0.5, color='red', linewidth=2, linestyle='--')
axes[2].axvline(x=n_img - 0.5, color='red', linewidth=2, linestyle='--')
plt.colorbar(im2, ax=axes[2], shrink=0.8)

plt.suptitle('Attention Pattern Comparison', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Summary table
print("="*70)
print("ATTENTION PATTERN COMPARISON")
print("="*70)
print(f"{'Property':<25} {'LLM':<15} {'ViT':<15} {'VLM (LLaVA)'}")
print("-"*70)
print(f"{'Mask Type':<25} {'Causal':<15} {'None':<15} {'Mixed'}")
print(f"{'Upper Triangle':<25} {'Zeros':<15} {'Non-zero':<15} {'Partial zeros'}")
print(f"{'Token-to-Token':<25} {'Past only':<15} {'All':<15} {'Depends on type'}")
print(f"{'Cross-Modal':<25} {'N/A':<15} {'N/A':<15} {'Text→Image: Yes'}")
print(f"{'Use Case':<25} {'Text gen':<15} {'Image cls':<15} {'Multimodal'}")
print("="*70)

## Part 6: Interview Questions

### Q1: How does attention differ between LLMs and VLMs?

**Answer:**
- **LLMs** use causal self-attention where each token can only attend to itself and previous tokens. This is enforced with a lower-triangular mask.
- **VLMs** use mixed attention patterns:
  - Image tokens use bidirectional attention (all patches see all patches)
  - Text tokens use causal attention among themselves
  - Text tokens can attend to all image tokens (cross-modal attention)
  - Image tokens cannot attend to text tokens (they come first in sequence)

---

### Q2: How does LLaVA handle image inputs?

**Answer:**
1. **Vision Encoder**: A pretrained ViT encodes the image into patch embeddings
2. **Projection Layer**: A simple MLP projects visual features to LLM embedding dimension
3. **Concatenation**: Visual tokens are prepended to text tokens
4. **Processing**: The LLM processes the combined sequence with appropriate attention masking

---

### Q3: What's the attention mask structure in a VLM?

**Answer:**
For a sequence with N image tokens and M text tokens:

```
                Image (N)     Text (M)
Image (N)    [  All zeros     All ones  ]  <- Image sees image, not text
Text (M)     [  All zeros     Causal    ]  <- Text sees image + causal text
```

Key insight: It's NOT simply a causal mask! The top-right quadrant is all ones (masked).

---

### Q4: Why use a pretrained ViT instead of training from scratch?

**Answer:**
1. **Transfer Learning**: Visual representations generalize well
2. **Compute Efficiency**: Training vision encoders requires massive compute
3. **Data Efficiency**: Pretrained ViT needs less vision-language data
4. **Stability**: Pretrained weights provide stable gradients

---

### Q5: What's the difference between cross-attention (Flamingo) and concatenation (LLaVA)?

**Answer:**

| Aspect | Cross-Attention (Flamingo) | Concatenation (LLaVA) |
|--------|---------------------------|----------------------|
| Architecture | Separate cross-attn layers | Single self-attn over concat |
| Q, K, V | Q from text, K/V from image | All from same sequence |
| Sequence Length | Text length only | Image + text length |
| Complexity | More complex | Simpler |
| Memory | Less (separate streams) | More (longer sequence) |

In [None]:
print("\n" + "="*60)
print("KEY TAKEAWAYS")
print("="*60)
print("""
1. LLMs use CAUSAL attention (lower triangular mask)
   - Each token sees only past tokens

2. ViTs use BIDIRECTIONAL attention (no mask)
   - All patches see all other patches

3. VLMs (LLaVA-style) use MIXED attention
   - Image tokens: bidirectional among themselves
   - Text tokens: causal + can see all image tokens
   - The mask is NOT simply lower triangular!

4. The key implementation insight:
   create_vlm_attention_mask() must handle 4 quadrants:
   - Image→Image: bidirectional (no mask)
   - Image→Text: masked (can't see future)
   - Text→Image: can attend (no mask)
   - Text→Text: causal mask
""")
print("="*60)