# VQ-VAE v6 Training - Robust Residual FSQ (RFSQ)

## Changes from v5.1

| Change | v5.1 | v6 |
|--------|------|----|
| Quantization | Single-stage FSQ (8 dims) | **RFSQ: 2-stage residual FSQ (4 dims)** |
| Conditioning | None | **LayerNorm (prevents residual decay)** |
| Implicit codes | 390,625 | 390,625 (625 × 625) |
| New metrics | error_similarity | **per-stage usage, residual norms** |

## Why RFSQ?

v5.1 achieved 45.6% building accuracy with random errors (not material confusion).
RFSQ offers:
1. **Multi-stage residual quantization**: Stage 1 captures coarse structure, Stage 2 refines details
2. **LayerNorm conditioning**: Prevents residual magnitude decay across stages
3. **Finer granularity**: More effective code usage

## Reference
- Paper: "Improving Finite Scalar Quantization via Progressive Training"
- GitHub: https://github.com/zhuxiaoxuhit/robust_rfsq

## Goals

| Metric | v5.1 Result | v6 Target |
|--------|-------------|----------|
| Building Accuracy | 45.6% | **>55%** |
| Building Recall | 84.7% | >85% |
| Rare Block Recall | ~0% | **>20%** |
| Stage 1 Usage | N/A | >30% |
| Stage 2 Usage | N/A | >30% |

## Setup - Mount Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create output directory
import os
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v6'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output will be saved to: {OUTPUT_DIR}")

## Cell 1: Imports

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Cell 2: Configuration

**IMPORTANT**: Update the data paths below to match your Google Drive structure!

In [None]:
# === Data Paths (UPDATE THESE FOR YOUR DRIVE) ===
DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'

DATA_DIR = f"{DRIVE_BASE}/splits/train"      # Training H5 files
VAL_DIR = f"{DRIVE_BASE}/splits/val"         # Validation H5 files
VOCAB_PATH = f"{DRIVE_BASE}/vocabulary/tok2block.json"  # Token to block mapping
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/embeddings/block_embeddings_v3.npy"  # V3 embeddings

OUTPUT_DIR = f"{DRIVE_BASE}/vqvae_v6"        # Output directory (v6)

# === V6 RFSQ Configuration (CHANGED from v5.1) ===
HIDDEN_DIMS = [96, 192]  # 2 stages for 32->8 (same as v5.1)
RFSQ_LEVELS_PER_STAGE = [5, 5, 5, 5]  # 4 dims × 5 levels = 625 codes per stage
NUM_STAGES = 2  # 2 residual stages -> 625 × 625 = 390,625 total codes
DROPOUT = 0.1

# === Structure weights (unchanged) ===
STRUCTURE_WEIGHT = 50.0
FALSE_AIR_WEIGHT = 5.0
VOLUME_WEIGHT = 2.0
STRUCTURE_TO_AIR_WEIGHT = 10.0
USE_SHAPE_LOSS = True
USE_ASYMMETRIC_LOSS = True

# === TERRAIN SETTINGS (unchanged) ===
TERRAIN_WEIGHT = 0.2  # Lower weight for terrain blocks
BUILDING_WEIGHT = 1.0  # Full weight for building blocks
AIR_WEIGHT = 0.1  # Very low weight for air

# === Training ===
TOTAL_EPOCHS = 25
BATCH_SIZE = 4
BASE_LR = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4

SEED = 42
NUM_WORKERS = 2

# Calculate implicit codebook size
CODES_PER_STAGE = int(np.prod(RFSQ_LEVELS_PER_STAGE))
TOTAL_IMPLICIT_CODES = CODES_PER_STAGE ** NUM_STAGES

print("VQ-VAE v6 (RFSQ) Configuration:")
print(f"  Latent grid: 8x8x8")
print(f"  RFSQ levels per stage: {RFSQ_LEVELS_PER_STAGE}")
print(f"  Number of stages: {NUM_STAGES}")
print(f"  Codes per stage: {CODES_PER_STAGE:,}")
print(f"  Total implicit codes: {TOTAL_IMPLICIT_CODES:,}")
print(f"  Hidden dims: {HIDDEN_DIMS}")
print(f"  Epochs: {TOTAL_EPOCHS}")
print(f"\nTerrain-Aware Training:")
print(f"  Terrain weight: {TERRAIN_WEIGHT}")
print(f"  Building weight: {BUILDING_WEIGHT}")
print(f"  Air weight: {AIR_WEIGHT}")
print(f"\nKEY CHANGE: RFSQ with LayerNorm conditioning!")

## Cell 3: Load Vocabulary and Embeddings

In [None]:
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Load V3 embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## Cell 4: Terrain Detection

