# SFC Encoder Architecture Walkthrough

This notebook provides a **self-contained, step-by-step walkthrough** of the Space-Filling Curve (SFC) encoder architecture for sparse-conditioned image generation.

## Table of Contents
1. [Space-Filling Curves: Hilbert vs Z-order](#1-space-filling-curves)
2. [SFC Tokenization: Sparse Pixels → Tokens](#2-sfc-tokenization)
3. [Coordinate Embeddings (Option A: Unified Coords)](#3-coordinate-embeddings)
4. [Cross-Attention with Spatial Bias (Option B)](#4-cross-attention)
5. [Full Forward Pass](#5-full-forward-pass)
6. [Ablation Comparison](#6-ablation-comparison)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, Literal
import math

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---
## 1. Space-Filling Curves <a name="1-space-filling-curves"></a>

Space-filling curves map 2D coordinates to a 1D index while preserving spatial locality.

### Why use SFC for tokenization?
- **Locality preservation**: Nearby pixels in 2D remain nearby in the 1D sequence
- **Efficient for sparse data**: Only observed pixels need to be tokenized
- **Better than raster scan**: Raster ordering has poor locality (jumping across rows)

In [None]:
# ============================================================================
# SPACE-FILLING CURVE IMPLEMENTATIONS
# ============================================================================

def xy_to_zorder(x: int, y: int, bits: int = 5) -> int:
    """
    Convert (x, y) to Z-order (Morton code) index.
    Z-order interleaves the bits of x and y.
    """
    z = 0
    for i in range(bits):
        z |= ((x & (1 << i)) << i) | ((y & (1 << i)) << (i + 1))
    return z


def xy_to_hilbert(x: int, y: int, order: int = 5) -> int:
    """
    Convert (x, y) to Hilbert curve index.
    Hilbert has better locality: consecutive indices are always adjacent in 2D.
    """
    n = 1 << order
    d = 0
    s = n >> 1
    while s > 0:
        rx = 1 if (x & s) > 0 else 0
        ry = 1 if (y & s) > 0 else 0
        d += s * s * ((3 * rx) ^ ry)
        if ry == 0:
            if rx == 1:
                x = s - 1 - x
                y = s - 1 - y
            x, y = y, x
        s >>= 1
    return d


def visualize_sfc(size: int = 8, curve: str = "hilbert"):
    """Visualize the space-filling curve ordering."""
    bits = int(np.log2(size))
    
    # Compute SFC index for each pixel
    sfc_indices = np.zeros((size, size), dtype=int)
    for y in range(size):
        for x in range(size):
            if curve == "hilbert":
                sfc_indices[y, x] = xy_to_hilbert(x, y, order=bits)
            else:
                sfc_indices[y, x] = xy_to_zorder(x, y, bits=bits)
    
    # Create path coordinates
    path = [(0, 0)] * (size * size)
    for y in range(size):
        for x in range(size):
            idx = sfc_indices[y, x]
            path[idx] = (x + 0.5, y + 0.5)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Left: Index grid
    ax = axes[0]
    im = ax.imshow(sfc_indices, cmap='viridis')
    for y in range(size):
        for x in range(size):
            ax.text(x, y, str(sfc_indices[y, x]), ha='center', va='center', 
                   fontsize=8, color='white' if sfc_indices[y, x] < size*size/2 else 'black')
    ax.set_title(f'{curve.capitalize()} Curve Index Grid ({size}x{size})')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    plt.colorbar(im, ax=ax, label='SFC Index')
    
    # Right: Path visualization
    ax = axes[1]
    path_x, path_y = zip(*path)
    ax.plot(path_x, path_y, 'b-', linewidth=1, alpha=0.7)
    ax.scatter(path_x, path_y, c=range(len(path)), cmap='viridis', s=50, zorder=5)
    ax.set_xlim(0, size)
    ax.set_ylim(0, size)
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.set_title(f'{curve.capitalize()} Curve Path')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return sfc_indices

In [None]:
# Visualize Hilbert curve
print("HILBERT CURVE - Better locality (consecutive indices always adjacent)")
hilbert_indices = visualize_sfc(size=8, curve="hilbert")

In [None]:
# Visualize Z-order curve
print("Z-ORDER CURVE - Simpler but has jumps (e.g., index 7 to 8)")
zorder_indices = visualize_sfc(size=8, curve="zorder")

### Key Observation

Notice how:
- **Hilbert**: Every consecutive pair of indices is spatially adjacent
- **Z-order**: Has "jumps" (e.g., 7→8 jumps across the grid)

For attention mechanisms, Hilbert's better locality means tokens grouped together are more likely to be spatially related.

---
## 2. SFC Tokenization: Sparse Pixels → Tokens <a name="2-sfc-tokenization"></a>

The SFC tokenizer converts sparse pixels into a sequence of tokens:

1. **Order all pixels** by SFC (Hilbert/Z-order)
2. **Select observed pixels** (where `cond_mask=1`)
3. **Group consecutive pixels** into groups of `g` (default 8)
4. **Concatenate features** within each group
5. **Project to hidden_size**

In [None]:
def precompute_sfc_order(size: int, curve: str = "hilbert"):
    """
    Precompute SFC ordering and coordinate grids.
    
    Returns:
        order_flat: (HW,) indices sorted by SFC order
        coords_flat: (HW, 2) normalized coordinates in [-1, 1]
        sfc_pos_flat: (HW, 1) normalized SFC position in [0, 1]
    """
    bits = int(np.log2(size))
    
    # Compute (sfc_idx, flat_idx) pairs
    pairs = []
    for y in range(size):
        for x in range(size):
            if curve == "hilbert":
                sfc = xy_to_hilbert(x, y, order=bits)
            else:
                sfc = xy_to_zorder(x, y, bits=bits)
            flat = y * size + x
            pairs.append((sfc, flat))
    
    pairs.sort(key=lambda t: t[0])
    order_flat = torch.tensor([flat for _, flat in pairs], dtype=torch.long)
    
    # Normalized coordinates in [-1, 1]
    xs = torch.arange(size, dtype=torch.float32)
    ys = torch.arange(size, dtype=torch.float32)
    yy, xx = torch.meshgrid(ys, xs, indexing="ij")
    x_norm = (xx / (size - 1)) * 2.0 - 1.0
    y_norm = (yy / (size - 1)) * 2.0 - 1.0
    coords_flat = torch.stack([x_norm, y_norm], dim=-1).reshape(-1, 2)
    
    # Normalized SFC position in [0, 1]
    inv = torch.empty_like(order_flat)
    inv[order_flat] = torch.arange(order_flat.numel(), dtype=torch.long)
    sfc_pos_flat = inv.float().unsqueeze(-1) / max(order_flat.numel() - 1, 1)
    
    return order_flat, coords_flat, sfc_pos_flat


# Demonstrate SFC ordering
size = 8
order_flat, coords_flat, sfc_pos_flat = precompute_sfc_order(size, "hilbert")

print(f"Image size: {size}x{size} = {size*size} pixels")
print(f"\nSFC order (flat indices sorted by Hilbert index):")
print(f"First 16: {order_flat[:16].tolist()}")
print(f"\nCoordinates (first 4 in SFC order):")
for i in range(4):
    flat_idx = order_flat[i].item()
    x, y = flat_idx % size, flat_idx // size
    coord = coords_flat[flat_idx]
    print(f"  SFC idx {i}: pixel ({x}, {y}), normalized coord ({coord[0]:.2f}, {coord[1]:.2f})")

In [None]:
def demonstrate_sparse_tokenization(size=8, sparsity=0.4, group_size=4):
    """
    Demonstrate how sparse pixels are tokenized via SFC.
    """
    # Create random sparse mask
    mask = torch.rand(size, size) < sparsity
    num_observed = mask.sum().item()
    
    # Get SFC ordering
    order_flat, coords_flat, sfc_pos_flat = precompute_sfc_order(size, "hilbert")
    
    # Order mask by SFC
    mask_flat = mask.reshape(-1)
    mask_ord = mask_flat[order_flat]
    
    # Select observed pixels in SFC order
    observed_indices = torch.nonzero(mask_ord, as_tuple=False).squeeze(1)
    
    # Group into tokens
    num_tokens = int(np.ceil(len(observed_indices) / group_size))
    
    # Visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Left: Sparse mask
    ax = axes[0]
    ax.imshow(mask.float(), cmap='gray', vmin=0, vmax=1)
    ax.set_title(f'Sparse Mask ({sparsity*100:.0f}% observed)\n{num_observed} pixels')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    
    # Middle: SFC ordering with observed pixels highlighted
    ax = axes[1]
    sfc_indices = np.zeros((size, size))
    bits = int(np.log2(size))
    for y in range(size):
        for x in range(size):
            sfc_indices[y, x] = xy_to_hilbert(x, y, order=bits)
    
    ax.imshow(sfc_indices, cmap='viridis', alpha=0.3)
    # Highlight observed pixels
    for y in range(size):
        for x in range(size):
            if mask[y, x]:
                ax.scatter(x, y, c='red', s=100, marker='o', edgecolors='white', linewidths=2)
                ax.text(x, y, str(int(sfc_indices[y, x])), ha='center', va='center', 
                       fontsize=7, color='white', fontweight='bold')
    ax.set_title('Observed Pixels with SFC Index')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    
    # Right: Token grouping
    ax = axes[2]
    token_colors = plt.cm.tab10(np.linspace(0, 1, num_tokens))
    
    for token_idx in range(num_tokens):
        start = token_idx * group_size
        end = min(start + group_size, len(observed_indices))
        
        for i in range(start, end):
            sfc_order_idx = observed_indices[i].item()
            flat_idx = order_flat[sfc_order_idx].item()
            x, y = flat_idx % size, flat_idx // size
            ax.scatter(x, y, c=[token_colors[token_idx]], s=150, marker='s', 
                      edgecolors='black', linewidths=1)
            ax.text(x, y, f'T{token_idx}', ha='center', va='center', 
                   fontsize=6, color='white', fontweight='bold')
    
    ax.set_xlim(-0.5, size-0.5)
    ax.set_ylim(-0.5, size-0.5)
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.set_title(f'Token Grouping (g={group_size})\n{num_tokens} tokens')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nTokenization Summary:")
    print(f"  Input: {size}x{size} = {size*size} pixels")
    print(f"  Observed: {num_observed} pixels ({num_observed/size/size*100:.1f}%)")
    print(f"  Group size: {group_size}")
    print(f"  Output: {num_tokens} tokens")
    print(f"  Compression: {size*size} → {num_tokens} ({num_tokens/size/size*100:.1f}% of dense)")

demonstrate_sparse_tokenization(size=8, sparsity=0.4, group_size=4)

### Key Insight

Notice how spatially nearby observed pixels tend to be in the same token (same color). This is because:
1. Hilbert curve preserves locality
2. Consecutive observed pixels (in SFC order) are grouped together

This means each token contains **spatially coherent** information!

---
## 3. Coordinate Embeddings (Option A: Unified Coords) <a name="3-coordinate-embeddings"></a>

Both tokens and queries need positional information. The question is: should they use the **same** embedding function?

### Baseline: Separate Embeddings
- **Tokens**: `LearnableFourierMLP` (Fourier features → 2-layer MLP)
- **Queries**: `FourierFeatures` (Fourier features → single linear)

### Option A: Unified Embeddings
- **Both**: Share the same `LearnableFourierMLP`

In [None]:
class FourierFeatures(nn.Module):
    """Fixed Fourier features with single linear projection (used for queries in baseline)."""
    def __init__(self, in_features: int, out_features: int, n_bands: int = 16):
        super().__init__()
        self.n_bands = n_bands
        freqs = (2.0 ** torch.arange(n_bands)) * np.pi
        self.register_buffer("freqs", freqs)
        
        fourier_dim = 4 * n_bands  # sin/cos for x and y
        self.proj = nn.Linear(fourier_dim, out_features)
    
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        # coords: (N, 2)
        x, y = coords[:, 0:1], coords[:, 1:2]  # (N, 1)
        xw = x * self.freqs  # (N, n_bands)
        yw = y * self.freqs
        fourier = torch.cat([torch.sin(xw), torch.cos(xw), 
                            torch.sin(yw), torch.cos(yw)], dim=-1)  # (N, 4*n_bands)
        return self.proj(fourier)


class LearnableFourierMLP(nn.Module):
    """Fourier features with learnable MLP projection (used for tokens, and queries in Option A)."""
    def __init__(self, out_features: int, n_bands: int = 16, hidden_dim: int = None, mlp_layers: int = 2):
        super().__init__()
        self.n_bands = n_bands
        freqs = (2.0 ** torch.arange(n_bands)) * np.pi
        self.register_buffer("freqs", freqs)
        
        fourier_dim = 4 * n_bands
        if hidden_dim is None:
            hidden_dim = max(out_features, fourier_dim)
        
        # 2-layer MLP
        self.mlp = nn.Sequential(
            nn.Linear(fourier_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_features),
        )
    
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        # coords: (N, 2)
        proj = coords.unsqueeze(-1) * self.freqs  # (N, 2, n_bands)
        fourier = torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)  # (N, 2, 2*n_bands)
        fourier = fourier.reshape(coords.shape[0], -1)  # (N, 4*n_bands)
        return self.mlp(fourier)


# Compare the two embedding types
hidden_size = 64
n_bands = 8

fourier_simple = FourierFeatures(2, hidden_size, n_bands=n_bands)
fourier_mlp = LearnableFourierMLP(hidden_size, n_bands=n_bands)

# Test coordinates
test_coords = torch.tensor([
    [-1.0, -1.0],  # top-left
    [0.0, 0.0],    # center
    [1.0, 1.0],    # bottom-right
    [0.5, -0.5],   # somewhere
])

with torch.no_grad():
    embed_simple = fourier_simple(test_coords)
    embed_mlp = fourier_mlp(test_coords)

print("Coordinate Embedding Comparison")
print("=" * 50)
print(f"\nInput coordinates shape: {test_coords.shape}")
print(f"FourierFeatures output shape: {embed_simple.shape}")
print(f"LearnableFourierMLP output shape: {embed_mlp.shape}")

# Parameter count
params_simple = sum(p.numel() for p in fourier_simple.parameters())
params_mlp = sum(p.numel() for p in fourier_mlp.parameters())
print(f"\nFourierFeatures parameters: {params_simple:,}")
print(f"LearnableFourierMLP parameters: {params_mlp:,}")

In [None]:
def visualize_embedding_similarity(embed_fn, title, grid_size=16):
    """
    Visualize how coordinate embeddings relate to each other.
    Shows cosine similarity between center point and all other points.
    """
    # Create coordinate grid
    xs = torch.linspace(-1, 1, grid_size)
    ys = torch.linspace(-1, 1, grid_size)
    yy, xx = torch.meshgrid(ys, xs, indexing='ij')
    coords = torch.stack([xx, yy], dim=-1).reshape(-1, 2)  # (grid_size^2, 2)
    
    with torch.no_grad():
        embeddings = embed_fn(coords)  # (grid_size^2, hidden_size)
        embeddings = F.normalize(embeddings, dim=-1)  # Normalize for cosine similarity
    
    # Compute similarity to center point
    center_idx = grid_size * grid_size // 2 + grid_size // 2
    center_embed = embeddings[center_idx:center_idx+1]  # (1, hidden_size)
    
    similarity = (embeddings @ center_embed.T).squeeze()  # (grid_size^2,)
    similarity_grid = similarity.reshape(grid_size, grid_size)
    
    # Plot
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(similarity_grid, cmap='RdBu_r', vmin=-1, vmax=1)
    ax.scatter(grid_size//2, grid_size//2, c='green', s=100, marker='*', 
              edgecolors='white', linewidths=2, label='Reference point')
    ax.set_title(f'{title}\nCosine Similarity to Center')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    plt.colorbar(im, ax=ax, label='Cosine Similarity')
    ax.legend()
    plt.tight_layout()
    plt.show()

print("FourierFeatures (single linear projection):")
visualize_embedding_similarity(fourier_simple, "FourierFeatures")

print("\nLearnableFourierMLP (2-layer MLP):")
visualize_embedding_similarity(fourier_mlp, "LearnableFourierMLP")

### Option A: Why Unified Coordinates Matter

In the **baseline**, cross-attention computes:
```
Q = W_q @ FourierFeatures(query_coords)      # Different embedding
K = W_k @ LearnableFourierMLP(token_coords)  # Different embedding
attention = softmax(Q @ K.T)
```

The Q/K projections must learn to **align two different coordinate representations**.

With **Option A** (unified):
```
shared_embed = LearnableFourierMLP
Q = W_q @ shared_embed(query_coords)   # Same embedding
K = W_k @ shared_embed(token_coords)   # Same embedding
```

Now both operate in the **same representation space**, making alignment easier.

---
## 4. Cross-Attention with Spatial Bias (Option B) <a name="4-cross-attention"></a>

Cross-attention allows patch queries to gather information from sparse tokens.

### Baseline Cross-Attention
```
attention_scores = Q @ K.T / sqrt(d)
attention_weights = softmax(attention_scores)
output = attention_weights @ V
```

The model must **learn** that spatially nearby tokens are more relevant.

### Option B: Spatial Bias
```
dist = cdist(query_coords, token_coords)  # Pairwise distances
spatial_bias = -scale * dist + offset     # Learnable per head
attention_scores = Q @ K.T / sqrt(d) + spatial_bias  # Add bias!
```

Now **closer tokens automatically get higher attention scores**.

In [None]:
def demonstrate_spatial_bias(num_queries=4, num_tokens=12, num_heads=4):
    """
    Demonstrate how spatial bias affects attention patterns.
    """
    hidden_size = 64
    head_dim = hidden_size // num_heads
    
    # Random query positions (patch centers)
    query_coords = torch.tensor([
        [-0.5, -0.5],  # top-left patch
        [0.5, -0.5],   # top-right patch
        [-0.5, 0.5],   # bottom-left patch
        [0.5, 0.5],    # bottom-right patch
    ])
    
    # Random token positions (sparse observed pixels)
    torch.manual_seed(42)
    token_coords = torch.rand(num_tokens, 2) * 2 - 1  # Random in [-1, 1]
    
    # Compute pairwise distances
    dist = torch.cdist(query_coords, token_coords, p=2)  # (Q, T)
    
    # Simulate Q, K with random values
    Q = torch.randn(num_queries, num_heads, head_dim)
    K = torch.randn(num_tokens, num_heads, head_dim)
    
    # Baseline attention (no spatial bias)
    attn_scores_baseline = torch.einsum('qhd,thd->hqt', Q, K) / np.sqrt(head_dim)
    attn_weights_baseline = F.softmax(attn_scores_baseline, dim=-1)
    
    # Option B: With spatial bias
    scale = torch.tensor([1.0, 2.0, 3.0, 5.0])  # Different scales per head
    offset = torch.zeros(num_heads)
    
    # spatial_bias: (num_heads, Q, T)
    spatial_bias = -scale.view(-1, 1, 1) * dist.unsqueeze(0) + offset.view(-1, 1, 1)
    
    attn_scores_with_bias = attn_scores_baseline + spatial_bias
    attn_weights_with_bias = F.softmax(attn_scores_with_bias, dim=-1)
    
    # Visualization
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Top row: Spatial layout and distance matrix
    ax = axes[0, 0]
    ax.scatter(token_coords[:, 0], token_coords[:, 1], c='blue', s=100, 
              marker='o', label='Tokens', edgecolors='black')
    ax.scatter(query_coords[:, 0], query_coords[:, 1], c='red', s=200, 
              marker='s', label='Query patches', edgecolors='black')
    for i, (x, y) in enumerate(query_coords):
        ax.annotate(f'Q{i}', (x, y), fontsize=10, ha='center', va='center', color='white', fontweight='bold')
    for i, (x, y) in enumerate(token_coords):
        ax.annotate(f'T{i}', (x, y), fontsize=8, ha='center', va='center', color='white')
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_aspect('equal')
    ax.set_title('Spatial Layout')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    ax = axes[0, 1]
    im = ax.imshow(dist, cmap='viridis')
    ax.set_title('Distance Matrix (Q × T)')
    ax.set_xlabel('Token')
    ax.set_ylabel('Query')
    ax.set_xticks(range(num_tokens))
    ax.set_yticks(range(num_queries))
    ax.set_xticklabels([f'T{i}' for i in range(num_tokens)], fontsize=8)
    ax.set_yticklabels([f'Q{i}' for i in range(num_queries)])
    plt.colorbar(im, ax=ax, label='Euclidean Distance')
    
    ax = axes[0, 2]
    im = ax.imshow(spatial_bias[2], cmap='RdBu_r')  # Show head 2
    ax.set_title(f'Spatial Bias (Head 2, scale={scale[2]:.1f})')
    ax.set_xlabel('Token')
    ax.set_ylabel('Query')
    ax.set_xticks(range(num_tokens))
    ax.set_yticks(range(num_queries))
    ax.set_xticklabels([f'T{i}' for i in range(num_tokens)], fontsize=8)
    ax.set_yticklabels([f'Q{i}' for i in range(num_queries)])
    plt.colorbar(im, ax=ax, label='Bias Value')
    
    # Bottom row: Attention weights comparison
    ax = axes[1, 0]
    im = ax.imshow(attn_weights_baseline[2], cmap='hot', vmin=0, vmax=0.3)
    ax.set_title('Baseline Attention (Head 2)\nNo Spatial Bias')
    ax.set_xlabel('Token')
    ax.set_ylabel('Query')
    ax.set_xticks(range(num_tokens))
    ax.set_yticks(range(num_queries))
    ax.set_xticklabels([f'T{i}' for i in range(num_tokens)], fontsize=8)
    ax.set_yticklabels([f'Q{i}' for i in range(num_queries)])
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    ax = axes[1, 1]
    im = ax.imshow(attn_weights_with_bias[2], cmap='hot', vmin=0, vmax=0.3)
    ax.set_title('With Spatial Bias (Head 2)\nCloser tokens get more attention')
    ax.set_xlabel('Token')
    ax.set_ylabel('Query')
    ax.set_xticks(range(num_tokens))
    ax.set_yticks(range(num_queries))
    ax.set_xticklabels([f'T{i}' for i in range(num_tokens)], fontsize=8)
    ax.set_yticklabels([f'Q{i}' for i in range(num_queries)])
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    # Show difference
    ax = axes[1, 2]
    diff = attn_weights_with_bias[2] - attn_weights_baseline[2]
    im = ax.imshow(diff, cmap='RdBu_r', vmin=-0.2, vmax=0.2)
    ax.set_title('Difference (With Bias - Baseline)\nBlue=decreased, Red=increased')
    ax.set_xlabel('Token')
    ax.set_ylabel('Query')
    ax.set_xticks(range(num_tokens))
    ax.set_yticks(range(num_queries))
    ax.set_xticklabels([f'T{i}' for i in range(num_tokens)], fontsize=8)
    ax.set_yticklabels([f'Q{i}' for i in range(num_queries)])
    plt.colorbar(im, ax=ax, label='Weight Difference')
    
    plt.tight_layout()
    plt.show()
    
    return query_coords, token_coords, dist

query_coords, token_coords, dist = demonstrate_spatial_bias()

### Key Observation

With spatial bias:
- Queries attend **more strongly** to nearby tokens (red in difference plot)
- Queries attend **less** to distant tokens (blue in difference plot)
- This is an **inductive bias** that helps the model learn faster

The bias is **learnable per head**, so some heads can focus locally while others attend globally.

---
## 5. Full Forward Pass <a name="5-full-forward-pass"></a>

Now let's trace through the complete forward pass of the SFC encoder.

In [None]:
class SimpleSFCEncoder(nn.Module):
    """
    Simplified SFC encoder for demonstration.
    Shows the key components without full model complexity.
    """
    def __init__(
        self,
        in_channels: int = 3,
        hidden_size: int = 128,
        patch_size: int = 4,
        group_size: int = 4,
        num_heads: int = 4,
        # Ablation options
        unified_coords: bool = False,
        spatial_bias: bool = False,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.group_size = group_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.unified_coords = unified_coords
        self.spatial_bias = spatial_bias
        
        # Coordinate embeddings
        self.token_coord_embed = LearnableFourierMLP(hidden_size, n_bands=8)
        
        if unified_coords:
            # Option A: Share the same embedder
            self.query_coord_embed = self.token_coord_embed
        else:
            # Baseline: Separate embedder
            self.query_coord_embed = FourierFeatures(2, hidden_size, n_bands=8)
        
        # Token projection: (group_size * (in_channels + hidden_size)) -> hidden_size
        per_point_dim = in_channels + hidden_size  # pixel values + coord embedding
        self.token_proj = nn.Linear(group_size * per_point_dim, hidden_size)
        self.token_norm = nn.LayerNorm(hidden_size)
        
        # Cross-attention
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
        
        # Option B: Spatial bias parameters
        if spatial_bias:
            self.spatial_scale = nn.Parameter(torch.ones(num_heads))
            self.spatial_offset = nn.Parameter(torch.zeros(num_heads))
    
    def tokenize(self, x: torch.Tensor, cond_mask: torch.Tensor):
        """
        Convert sparse pixels to SFC tokens.
        
        Args:
            x: (B, C, H, W) input image
            cond_mask: (B, 1, H, W) binary mask of observed pixels
        
        Returns:
            tokens: (B, T, D) token embeddings
            token_mask: (B, T) boolean mask (True = real token)
            token_coords: (B, T, 2) mean coordinate per token
        """
        B, C, H, W = x.shape
        device = x.device
        
        # Get SFC ordering
        order_flat, coords_flat, _ = precompute_sfc_order(H, "hilbert")
        order_flat = order_flat.to(device)
        coords_flat = coords_flat.to(device)
        
        # Flatten and reorder by SFC
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H*W, C)  # (B, HW, C)
        m_flat = cond_mask[:, 0].reshape(B, H*W) > 0.5     # (B, HW)
        
        x_ord = x_flat[:, order_flat, :]  # (B, HW, C)
        m_ord = m_flat[:, order_flat]     # (B, HW)
        coords_ord = coords_flat[order_flat, :]  # (HW, 2)
        
        # Process each batch item
        tokens_list, masks_list, coords_list = [], [], []
        
        for b in range(B):
            # Select observed pixels
            idx = torch.nonzero(m_ord[b], as_tuple=False).squeeze(1)
            if idx.numel() == 0:
                # No observed pixels - create dummy token
                tokens_list.append(torch.zeros(1, self.hidden_size, device=device))
                masks_list.append(torch.zeros(1, dtype=torch.bool, device=device))
                coords_list.append(torch.zeros(1, 2, device=device))
                continue
            
            vals = x_ord[b, idx, :]  # (K, C)
            xy = coords_ord[idx, :].to(device)   # (K, 2)
            
            # Add coordinate embedding
            coord_embed = self.token_coord_embed(xy)  # (K, D)
            feats = torch.cat([vals, coord_embed], dim=-1)  # (K, C+D)
            
            K = feats.shape[0]
            
            # Pad to multiple of group_size
            total = int(np.ceil(K / self.group_size) * self.group_size)
            pad = total - K
            if pad > 0:
                feats = torch.cat([feats, feats.new_zeros(pad, feats.shape[-1])], dim=0)
                xy_padded = torch.cat([xy, xy[-1:].expand(pad, -1)], dim=0)
            else:
                xy_padded = xy
            
            # Group and flatten
            T = total // self.group_size
            feats = feats.view(T, self.group_size, -1).reshape(T, -1)  # (T, g*(C+D))
            
            # Compute mean coordinate per token
            tcoords = xy_padded.view(T, self.group_size, 2).mean(dim=1)  # (T, 2)
            
            # Project to hidden size
            tok = self.token_proj(feats)  # (T, D)
            tok = self.token_norm(tok)
            
            # Create mask
            real_T = int(np.ceil(K / self.group_size))
            tmask = torch.zeros(T, dtype=torch.bool, device=device)
            tmask[:real_T] = True
            
            tokens_list.append(tok)
            masks_list.append(tmask)
            coords_list.append(tcoords)
        
        # Pad across batch
        T_max = max(t.shape[0] for t in tokens_list)
        tokens = torch.zeros(B, T_max, self.hidden_size, device=device)
        token_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device)
        token_coords = torch.zeros(B, T_max, 2, device=device)
        
        for b in range(B):
            T = tokens_list[b].shape[0]
            tokens[b, :T] = tokens_list[b]
            token_mask[b, :T] = masks_list[b]
            token_coords[b, :T] = coords_list[b]
        
        return tokens, token_mask, token_coords
    
    def cross_attention(self, queries, tokens, token_mask, query_coords, token_coords):
        """
        Cross-attention: queries attend to tokens.
        
        Args:
            queries: (B, L, D) patch queries
            tokens: (B, T, D) SFC tokens
            token_mask: (B, T) boolean mask
            query_coords: (B, L, 2) query coordinates
            token_coords: (B, T, 2) token coordinates
        """
        B, L, D = queries.shape
        T = tokens.shape[1]
        
        # Project Q, K, V
        Q = self.q_proj(queries).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(tokens).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(tokens).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        # Q, K, V: (B, heads, seq, head_dim)
        
        # Attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)  # (B, heads, L, T)
        
        # Option B: Add spatial bias
        if self.spatial_bias:
            dist = torch.cdist(query_coords.float(), token_coords.float(), p=2)  # (B, L, T)
            scale = self.spatial_scale.view(1, -1, 1, 1)  # (1, heads, 1, 1)
            offset = self.spatial_offset.view(1, -1, 1, 1)
            spatial_bias = -scale * dist.unsqueeze(1) + offset  # (B, heads, L, T)
            attn_scores = attn_scores + spatial_bias
        
        # Apply mask
        if token_mask is not None:
            mask = ~token_mask.unsqueeze(1).unsqueeze(2)  # (B, 1, 1, T)
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))
        
        # Softmax and apply to values
        attn_weights = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_weights, V)  # (B, heads, L, head_dim)
        out = out.transpose(1, 2).reshape(B, L, D)  # (B, L, D)
        out = self.out_proj(out)
        
        return out, attn_weights
    
    def forward(self, x, cond_mask):
        """
        Full forward pass.
        
        Args:
            x: (B, C, H, W) input image
            cond_mask: (B, 1, H, W) binary mask
        
        Returns:
            patch_embeddings: (B, L, D) where L = (H/patch_size) * (W/patch_size)
        """
        B, C, H, W = x.shape
        device = x.device
        
        # Step 1: Tokenize sparse pixels
        tokens, token_mask, token_coords = self.tokenize(x, cond_mask)
        
        # Step 2: Create patch queries
        ph, pw = H // self.patch_size, W // self.patch_size
        L = ph * pw
        
        # Patch center coordinates
        ys = (torch.arange(ph, device=device).float() + 0.5) * self.patch_size
        xs = (torch.arange(pw, device=device).float() + 0.5) * self.patch_size
        ys = (ys / H) * 2 - 1  # Normalize to [-1, 1]
        xs = (xs / W) * 2 - 1
        yy, xx = torch.meshgrid(ys, xs, indexing='ij')
        query_coords = torch.stack([xx, yy], dim=-1).view(1, L, 2).expand(B, -1, -1)
        
        # Embed query coordinates
        queries = self.query_coord_embed(query_coords.reshape(-1, 2)).view(B, L, -1)
        
        # Step 3: Cross-attention
        output, attn_weights = self.cross_attention(
            queries, tokens, token_mask, query_coords, token_coords
        )
        
        return output, {
            'tokens': tokens,
            'token_mask': token_mask,
            'token_coords': token_coords,
            'query_coords': query_coords,
            'attn_weights': attn_weights,
        }

In [None]:
def demonstrate_full_forward_pass():
    """
    Walk through the full forward pass with visualizations.
    """
    # Setup
    B, C, H, W = 1, 3, 16, 16
    patch_size = 4
    group_size = 4
    hidden_size = 64
    sparsity = 0.3
    
    # Create model
    model = SimpleSFCEncoder(
        in_channels=C,
        hidden_size=hidden_size,
        patch_size=patch_size,
        group_size=group_size,
        unified_coords=True,   # Option A
        spatial_bias=True,     # Option B
    )
    
    # Create input
    torch.manual_seed(123)
    x = torch.randn(B, C, H, W)
    cond_mask = (torch.rand(B, 1, H, W) < sparsity).float()
    
    # Forward pass
    with torch.no_grad():
        output, info = model(x, cond_mask)
    
    # Extract info
    tokens = info['tokens']
    token_mask = info['token_mask']
    token_coords = info['token_coords']
    query_coords = info['query_coords']
    attn_weights = info['attn_weights']
    
    num_observed = cond_mask.sum().int().item()
    num_tokens = token_mask[0].sum().item()
    num_patches = output.shape[1]
    
    print("=" * 60)
    print("FULL FORWARD PASS WALKTHROUGH")
    print("=" * 60)
    print(f"\nInput:")
    print(f"  Image: {tuple(x.shape)} = (B, C, H, W)")
    print(f"  Mask: {tuple(cond_mask.shape)} with {num_observed} observed pixels ({num_observed/H/W*100:.1f}%)")
    print(f"\nTokenization:")
    print(f"  Group size: {group_size}")
    print(f"  Tokens: {tuple(tokens.shape)} with {int(num_tokens)} real tokens")
    print(f"  Compression: {H*W} pixels → {int(num_tokens)} tokens")
    print(f"\nQueries:")
    print(f"  Patch size: {patch_size}x{patch_size}")
    print(f"  Patches: {H//patch_size}x{W//patch_size} = {num_patches}")
    print(f"\nCross-Attention:")
    print(f"  Q: ({num_patches} queries) × K: ({int(num_tokens)} tokens)")
    print(f"  Attention weights: {tuple(attn_weights.shape)} = (B, heads, L, T)")
    print(f"\nOutput:")
    print(f"  Patch embeddings: {tuple(output.shape)} = (B, L, D)")
    
    # Visualization
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Input image and mask
    ax = axes[0, 0]
    # Show first channel of input
    ax.imshow(x[0, 0].numpy(), cmap='gray')
    ax.set_title('Input Image (Channel 0)')
    ax.axis('off')
    
    ax = axes[0, 1]
    ax.imshow(cond_mask[0, 0].numpy(), cmap='gray')
    ax.set_title(f'Conditioning Mask\n{num_observed} pixels observed')
    ax.axis('off')
    
    # Token positions
    ax = axes[0, 2]
    tc = token_coords[0, :int(num_tokens)].numpy()
    qc = query_coords[0].numpy()
    
    ax.scatter(qc[:, 0], qc[:, 1], c='red', s=200, marker='s', 
              label=f'Query patches ({num_patches})', alpha=0.7, edgecolors='black')
    ax.scatter(tc[:, 0], tc[:, 1], c='blue', s=100, marker='o',
              label=f'Tokens ({int(num_tokens)})', alpha=0.7, edgecolors='black')
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_aspect('equal')
    ax.invert_yaxis()
    ax.set_title('Spatial Layout')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Attention patterns for different heads
    for head_idx, ax in enumerate(axes[1, :3]):
        if head_idx < attn_weights.shape[1]:
            attn = attn_weights[0, head_idx, :, :int(num_tokens)].numpy()
            im = ax.imshow(attn, cmap='hot', aspect='auto')
            ax.set_title(f'Attention (Head {head_idx})')
            ax.set_xlabel('Token')
            ax.set_ylabel('Query Patch')
            plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    plt.show()
    
    return model, x, cond_mask, output, info

model, x, cond_mask, output, info = demonstrate_full_forward_pass()

---
## 6. Ablation Comparison <a name="6-ablation-comparison"></a>

Let's compare the 4 ablation configurations side by side.

In [None]:
def compare_ablations():
    """
    Compare attention patterns across the 4 ablation configurations.
    """
    # Setup
    B, C, H, W = 1, 3, 16, 16
    patch_size = 4
    
    # Create input
    torch.manual_seed(42)
    x = torch.randn(B, C, H, W)
    cond_mask = (torch.rand(B, 1, H, W) < 0.3).float()
    
    # Create 4 model variants
    configs = [
        ('Baseline', False, False),
        ('Option A (Unified Coords)', True, False),
        ('Option B (Spatial Bias)', False, True),
        ('Option A+B (Both)', True, True),
    ]
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    for col, (name, unified, spatial) in enumerate(configs):
        # Create model with specific config
        model = SimpleSFCEncoder(
            in_channels=C,
            hidden_size=64,
            patch_size=patch_size,
            group_size=4,
            unified_coords=unified,
            spatial_bias=spatial,
        )
        
        # Forward pass
        with torch.no_grad():
            output, info = model(x, cond_mask)
        
        attn = info['attn_weights'][0]  # (heads, L, T)
        num_tokens = int(info['token_mask'][0].sum().item())
        
        # Top row: Head 0 attention
        ax = axes[0, col]
        im = ax.imshow(attn[0, :, :num_tokens].numpy(), cmap='hot', aspect='auto')
        ax.set_title(f'{name}\nHead 0')
        ax.set_xlabel('Token')
        if col == 0:
            ax.set_ylabel('Query')
        
        # Bottom row: Mean attention across heads
        ax = axes[1, col]
        mean_attn = attn[:, :, :num_tokens].mean(dim=0).numpy()
        im = ax.imshow(mean_attn, cmap='hot', aspect='auto')
        ax.set_title('Mean across heads')
        ax.set_xlabel('Token')
        if col == 0:
            ax.set_ylabel('Query')
    
    plt.suptitle('Attention Patterns Comparison\n(Same input, different configurations)', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\nAblation Summary:")
    print("=" * 70)
    print(f"{'Configuration':<35} {'Unified Coords':<15} {'Spatial Bias':<15}")
    print("-" * 70)
    for name, unified, spatial in configs:
        print(f"{name:<35} {str(unified):<15} {str(spatial):<15}")
    print("=" * 70)
    print("\nKey Differences:")
    print("- Baseline: Attention must learn spatial relevance from scratch")
    print("- Option A: Tokens and queries use same coordinate representation")
    print("- Option B: Closer tokens automatically get higher attention")
    print("- Option A+B: Both benefits combined")

compare_ablations()

---
## Summary

### Architecture Flow

```
Input Image (B, C, H, W) + Sparse Mask (B, 1, H, W)
                    │
                    ▼
┌─────────────────────────────────────────────────────────────┐
│                   SFC TOKENIZER                              │
│  1. Order pixels by Hilbert/Z-order curve                   │
│  2. Select observed pixels (mask=1)                         │
│  3. Add coordinate embeddings (Option A: unified)           │
│  4. Group consecutive g pixels → flatten → project          │
│  Output: tokens (B,T,D), token_coords (B,T,2)               │
└─────────────────────────────────────────────────────────────┘
                    │
                    ▼
┌─────────────────────────────────────────────────────────────┐
│                   CROSS-ATTENTION                            │
│  Queries: patch center embeddings (B,L,D)                   │
│  Keys/Values: SFC tokens (B,T,D)                            │
│  Option B: Add spatial bias based on coordinate distance    │
│  Output: patch embeddings (B,L,D)                           │
└─────────────────────────────────────────────────────────────┘
                    │
                    ▼
              DiT Encoder Blocks
                    │
                    ▼
              Heavy Decoder (Super-Resolution)
                    │
                    ▼
              Output: Denoised/Generated Image
```

### Ablation Options

| Option | What it does | Why it helps |
|--------|--------------|---------------|
| **A: Unified Coords** | Share coord embedder between tokens & queries | Same representation space, easier alignment |
| **B: Spatial Bias** | Add distance-based attention bias | Closer tokens automatically more relevant |

### Running Experiments

```bash
# Baseline
python train_cifar10.py --encoder_type sfc

# Option A only
python train_cifar10.py --encoder_type sfc --sfc_unified_coords

# Option B only  
python train_cifar10.py --encoder_type sfc --sfc_spatial_bias

# Both options
python train_cifar10.py --encoder_type sfc --sfc_unified_coords --sfc_spatial_bias
```