# Understanding the AnyMAL Architecture

This notebook provides a deep dive into how AnyMAL converts images into tokens that a language model can understand.

## Overview

AnyMAL follows a simple but powerful paradigm:
1. **Encode** images with a frozen vision encoder (CLIP)
2. **Project** vision features to the LLM's embedding space
3. **Generate** text using the frozen LLM

The key insight is that we only need to train the projection layer - everything else can stay frozen!

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt

# For visualization
%matplotlib inline

## Step 1: Image Encoding with CLIP

CLIP's Vision Transformer (ViT) converts an image into a sequence of patch embeddings.

For ViT-L/14:
- Input: 224×224 RGB image
- Patches: 14×14 pixels each
- Number of patches: (224/14)² = 256
- Plus 1 CLS token = 257 total tokens
- Hidden dimension: 1024

In [None]:
from models.encoders.image_encoder import ImageEncoder

# Create encoder (downloads weights on first use)
encoder = ImageEncoder(
    model_name="ViT-L-14",
    pretrained="openai",
    freeze=True,
)

print(f"Output dimension: {encoder.get_output_dim()}")
print(f"Number of patches: {encoder.get_num_patches()}")

In [None]:
# Process a sample image
from data.data_utils import get_image_transform

transform = get_image_transform(image_size=224, is_train=False)

# Create a random test image
import numpy as np
test_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
image_tensor = transform(test_image).unsqueeze(0)  # Add batch dimension

print(f"Input shape: {image_tensor.shape}")

# Encode
with torch.no_grad():
    features = encoder(image_tensor)

print(f"Output shape: {features.shape}")
print(f"  - Batch size: {features.shape[0]}")
print(f"  - Sequence length: {features.shape[1]} (256 patches + 1 CLS)")
print(f"  - Hidden dim: {features.shape[2]}")

## Step 2: Projecting to LLM Space

The Perceiver Resampler:
1. Takes 257 tokens from CLIP (dim=1024)
2. Compresses to 64 tokens (dim=4096)

This is done via cross-attention with learnable queries.

In [None]:
from models.projectors.perceiver_resampler import PerceiverResampler

projector = PerceiverResampler(
    input_dim=1024,    # CLIP output dim
    output_dim=4096,   # LLaMA hidden dim
    num_latents=64,    # Output tokens
    num_layers=6,
)

print(f"Projector parameters: {projector.get_num_params():,}")

# Project features
image_tokens = projector(features)
print(f"\nProjection: {features.shape} -> {image_tokens.shape}")
print(f"  - Compressed from 257 to 64 tokens")
print(f"  - Dimension changed from 1024 to 4096")

## Step 3: The Full Forward Pass

The complete AnyMAL forward pass:
1. Encode image → [B, 257, 1024]
2. Project → [B, 64, 4096]
3. Embed text → [B, text_len, 4096]
4. Concatenate → [B, 64 + text_len, 4096]
5. LLM forward → logits

In [None]:
# Visualize the architecture
print("""
AnyMAL Architecture
══════════════════════════════════════════════════════════════

                        Input Image
                        (224 × 224)
                             │
                             ▼
                ┌────────────────────────┐
                │    CLIP ViT-L/14       │  ← FROZEN
                │    (Vision Encoder)    │
                └────────────────────────┘
                             │
                    [B, 257, 1024]
                             │
                             ▼
                ┌────────────────────────┐
                │  Perceiver Resampler   │  ← TRAINABLE
                │   (Cross-Attention)    │
                └────────────────────────┘
                             │
                     [B, 64, 4096]
                             │
        ┌────────────────────┴────────────────────┐
        │                                         │
        ▼                                         ▼
   Image Tokens                             Text Tokens
   [B, 64, 4096]                           [B, T, 4096]
        │                                         │
        └───────────────┬─────────────────────────┘
                        │
                        ▼
               Concatenate along seq
                [B, 64+T, 4096]
                        │
                        ▼
                ┌────────────────────────┐
                │     LLaMA-3 8B         │  ← FROZEN (+ LoRA in Stage 2)
                │   (Language Model)     │
                └────────────────────────┘
                        │
                        ▼
                     Logits
                 [B, 64+T, vocab]
                        │
                        ▼
             Cross-Entropy Loss
            (on text portion only)
""")

## Key Insights

### Why freeze the encoders?
- CLIP already has excellent visual representations
- LLaMA already understands language well
- We just need to bridge the gap between them

### Why use cross-attention projection?
- Learnable compression of visual information
- Fixed output size regardless of input
- Each latent can specialize in different image aspects

### Why 64 tokens?
- Good balance between information and efficiency
- 257 tokens would be expensive for the LLM
- Too few tokens might lose important details