In [None]:
TERRAIN_BLOCKS: Set[str] = {
    # Dirt family
    'minecraft:dirt', 'minecraft:grass_block', 'minecraft:coarse_dirt',
    'minecraft:podzol', 'minecraft:mycelium', 'minecraft:rooted_dirt',
    'minecraft:dirt_path', 'minecraft:farmland', 'minecraft:mud',

    # Stone family
    'minecraft:stone', 'minecraft:cobblestone', 'minecraft:mossy_cobblestone',
    'minecraft:bedrock', 'minecraft:deepslate', 'minecraft:tuff',
    'minecraft:granite', 'minecraft:diorite', 'minecraft:andesite',

    # Sand family
    'minecraft:sand', 'minecraft:red_sand', 'minecraft:gravel', 'minecraft:clay',

    # Water
    'minecraft:water', 'minecraft:lava',

    # Terracotta
    'minecraft:terracotta', 'minecraft:white_terracotta', 'minecraft:orange_terracotta',
    'minecraft:brown_terracotta', 'minecraft:red_terracotta',

    # Netherrack
    'minecraft:netherrack', 'minecraft:soul_sand', 'minecraft:soul_soil',

    # End
    'minecraft:end_stone',

    # Snow/Ice
    'minecraft:snow_block', 'minecraft:ice', 'minecraft:packed_ice',
}

# Build terrain token set
TERRAIN_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    base_name = block.split('[')[0] if '[' in block else block
    if base_name in TERRAIN_BLOCKS:
        TERRAIN_TOKENS.add(tok)

TERRAIN_TOKENS_TENSOR = torch.tensor(sorted(TERRAIN_TOKENS), dtype=torch.long)
print(f"Terrain tokens: {len(TERRAIN_TOKENS)}")

def detect_terrain(block_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
    """Return mask where True = terrain block (excludes air)."""
    terrain_tensor = TERRAIN_TOKENS_TENSOR.to(device)
    air_tensor = AIR_TOKENS_TENSOR.to(device)

    is_terrain = torch.isin(block_ids, terrain_tensor)
    is_air = torch.isin(block_ids, air_tensor)

    return is_terrain & ~is_air

## Cell 5: Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Cell 6: FSQ Base Module

In [None]:
class FSQ(nn.Module):
    """
    Finite Scalar Quantization - base module used by each RFSQ stage.
    """

    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))

        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))

        # Precompute for index calculation
        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))

        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))

        # Track usage for metrics
        self.register_buffer('_usage', torch.zeros(self.codebook_size))

    def reset_usage(self):
        self._usage.zero_()

    def get_usage_stats(self) -> Tuple[float, float]:
        """Return (usage_fraction, perplexity)."""
        usage = (self._usage > 0).float().mean().item()

        if self._usage.sum() == 0:
            return usage, 0.0

        probs = self._usage / self._usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        perplexity = entropy.exp().item()

        return usage, perplexity

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize latent vectors.

        Args:
            z: Shape [..., dim] continuous latent vectors

        Returns:
            z_q: Quantized vectors, same shape
            indices: Integer indices [...], each in [0, codebook_size)
        """
        # Bound to (-1, 1)
        z_bounded = torch.tanh(z)

        # Quantize each dimension
        z_q_list = []
        for i in range(self.dim):
            L = self._levels[i]
            half_L = self._half_levels[i]

            z_i = z_bounded[..., i]
            z_i = z_i * half_L
            z_i = torch.round(z_i)
            z_i = torch.clamp(z_i, -half_L, half_L)
            z_i = z_i / half_L

            z_q_list.append(z_i)

        z_q = torch.stack(z_q_list, dim=-1)

        # Straight-through estimator
        z_q = z_bounded + (z_q - z_bounded).detach()

        # Compute indices
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            z_i = z_q[..., i]
            level_idx = ((z_i * half_L) + half_L).round().long()
            level_idx = torch.clamp(level_idx, 0, L - 1)
            indices = indices + level_idx * self._basis[i]

        # Track usage
        with torch.no_grad():
            for idx in indices.unique():
                if idx < self.codebook_size:
                    self._usage[idx] += (indices == idx).sum()

        return z_q, indices

print(f"FSQ base module defined")

## Cell 7: RFSQ Module (Robust Residual FSQ)

In [None]:
class InvertibleLayerNorm(nn.Module):
    """
    LayerNorm that stores statistics for exact inverse transformation.
    
    Critical for RFSQ: normalizes before quantization, then inverse-transforms
    after to preserve original scale of residuals. Without this, residual
    magnitudes decay exponentially across stages.
    """

    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps

        # Learnable affine parameters
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        # Stored during forward for inverse
        self.stored_mean = None
        self.stored_std = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Normalize input and store statistics for inverse.
        
        Args:
            x: Input tensor [B, X, Y, Z, C] (channels last)
        """
        # Normalize over spatial dims (X, Y, Z)
        self.stored_mean = x.mean(dim=(1, 2, 3), keepdim=True)
        self.stored_std = x.std(dim=(1, 2, 3), keepdim=True) + self.eps
        x_norm = (x - self.stored_mean) / self.stored_std
        return x_norm * self.weight + self.bias

    def inverse(self, x_norm: torch.Tensor) -> torch.Tensor:
        """Inverse transform using stored statistics."""
        if self.stored_mean is None:
            raise RuntimeError("Must call forward() before inverse()")
        x = (x_norm - self.bias) / self.weight
        return x * self.stored_std + self.stored_mean


class RFSQStage(nn.Module):
    """Single stage of Residual FSQ with LayerNorm conditioning."""

    def __init__(self, levels: List[int]):
        super().__init__()
        self.levels = levels
        self.fsq = FSQ(levels)
        self.layernorm = InvertibleLayerNorm(len(levels))

    @property
    def codebook_size(self) -> int:
        return self.fsq.codebook_size

    def reset_usage(self):
        self.fsq.reset_usage()

    def get_usage_stats(self) -> Tuple[float, float]:
        return self.fsq.get_usage_stats()

    def forward(self, residual: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize residual with LayerNorm conditioning.
        
        Returns:
            z_q: Quantized (in original scale)
            new_residual: residual - z_q
            indices: FSQ indices
        """
        # 1. Normalize residual
        z_norm = self.layernorm(residual)

        # 2. Quantize in normalized space
        z_q_norm, indices = self.fsq(z_norm)

        # 3. Inverse transform back to original scale
        z_q = self.layernorm.inverse(z_q_norm)

        # 4. Compute new residual
        new_residual = residual - z_q

        return z_q, new_residual, indices


class RFSQ(nn.Module):
    """Robust Residual FSQ with multiple stages."""

    def __init__(self, levels_per_stage: List[int], num_stages: int = 2):
        super().__init__()
        self.levels_per_stage = levels_per_stage
        self.num_stages = num_stages
        self.dim = len(levels_per_stage)

        self.stages = nn.ModuleList([
            RFSQStage(levels_per_stage) for _ in range(num_stages)
        ])

        codes_per_stage = int(np.prod(levels_per_stage))
        self.codebook_size = codes_per_stage ** num_stages
        self.codes_per_stage = codes_per_stage

    def reset_usage(self):
        for stage in self.stages:
            stage.reset_usage()

    def get_usage_stats(self) -> Dict[str, Tuple[float, float]]:
        """Return per-stage (usage, perplexity)."""
        return {f'stage{i}': stage.get_usage_stats() 
                for i, stage in enumerate(self.stages)}

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Multi-stage residual quantization.
        
        Args:
            z: Encoder output [B, X, Y, Z, C]
            
        Returns:
            z_q: Quantized sum of all stages
            all_indices: List of indices from each stage
        """
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []

        for stage in self.stages:
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)

        return z_q_sum, all_indices

    def forward_with_norms(self, z: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor], List[float]]:
        """Forward that also returns residual norms for monitoring."""
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []
        residual_norms = []

        for stage in self.stages:
            residual_norms.append(residual.norm().item())
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)

        residual_norms.append(residual.norm().item())  # Final residual
        return z_q_sum, all_indices, residual_norms


print(f"RFSQ module defined: {NUM_STAGES} stages × {CODES_PER_STAGE:,} codes = {TOTAL_IMPLICIT_CODES:,} total")

## Cell 8: VQ-VAE v6 Architecture with RFSQ

In [None]:
class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class EncoderV6(nn.Module):
    """32x32x32 -> 8x8x8 encoder, outputs RFSQ dimension (4 instead of 8)."""

    def __init__(self, in_channels: int, hidden_dims: List[int], rfsq_dim: int, dropout: float = 0.1):
        super().__init__()
        layers = []
        current = in_channels

        for h in hidden_dims:
            layers.extend([
                nn.Conv3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
                ResidualBlock3D(h),
            ])
            current = h

        layers.extend([
            ResidualBlock3D(current),
            ResidualBlock3D(current),
            nn.Conv3d(current, rfsq_dim, 3, padding=1),  # Output RFSQ dim
        ])

        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)


class DecoderV6(nn.Module):
    """8x8x8 -> 32x32x32 decoder, takes RFSQ dimension as input."""

    def __init__(self, rfsq_dim: int, hidden_dims: List[int], num_blocks: int, dropout: float = 0.1):
        super().__init__()

        layers = [
            nn.Conv3d(rfsq_dim, hidden_dims[0], 3, padding=1),
            nn.BatchNorm3d(hidden_dims[0]),
            nn.ReLU(inplace=True),
            ResidualBlock3D(hidden_dims[0]),
            ResidualBlock3D(hidden_dims[0]),
        ]

        current = hidden_dims[0]
        for h in hidden_dims[1:]:
            layers.extend([
                ResidualBlock3D(current),
                nn.ConvTranspose3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
            ])
            current = h

        layers.extend([
            ResidualBlock3D(current),
            nn.ConvTranspose3d(current, current, 4, stride=2, padding=1),
            nn.BatchNorm3d(current),
            nn.ReLU(inplace=True),
            nn.Conv3d(current, num_blocks, 3, padding=1),
        ])

        self.decoder = nn.Sequential(*layers)

    def forward(self, z_q):
        return self.decoder(z_q)


class TerrainWeightedLoss(nn.Module):
    """Cross-entropy with lower weight for terrain blocks."""

    def __init__(self, terrain_weight: float = 0.2, building_weight: float = 1.0, air_weight: float = 0.1):
        super().__init__()
        self.terrain_weight = terrain_weight
        self.building_weight = building_weight
        self.air_weight = air_weight

    def forward(self, logits: torch.Tensor, targets: torch.Tensor,
                terrain_mask: torch.Tensor, air_mask: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

        weights = torch.full_like(ce_loss, self.building_weight)
        weights[terrain_mask] = self.terrain_weight
        weights[air_mask] = self.air_weight

        return (ce_loss * weights).sum() / weights.sum()


class VQVAEv6(nn.Module):
    """VQ-VAE v6 with Robust Residual FSQ."""

    def __init__(self, vocab_size: int, emb_dim: int, hidden_dims: List[int],
                 rfsq_levels: List[int], num_stages: int, pretrained_emb: np.ndarray,
                 terrain_weight: float = 0.2, building_weight: float = 1.0,
                 air_weight: float = 0.1, dropout: float = 0.1):
        super().__init__()

        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.rfsq_dim = len(rfsq_levels)
        self.num_stages = num_stages

        # Embeddings (frozen)
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False

        # Encoder: 32x32x32 -> 8x8x8 x rfsq_dim
        self.encoder = EncoderV6(emb_dim, hidden_dims, self.rfsq_dim, dropout)

        # RFSQ: multi-stage residual quantization
        self.rfsq = RFSQ(rfsq_levels, num_stages)

        # Decoder: 8x8x8 x rfsq_dim -> 32x32x32 x vocab_size
        self.decoder = DecoderV6(self.rfsq_dim, list(reversed(hidden_dims)), vocab_size, dropout)

        # Loss
        self.terrain_loss = TerrainWeightedLoss(terrain_weight, building_weight, air_weight)

    def forward(self, block_ids: torch.Tensor, return_norms: bool = False) -> Dict[str, Any]:
        # Embed blocks
        x = self.block_emb(block_ids)  # [B, X, Y, Z, emb_dim]
        x = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, emb_dim, X, Y, Z]

        # Encode
        z_e = self.encoder(x)  # [B, rfsq_dim, 8, 8, 8]

        # Permute for RFSQ: [B, 8, 8, 8, rfsq_dim]
        z_e = z_e.permute(0, 2, 3, 4, 1).contiguous()

        # Quantize with RFSQ
        if return_norms:
            z_q, all_indices, residual_norms = self.rfsq.forward_with_norms(z_e)
        else:
            z_q, all_indices = self.rfsq(z_e)
            residual_norms = None

        # Permute back: [B, rfsq_dim, 8, 8, 8]
        z_q = z_q.permute(0, 4, 1, 2, 3).contiguous()

        # Decode
        logits = self.decoder(z_q)  # [B, vocab_size, X, Y, Z]

        result = {
            'logits': logits,
            'all_indices': all_indices,
            'z_e': z_e,
            'z_q': z_q,
        }
        if residual_norms is not None:
            result['residual_norms'] = residual_norms
        return result

    def compute_loss(self, block_ids: torch.Tensor,
                     air_tokens: torch.Tensor,
                     terrain_tokens: torch.Tensor,
                     structure_weight: float = 50.0,
                     return_norms: bool = False) -> Dict[str, Any]:
        """Compute loss with terrain-aware weighting."""

        out = self(block_ids, return_norms=return_norms)
        logits = out['logits']

        # Flatten for loss
        B, C, X, Y, Z = logits.shape
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
        targets_flat = block_ids.view(-1)

        device = targets_flat.device
        air_dev = air_tokens.to(device)
        terrain_dev = terrain_tokens.to(device)

        # Masks
        is_air = torch.isin(targets_flat, air_dev)
        is_terrain = torch.isin(targets_flat, terrain_dev) & ~is_air
        is_building = ~is_air & ~is_terrain

        # Primary loss
        loss = self.terrain_loss(logits_flat, targets_flat, is_terrain, is_air)

        # Metrics
        with torch.no_grad():
            preds = logits_flat.argmax(dim=1)
            is_air_pred = torch.isin(preds, air_dev)

            correct = (preds == targets_flat).float()
            overall_acc = correct.mean()

            terrain_acc = correct[is_terrain].mean() if is_terrain.any() else torch.tensor(0.0, device=device)

            if is_building.any():
                building_acc = correct[is_building].mean()
                building_preserved = is_building & ~is_air_pred
                building_recall = building_preserved.sum().float() / is_building.sum().float()
                building_erased = is_building & is_air_pred
                building_false_air = building_erased.sum().float() / is_building.sum().float()
            else:
                building_acc = torch.tensor(0.0, device=device)
                building_recall = torch.tensor(0.0, device=device)
                building_false_air = torch.tensor(0.0, device=device)

            is_struct = ~is_air
            if is_struct.any():
                struct_acc = correct[is_struct].mean()
                struct_preserved = is_struct & ~is_air_pred
                struct_recall = struct_preserved.sum().float() / is_struct.sum().float()
            else:
                struct_acc = torch.tensor(0.0, device=device)
                struct_recall = torch.tensor(0.0, device=device)

            # Volume ratio
            orig_vol = is_struct.sum().float()
            pred_vol = (~is_air_pred).sum().float()
            vol_ratio = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0, device=device)

            # Error similarity
            wrong_building = is_building & (preds != targets_flat)
            if wrong_building.any():
                pred_emb = self.block_emb.weight[preds[wrong_building]]
                target_emb = self.block_emb.weight[targets_flat[wrong_building]]
                error_similarity = F.cosine_similarity(pred_emb, target_emb, dim=-1).mean()
            else:
                error_similarity = torch.tensor(0.0, device=device)

            wrong_terrain = is_terrain & (preds != targets_flat)
            if wrong_terrain.any():
                pred_emb_t = self.block_emb.weight[preds[wrong_terrain]]
                target_emb_t = self.block_emb.weight[targets_flat[wrong_terrain]]
                terrain_error_similarity = F.cosine_similarity(pred_emb_t, target_emb_t, dim=-1).mean()
            else:
                terrain_error_similarity = torch.tensor(0.0, device=device)

        result = {
            'loss': loss,
            'overall_acc': overall_acc,
            'terrain_acc': terrain_acc,
            'building_acc': building_acc,
            'building_recall': building_recall,
            'building_false_air': building_false_air,
            'struct_acc': struct_acc,
            'struct_recall': struct_recall,
            'vol_ratio': vol_ratio,
            'error_similarity': error_similarity,
            'terrain_error_similarity': terrain_error_similarity,
        }
        if 'residual_norms' in out:
            result['residual_norms'] = out['residual_norms']
        return result


print("VQ-VAE v6 architecture defined!")
print(f"RFSQ: {NUM_STAGES} stages with LayerNorm conditioning")

## Cell 9: Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scaler, device,
                air_tokens, terrain_tokens, structure_weight):
    model.train()
    model.rfsq.reset_usage()

    metrics = {k: 0.0 for k in [
        'loss', 'overall_acc', 'terrain_acc', 'building_acc',
        'building_recall', 'building_false_air', 'struct_acc',
        'struct_recall', 'vol_ratio', 'error_similarity',
        'terrain_error_similarity',
    ]}
    grad_norms = []
    all_residual_norms = []  # Track residual decay
    n = 0
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)

        with torch.amp.autocast('cuda', enabled=USE_AMP):
            # Get residual norms every 100 batches
            return_norms = (batch_idx % 100 == 0)
            out = model.compute_loss(batch, air_tokens, terrain_tokens, 
                                    structure_weight, return_norms=return_norms)
            loss = out['loss'] / GRAD_ACCUM_STEPS

        if return_norms and 'residual_norms' in out:
            all_residual_norms.append(out['residual_norms'])

        scaler.scale(loss).backward()

        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            grad_norms.append(grad_norm.item())
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1

    # Per-stage RFSQ usage
    stage_stats = model.rfsq.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp

    metrics['grad_norm'] = sum(grad_norms) / len(grad_norms) if grad_norms else 0.0

    # Compute average residual decay
    if all_residual_norms:
        avg_norms = np.mean(all_residual_norms, axis=0)
        if len(avg_norms) >= 2 and avg_norms[0] > 0:
            metrics['residual_decay'] = avg_norms[-1] / avg_norms[0]  # Final/Initial
        else:
            metrics['residual_decay'] = 1.0
    else:
        metrics['residual_decay'] = 1.0

    return {k: v/n if k not in ['stage0_usage', 'stage0_perplexity', 
                                 'stage1_usage', 'stage1_perplexity',
                                 'grad_norm', 'residual_decay'] else v 
            for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens, terrain_tokens, structure_weight):
    model.eval()
    model.rfsq.reset_usage()

    metrics = {k: 0.0 for k in [
        'loss', 'overall_acc', 'terrain_acc', 'building_acc',
        'building_recall', 'building_false_air', 'struct_acc',
        'struct_recall', 'vol_ratio', 'error_similarity',
        'terrain_error_similarity',
    ]}
    all_residual_norms = []
    n = 0

    for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
        batch = batch.to(device)

        with torch.amp.autocast('cuda', enabled=USE_AMP):
            return_norms = (batch_idx % 50 == 0)
            out = model.compute_loss(batch, air_tokens, terrain_tokens,
                                    structure_weight, return_norms=return_norms)

        if return_norms and 'residual_norms' in out:
            all_residual_norms.append(out['residual_norms'])

        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1

    stage_stats = model.rfsq.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp

    if all_residual_norms:
        avg_norms = np.mean(all_residual_norms, axis=0)
        if len(avg_norms) >= 2 and avg_norms[0] > 0:
            metrics['residual_decay'] = avg_norms[-1] / avg_norms[0]
        else:
            metrics['residual_decay'] = 1.0
    else:
        metrics['residual_decay'] = 1.0

    return {k: v/n if k not in ['stage0_usage', 'stage0_perplexity',
                                 'stage1_usage', 'stage1_perplexity',
                                 'residual_decay'] else v
            for k, v in metrics.items()}


print("Training functions defined!")
print("NEW metrics: per-stage usage/perplexity, residual_decay")

## Cell 10: Create Model and Optimizer

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

# Data loaders
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

# Create model
model = VQVAEv6(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    rfsq_levels=RFSQ_LEVELS_PER_STAGE,
    num_stages=NUM_STAGES,
    pretrained_emb=v3_embeddings,
    terrain_weight=TERRAIN_WEIGHT,
    building_weight=BUILDING_WEIGHT,
    air_weight=AIR_WEIGHT,
    dropout=DROPOUT,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"RFSQ stages: {model.num_stages}")
print(f"RFSQ total codes: {model.rfsq.codebook_size:,}")

# Optimizer
optimizer = optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=BASE_LR,
    weight_decay=1e-5
)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"\nOptimizer: AdamW, LR={BASE_LR}")

## Cell 11: Training Loop

In [None]:
print("="*70)
print("VQ-VAE V6 TRAINING - ROBUST RESIDUAL FSQ (RFSQ)")
print("="*70)
print(f"Key changes from v5.1:")
print(f"  - RFSQ: {NUM_STAGES}-stage residual quantization")
print(f"  - LayerNorm conditioning (prevents residual decay)")
print(f"  - Per-stage metrics tracking")
print()

history = {
    # Core metrics
    'train_loss': [], 'train_building_acc': [], 'train_building_recall': [],
    'train_terrain_acc': [], 'train_struct_recall': [],
    'val_loss': [], 'val_building_acc': [], 'val_building_recall': [],
    'val_terrain_acc': [], 'val_struct_recall': [],
    # Per-stage RFSQ metrics
    'train_stage0_usage': [], 'train_stage0_perplexity': [],
    'train_stage1_usage': [], 'train_stage1_perplexity': [],
    'val_stage0_usage': [], 'val_stage0_perplexity': [],
    'val_stage1_usage': [], 'val_stage1_perplexity': [],
    # Diagnostic metrics
    'train_building_false_air': [], 'val_building_false_air': [],
    'train_vol_ratio': [], 'val_vol_ratio': [],
    'train_error_similarity': [], 'val_error_similarity': [],
    'train_terrain_error_similarity': [], 'val_terrain_error_similarity': [],
    'train_grad_norm': [],
    # NEW: Residual decay tracking
    'train_residual_decay': [], 'val_residual_decay': [],
}

best_building_acc = 0
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    train_m = train_epoch(model, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, STRUCTURE_WEIGHT)

    val_m = validate(model, val_loader, device,
                     AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, STRUCTURE_WEIGHT)

    # Record metrics
    history['train_loss'].append(train_m['loss'])
    history['train_building_acc'].append(train_m['building_acc'])
    history['train_building_recall'].append(train_m['building_recall'])
    history['train_terrain_acc'].append(train_m['terrain_acc'])
    history['train_struct_recall'].append(train_m['struct_recall'])

    history['val_loss'].append(val_m['loss'])
    history['val_building_acc'].append(val_m['building_acc'])
    history['val_building_recall'].append(val_m['building_recall'])
    history['val_terrain_acc'].append(val_m['terrain_acc'])
    history['val_struct_recall'].append(val_m['struct_recall'])

    # Per-stage metrics
    history['train_stage0_usage'].append(train_m.get('stage0_usage', 0))
    history['train_stage0_perplexity'].append(train_m.get('stage0_perplexity', 0))
    history['train_stage1_usage'].append(train_m.get('stage1_usage', 0))
    history['train_stage1_perplexity'].append(train_m.get('stage1_perplexity', 0))

    history['val_stage0_usage'].append(val_m.get('stage0_usage', 0))
    history['val_stage0_perplexity'].append(val_m.get('stage0_perplexity', 0))
    history['val_stage1_usage'].append(val_m.get('stage1_usage', 0))
    history['val_stage1_perplexity'].append(val_m.get('stage1_perplexity', 0))

    # Diagnostic metrics
    history['train_building_false_air'].append(train_m['building_false_air'])
    history['val_building_false_air'].append(val_m['building_false_air'])
    history['train_vol_ratio'].append(train_m['vol_ratio'])
    history['val_vol_ratio'].append(val_m['vol_ratio'])
    history['train_error_similarity'].append(train_m['error_similarity'])
    history['val_error_similarity'].append(val_m['error_similarity'])
    history['train_terrain_error_similarity'].append(train_m['terrain_error_similarity'])
    history['val_terrain_error_similarity'].append(val_m['terrain_error_similarity'])
    history['train_grad_norm'].append(train_m['grad_norm'])

    # Residual decay
    history['train_residual_decay'].append(train_m.get('residual_decay', 1.0))
    history['val_residual_decay'].append(val_m.get('residual_decay', 1.0))

    # Best model
    if val_m['building_acc'] > best_building_acc:
        best_building_acc = val_m['building_acc']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v6_best.pt")

    # Log
    print(f"Epoch {epoch+1:2d} | "
          f"Build: {train_m['building_acc']:.1%}/{val_m['building_acc']:.1%} | "
          f"S0: {val_m.get('stage0_perplexity', 0):.0f} S1: {val_m.get('stage1_perplexity', 0):.0f} | "
          f"Decay: {val_m.get('residual_decay', 1.0):.2f} | "
          f"ErrSim: {val_m['error_similarity']:.2f}")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")

## Cell 12: Plot Training Curves

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 16))
epochs = range(1, TOTAL_EPOCHS + 1)

# Row 1: Core metrics
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val')
ax.set_title('Building Accuracy (KEY)', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 1]
ax.plot(epochs, history['train_building_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_building_recall'], 'r--', label='Val')
ax.axhline(y=0.85, color='g', linestyle='--', alpha=0.5)
ax.set_title('Building Recall')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 2]
ax.plot(epochs, history['train_terrain_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_terrain_acc'], 'r--', label='Val')
ax.set_title('Terrain Accuracy')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 3]
ax.plot(epochs, history['train_struct_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_struct_recall'], 'r--', label='Val')
ax.set_title('Structure Recall')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 2: Per-stage RFSQ metrics
ax = axes[1, 0]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train S0')
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val S0')
ax.set_title('Stage 0 Perplexity', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.plot(epochs, history['train_stage1_perplexity'], 'b-', label='Train S1')
ax.plot(epochs, history['val_stage1_perplexity'], 'r--', label='Val S1')
ax.set_title('Stage 1 Perplexity', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 2]
ax.plot(epochs, history['train_stage0_usage'], 'b-', label='Train S0')
ax.plot(epochs, history['val_stage0_usage'], 'r--', label='Val S0')
ax.plot(epochs, history['train_stage1_usage'], 'g-', label='Train S1')
ax.plot(epochs, history['val_stage1_usage'], 'm--', label='Val S1')
ax.axhline(y=0.3, color='orange', linestyle='--', alpha=0.5, label='Target 30%')
ax.set_title('Stage Usage')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 3]
ax.plot(epochs, history['train_residual_decay'], 'b-', label='Train')
ax.plot(epochs, history['val_residual_decay'], 'r--', label='Val')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='<50% target')
ax.set_title('Residual Decay (Final/Initial)', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 3: Loss and diagnostics
ax = axes[2, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.set_title('Loss')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 1]
ax.plot(epochs, history['train_building_false_air'], 'b-', label='Train')
ax.plot(epochs, history['val_building_false_air'], 'r--', label='Val')
ax.axhline(y=0.1, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Building False Air')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 2]
ax.plot(epochs, history['train_vol_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_vol_ratio'], 'r--', label='Val')
ax.axhline(y=1.0, color='g', linestyle='--', alpha=0.5)
ax.set_title('Volume Ratio')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 3]
ax.plot(epochs, history['train_grad_norm'], 'g-')
ax.axhline(y=1.0, color='r', linestyle='--', alpha=0.5)
ax.set_title('Gradient Norm')
ax.grid(True, alpha=0.3)

# Row 4: Error similarity and final metrics
ax = axes[3, 0]
ax.plot(epochs, history['val_error_similarity'], 'b-', label='Building')
ax.plot(epochs, history['val_terrain_error_similarity'], 'g-', label='Terrain')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Error Similarity (Val)')
ax.legend(); ax.grid(True, alpha=0.3)

# Comparison: v5.1 vs v6
ax = axes[3, 1]
v51_baseline = {'Build Acc': 0.456, 'Build Recall': 0.847, 'Err Sim': 0.24}
v6_final = {
    'Build Acc': history['val_building_acc'][-1],
    'Build Recall': history['val_building_recall'][-1],
    'Err Sim': history['val_error_similarity'][-1],
}
x = np.arange(len(v51_baseline))
width = 0.35
ax.bar(x - width/2, list(v51_baseline.values()), width, label='v5.1', color='gray')
ax.bar(x + width/2, list(v6_final.values()), width, label='v6 (RFSQ)', color='green')
ax.set_xticks(x)
ax.set_xticklabels(v51_baseline.keys())
ax.set_title('v5.1 vs v6 Comparison')
ax.legend(); ax.grid(True, alpha=0.3)

# Final metrics
ax = axes[3, 2]
final = {
    'Build\nAcc': history['val_building_acc'][-1],
    'Build\nRecall': history['val_building_recall'][-1],
    'S0\nUsage': history['val_stage0_usage'][-1],
    'S1\nUsage': history['val_stage1_usage'][-1],
    'Res\nDecay': history['val_residual_decay'][-1],
}
colors = ['green', 'orange', 'blue', 'purple', 'red']
bars = ax.bar(final.keys(), final.values(), color=colors)
ax.set_title('Final Val Metrics')
ax.set_ylim(0, 1)
for bar, val in zip(bars, final.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{val:.2f}', ha='center', fontsize=9)

# Summary text
ax = axes[3, 3]
ax.axis('off')
summary = f"""VQ-VAE v6 (RFSQ) Results
─────────────────────────
Best Building Acc: {best_building_acc:.1%} (epoch {best_epoch})
Final Building Recall: {history['val_building_recall'][-1]:.1%}

RFSQ Stage Perplexity:
  Stage 0: {history['val_stage0_perplexity'][-1]:.0f}
  Stage 1: {history['val_stage1_perplexity'][-1]:.0f}

Residual Decay: {history['val_residual_decay'][-1]:.2f}
  (<0.5 = LayerNorm working)

Error Similarity: {history['val_error_similarity'][-1]:.2f}
  (<0.4 = random errors)

Training Time: {train_time/60:.1f} min"""
ax.text(0.1, 0.9, summary, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v6_training.png", dpi=150, bbox_inches='tight')
plt.show()

# Print analysis
print("\n" + "="*70)
print("RFSQ ANALYSIS")
print("="*70)

print(f"\nResidual Decay: {history['val_residual_decay'][-1]:.3f}")
if history['val_residual_decay'][-1] < 0.5:
    print("  GOOD: LayerNorm preventing residual magnitude decay")
else:
    print("  WARNING: Residual decay too high, LayerNorm may not be effective")

print(f"\nStage 0 Perplexity: {history['val_stage0_perplexity'][-1]:.0f}")
print(f"Stage 1 Perplexity: {history['val_stage1_perplexity'][-1]:.0f}")
if history['val_stage1_perplexity'][-1] > 100:
    print("  GOOD: Stage 1 capturing meaningful residual information")
else:
    print("  WARNING: Stage 1 may not be useful (low perplexity)")

## Cell 13: Save Results

In [None]:
results = {
    'config': {
        'version': 'v6',
        'changes_from_v51': [
            'RFSQ: 2-stage residual quantization',
            'LayerNorm conditioning (prevents residual decay)',
            'Per-stage metrics tracking',
            'Residual decay monitoring',
        ],
        'hidden_dims': HIDDEN_DIMS,
        'rfsq_levels_per_stage': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'codes_per_stage': CODES_PER_STAGE,
        'total_implicit_codes': TOTAL_IMPLICIT_CODES,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'terrain_weight': TERRAIN_WEIGHT,
        'building_weight': BUILDING_WEIGHT,
        'air_weight': AIR_WEIGHT,
        'seed': SEED,
    },
    'results': {
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_building_recall': float(history['val_building_recall'][-1]),
        'final_terrain_acc': float(history['val_terrain_acc'][-1]),
        'final_struct_recall': float(history['val_struct_recall'][-1]),
        'final_stage0_perplexity': float(history['val_stage0_perplexity'][-1]),
        'final_stage1_perplexity': float(history['val_stage1_perplexity'][-1]),
        'final_stage0_usage': float(history['val_stage0_usage'][-1]),
        'final_stage1_usage': float(history['val_stage1_usage'][-1]),
        'final_residual_decay': float(history['val_residual_decay'][-1]),
        'final_error_similarity': float(history['val_error_similarity'][-1]),
        'final_terrain_error_similarity': float(history['val_terrain_error_similarity'][-1]),
        'training_time_min': float(train_time / 60),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v6_results.json", 'w') as f:
    json.dump(results, f, indent=2)

# Save checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'version': 'v6',
        'vocab_size': VOCAB_SIZE,
        'emb_dim': EMBEDDING_DIM,
        'hidden_dims': HIDDEN_DIMS,
        'rfsq_levels': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'terrain_weight': TERRAIN_WEIGHT,
        'building_weight': BUILDING_WEIGHT,
        'air_weight': AIR_WEIGHT,
        'dropout': DROPOUT,
    },
    'air_tokens': AIR_TOKENS_LIST,
    'terrain_tokens': sorted(TERRAIN_TOKENS),
    'best_building_acc': float(best_building_acc),
    'best_epoch': best_epoch,
}

torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v6_best_checkpoint.pt")
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v6_final.pt")

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/vqvae_v6_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v6_best_checkpoint.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v6_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v6_training.png")

print("\n" + "="*70)
print("FINAL RESULTS - VQ-VAE v6 (RFSQ)")
print("="*70)
print(f"Best building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")
print(f"Final building recall:  {history['val_building_recall'][-1]:.1%}")
print(f"\nRFSQ Perplexity (effective codes per stage):")
print(f"  Stage 0: {history['val_stage0_perplexity'][-1]:.0f}")
print(f"  Stage 1: {history['val_stage1_perplexity'][-1]:.0f}")
print(f"\nResidual Decay: {history['val_residual_decay'][-1]:.3f}")
print(f"  (< 0.5 means LayerNorm is preventing decay)")
print(f"\nError Similarity: {history['val_error_similarity'][-1]:.3f}")
print(f"  (< 0.4 = random errors, > 0.7 = material confusion)")
print(f"\nTraining time: {train_time/60:.1f} minutes")

# Comparison with v5.1
print("\n" + "="*70)
print("COMPARISON: v5.1 vs v6")
print("="*70)
print(f"{'Metric':<25} {'v5.1':>10} {'v6':>10} {'Change':>10}")
print("-" * 55)
v51 = {'Building Acc': 0.456, 'Building Recall': 0.847, 'Error Sim': 0.24}
v6_res = {
    'Building Acc': history['val_building_acc'][-1],
    'Building Recall': history['val_building_recall'][-1],
    'Error Sim': history['val_error_similarity'][-1],
}
for k in v51:
    change = v6_res[k] - v51[k]
    sign = '+' if change > 0 else ''
    print(f"{k:<25} {v51[k]:>10.1%} {v6_res[k]:>10.1%} {sign}{change:>9.1%}")