# VQ-VAE v5 Training - FSQ + Terrain-Aware Metrics

## Key Changes from v4

| Change | v4 | v5 |
|--------|-----|-----|
| Quantization | VQ-VAE (EMA codebook) | **FSQ (Finite Scalar Quantization)** |
| Codebook | 512 learned codes | **390,625 implicit codes** |
| Collapse risk | High (97% similarity) | **None (by design)** |
| Terrain handling | None | **Terrain-weighted loss + metrics** |
| Key metric | Structure Recall | **Building Accuracy** (excludes terrain) |

## Why FSQ?

Our v4 codebook collapsed - all 512 codes had 0.97 cosine similarity.
FSQ eliminates this by using fixed quantization levels instead of learned codes.

## Why Terrain-Aware?

Many builds sit on terrain bases (dirt, grass). Reconstructing flat terrain
is trivial and inflates our metrics. By separating terrain from buildings,
we get a truer picture of reconstruction quality.

## Setup - Mount Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create output directory
import os
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v5'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output will be saved to: {OUTPUT_DIR}")

## Cell 1: Imports

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Cell 2: Configuration

**IMPORTANT**: Update the data paths below to match your Google Drive structure!

In [None]:
# === Data Paths (UPDATE THESE FOR YOUR DRIVE) ===
DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'

DATA_DIR = f"{DRIVE_BASE}/splits/train"      # Training H5 files
VAL_DIR = f"{DRIVE_BASE}/splits/val"         # Validation H5 files
VOCAB_PATH = f"{DRIVE_BASE}/tok2block.json"  # Token to block mapping
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/block_embeddings_v3.npy"  # V3 embeddings

OUTPUT_DIR = f"{DRIVE_BASE}/vqvae_v5"        # Output directory

# === V5 Model Architecture ===
HIDDEN_DIMS = [96, 192]  # 2 stages for 32->8
LATENT_DIM = 8  # FSQ dimension (NOT 256 like v4)
FSQ_LEVELS = [5, 5, 5, 5, 5, 5, 5, 5]  # 8 dims x 5 levels = 390,625 implicit codes
DROPOUT = 0.1

# === FSQ Settings (replaces VQ) ===
# No commitment cost, no EMA decay - FSQ has no learned codebook!

# === Structure weights ===
STRUCTURE_WEIGHT = 50.0
FALSE_AIR_WEIGHT = 5.0
VOLUME_WEIGHT = 2.0
STRUCTURE_TO_AIR_WEIGHT = 10.0
USE_SHAPE_LOSS = True
USE_ASYMMETRIC_LOSS = True

# === TERRAIN SETTINGS (NEW in v5) ===
TERRAIN_WEIGHT = 0.2  # Lower weight for terrain blocks
BUILDING_WEIGHT = 1.0  # Full weight for building blocks
AIR_WEIGHT = 0.1  # Very low weight for air

# === Training ===
TOTAL_EPOCHS = 15
BATCH_SIZE = 4
BASE_LR = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4

SEED = 42
NUM_WORKERS = 2

# Calculate implicit codebook size
IMPLICIT_CODEBOOK_SIZE = int(np.prod(FSQ_LEVELS))

print("VQ-VAE v5 (FSQ) Configuration:")
print(f"  Latent grid: 8x8x8")
print(f"  FSQ levels: {FSQ_LEVELS}")
print(f"  Implicit codebook size: {IMPLICIT_CODEBOOK_SIZE:,}")
print(f"  Hidden dims: {HIDDEN_DIMS}")
print(f"  Epochs: {TOTAL_EPOCHS}")
print(f"\nTerrain-Aware Training (NEW):")
print(f"  Terrain weight: {TERRAIN_WEIGHT}")
print(f"  Building weight: {BUILDING_WEIGHT}")
print(f"  Air weight: {AIR_WEIGHT}")

## Cell 3: Load Vocabulary and Embeddings

In [None]:
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Load V3 embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## Cell 4: Terrain Detection

In [None]:
TERRAIN_BLOCKS: Set[str] = {
    # Dirt family
    'minecraft:dirt', 'minecraft:grass_block', 'minecraft:coarse_dirt',
    'minecraft:podzol', 'minecraft:mycelium', 'minecraft:rooted_dirt',
    'minecraft:dirt_path', 'minecraft:farmland', 'minecraft:mud',

    # Stone family
    'minecraft:stone', 'minecraft:cobblestone', 'minecraft:mossy_cobblestone',
    'minecraft:bedrock', 'minecraft:deepslate', 'minecraft:tuff',
    'minecraft:granite', 'minecraft:diorite', 'minecraft:andesite',

    # Sand family
    'minecraft:sand', 'minecraft:red_sand', 'minecraft:gravel', 'minecraft:clay',

    # Water
    'minecraft:water', 'minecraft:lava',

    # Terracotta
    'minecraft:terracotta', 'minecraft:white_terracotta', 'minecraft:orange_terracotta',
    'minecraft:brown_terracotta', 'minecraft:red_terracotta',

    # Netherrack
    'minecraft:netherrack', 'minecraft:soul_sand', 'minecraft:soul_soil',

    # End
    'minecraft:end_stone',

    # Snow/Ice
    'minecraft:snow_block', 'minecraft:ice', 'minecraft:packed_ice',
}

# Build terrain token set
TERRAIN_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    base_name = block.split('[')[0] if '[' in block else block
    if base_name in TERRAIN_BLOCKS:
        TERRAIN_TOKENS.add(tok)

TERRAIN_TOKENS_TENSOR = torch.tensor(sorted(TERRAIN_TOKENS), dtype=torch.long)
print(f"Terrain tokens: {len(TERRAIN_TOKENS)}")

def detect_terrain(block_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
    """Return mask where True = terrain block (excludes air)."""
    terrain_tensor = TERRAIN_TOKENS_TENSOR.to(device)
    air_tensor = AIR_TOKENS_TENSOR.to(device)

    is_terrain = torch.isin(block_ids, terrain_tensor)
    is_air = torch.isin(block_ids, air_tensor)

    return is_terrain & ~is_air

## Cell 5: Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Cell 6: FSQ Module (replaces VQ-VAE codebook)

In [None]:
class FSQ(nn.Module):
    """
    Finite Scalar Quantization - no learned codebook, no collapse.

    Each dimension is quantized to fixed levels.
    Implicit codebook size = product of all levels.
    """

    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))

        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))

        # Precompute for index calculation
        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))

        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))

        # Track usage for metrics
        self.register_buffer('_usage', torch.zeros(self.codebook_size))

    def reset_usage(self):
        self._usage.zero_()

    def get_usage_stats(self) -> Tuple[float, float]:
        """Return (usage_fraction, perplexity)."""
        usage = (self._usage > 0).float().mean().item()

        if self._usage.sum() == 0:
            return usage, 0.0

        probs = self._usage / self._usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        perplexity = entropy.exp().item()

        return usage, perplexity

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize latent vectors.

        Args:
            z: Shape [..., dim] continuous latent vectors

        Returns:
            z_q: Quantized vectors, same shape
            indices: Integer indices [...], each in [0, codebook_size)
        """
        # Bound to (-1, 1)
        z_bounded = torch.tanh(z)

        # Quantize each dimension
        z_q_list = []
        for i in range(self.dim):
            L = self._levels[i]
            half_L = self._half_levels[i]

            z_i = z_bounded[..., i]
            z_i = z_i * half_L
            z_i = torch.round(z_i)
            z_i = torch.clamp(z_i, -half_L, half_L)
            z_i = z_i / half_L

            z_q_list.append(z_i)

        z_q = torch.stack(z_q_list, dim=-1)

        # Straight-through estimator
        z_q = z_bounded + (z_q - z_bounded).detach()

        # Compute indices
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            z_i = z_q[..., i]
            level_idx = ((z_i * half_L) + half_L).round().long()
            level_idx = torch.clamp(level_idx, 0, L - 1)
            indices = indices + level_idx * self._basis[i]

        # Track usage
        if self.training:
            with torch.no_grad():
                for idx in indices.unique():
                    if idx < self.codebook_size:
                        self._usage[idx] += (indices == idx).sum()

        return z_q, indices

print(f"FSQ module: {FSQ_LEVELS} -> {IMPLICIT_CODEBOOK_SIZE:,} implicit codes")

## Cell 7: VQ-VAE v5 Architecture with FSQ

In [None]:
class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class EncoderV5(nn.Module):
    """32x32x32 -> 8x8x8 encoder, outputs FSQ dimension."""

    def __init__(self, in_channels: int, hidden_dims: List[int], fsq_dim: int, dropout: float = 0.1):
        super().__init__()
        layers = []
        current = in_channels

        for h in hidden_dims:
            layers.extend([
                nn.Conv3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
                ResidualBlock3D(h),
            ])
            current = h

        layers.extend([
            ResidualBlock3D(current),
            ResidualBlock3D(current),
            nn.Conv3d(current, fsq_dim, 3, padding=1),  # Output FSQ dim
        ])

        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)


class DecoderV5(nn.Module):
    """8x8x8 -> 32x32x32 decoder, takes FSQ dimension as input."""

    def __init__(self, fsq_dim: int, hidden_dims: List[int], num_blocks: int, dropout: float = 0.1):
        super().__init__()

        # First expand FSQ dim to first hidden dim
        layers = [
            nn.Conv3d(fsq_dim, hidden_dims[0], 3, padding=1),
            nn.BatchNorm3d(hidden_dims[0]),
            nn.ReLU(inplace=True),
            ResidualBlock3D(hidden_dims[0]),
            ResidualBlock3D(hidden_dims[0]),
        ]

        current = hidden_dims[0]
        for h in hidden_dims[1:]:
            layers.extend([
                ResidualBlock3D(current),
                nn.ConvTranspose3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
            ])
            current = h

        # Final upsample and output
        layers.extend([
            ResidualBlock3D(current),
            nn.ConvTranspose3d(current, current, 4, stride=2, padding=1),
            nn.BatchNorm3d(current),
            nn.ReLU(inplace=True),
            nn.Conv3d(current, num_blocks, 3, padding=1),
        ])

        self.decoder = nn.Sequential(*layers)

    def forward(self, z_q):
        return self.decoder(z_q)


class TerrainWeightedLoss(nn.Module):
    """Cross-entropy with lower weight for terrain blocks."""

    def __init__(self, terrain_weight: float = 0.2, building_weight: float = 1.0, air_weight: float = 0.1):
        super().__init__()
        self.terrain_weight = terrain_weight
        self.building_weight = building_weight
        self.air_weight = air_weight

    def forward(self, logits: torch.Tensor, targets: torch.Tensor,
                terrain_mask: torch.Tensor, air_mask: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

        weights = torch.full_like(ce_loss, self.building_weight)
        weights[terrain_mask] = self.terrain_weight
        weights[air_mask] = self.air_weight

        return (ce_loss * weights).sum() / weights.sum()


class VQVAEv5(nn.Module):
    """VQ-VAE v5 with FSQ and terrain-aware training."""

    def __init__(self, vocab_size: int, emb_dim: int, hidden_dims: List[int],
                 fsq_levels: List[int], pretrained_emb: np.ndarray,
                 terrain_weight: float = 0.2, building_weight: float = 1.0,
                 air_weight: float = 0.1, dropout: float = 0.1):
        super().__init__()

        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.fsq_dim = len(fsq_levels)

        # Embeddings (frozen - no point training with FSQ)
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False

        # Encoder: 32x32x32 -> 8x8x8 x fsq_dim
        self.encoder = EncoderV5(emb_dim, hidden_dims, self.fsq_dim, dropout)

        # FSQ: quantize to implicit codebook
        self.fsq = FSQ(fsq_levels)

        # Decoder: 8x8x8 x fsq_dim -> 32x32x32 x vocab_size
        # Note: hidden_dims reversed for upsampling
        self.decoder = DecoderV5(self.fsq_dim, list(reversed(hidden_dims)), vocab_size, dropout)

        # Terrain-weighted loss
        self.terrain_loss = TerrainWeightedLoss(terrain_weight, building_weight, air_weight)

    def forward(self, block_ids: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Embed blocks
        x = self.block_emb(block_ids)  # [B, X, Y, Z, emb_dim]
        x = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, emb_dim, X, Y, Z]

        # Encode
        z_e = self.encoder(x)  # [B, fsq_dim, 8, 8, 8]

        # Permute for FSQ: [B, 8, 8, 8, fsq_dim]
        z_e = z_e.permute(0, 2, 3, 4, 1).contiguous()

        # Quantize with FSQ
        z_q, indices = self.fsq(z_e)

        # Permute back: [B, fsq_dim, 8, 8, 8]
        z_q = z_q.permute(0, 4, 1, 2, 3).contiguous()

        # Decode
        logits = self.decoder(z_q)  # [B, vocab_size, X, Y, Z]

        return {
            'logits': logits,
            'indices': indices,
            'z_e': z_e,
            'z_q': z_q,
        }

    def compute_loss(self, block_ids: torch.Tensor,
                     air_tokens: torch.Tensor,
                     terrain_tokens: torch.Tensor,
                     structure_weight: float = 50.0,
                     use_terrain_weighting: bool = True) -> Dict[str, torch.Tensor]:
        """Compute loss with terrain-aware weighting."""

        out = self(block_ids)
        logits = out['logits']

        # Flatten for loss computation
        B, C, X, Y, Z = logits.shape
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
        targets_flat = block_ids.view(-1)

        device = targets_flat.device
        air_dev = air_tokens.to(device)
        terrain_dev = terrain_tokens.to(device)

        # Masks
        is_air = torch.isin(targets_flat, air_dev)
        is_terrain = torch.isin(targets_flat, terrain_dev) & ~is_air
        is_building = ~is_air & ~is_terrain

        # Primary loss: terrain-weighted CE
        if use_terrain_weighting:
            loss = self.terrain_loss(logits_flat, targets_flat, is_terrain, is_air)
        else:
            # Fallback: structure-weighted CE
            weights = torch.ones_like(targets_flat, dtype=torch.float)
            weights[~is_air] = structure_weight
            ce = F.cross_entropy(logits_flat, targets_flat, reduction='none')
            loss = (weights * ce).sum() / weights.sum()

        # Compute metrics
        with torch.no_grad():
            preds = logits_flat.argmax(dim=1)
            is_air_pred = torch.isin(preds, air_dev)

            # Overall accuracy
            correct = (preds == targets_flat).float()
            overall_acc = correct.mean()

            # Terrain accuracy
            if is_terrain.any():
                terrain_acc = correct[is_terrain].mean()
            else:
                terrain_acc = torch.tensor(0.0, device=device)

            # Building accuracy (THE KEY METRIC)
            if is_building.any():
                building_acc = correct[is_building].mean()

                # Building recall
                building_preserved = is_building & ~is_air_pred
                building_recall = building_preserved.sum().float() / is_building.sum().float()

                # Building false air
                building_erased = is_building & is_air_pred
                building_false_air = building_erased.sum().float() / is_building.sum().float()
            else:
                building_acc = torch.tensor(0.0, device=device)
                building_recall = torch.tensor(0.0, device=device)
                building_false_air = torch.tensor(0.0, device=device)

            # Structure metrics (terrain + building)
            is_struct = ~is_air
            if is_struct.any():
                struct_acc = correct[is_struct].mean()
                struct_preserved = is_struct & ~is_air_pred
                struct_recall = struct_preserved.sum().float() / is_struct.sum().float()
            else:
                struct_acc = torch.tensor(0.0, device=device)
                struct_recall = torch.tensor(0.0, device=device)

            # Volume ratio
            orig_vol = is_struct.sum().float()
            pred_vol = (~is_air_pred).sum().float()
            vol_ratio = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0, device=device)

            # Composition stats
            total = targets_flat.numel()
            terrain_frac = is_terrain.sum().float() / total
            building_frac = is_building.sum().float() / total

        return {
            'loss': loss,
            'overall_acc': overall_acc,
            'terrain_acc': terrain_acc,
            'building_acc': building_acc,  # THE KEY METRIC
            'building_recall': building_recall,
            'building_false_air': building_false_air,
            'struct_acc': struct_acc,
            'struct_recall': struct_recall,
            'vol_ratio': vol_ratio,
            'terrain_frac': terrain_frac,
            'building_frac': building_frac,
        }


print("VQ-VAE v5 architecture with FSQ and terrain-awareness defined!")

## Cell 8: Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scaler, device,
                air_tokens, terrain_tokens, structure_weight):
    model.train()
    model.fsq.reset_usage()

    metrics = {k: 0.0 for k in [
        'loss', 'overall_acc', 'terrain_acc', 'building_acc',
        'building_recall', 'building_false_air', 'struct_acc',
        'struct_recall', 'vol_ratio', 'terrain_frac', 'building_frac'
    ]}
    n = 0
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)

        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, terrain_tokens, structure_weight)
            loss = out['loss'] / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1

    # FSQ usage stats
    usage, perplexity = model.fsq.get_usage_stats()
    metrics['fsq_usage'] = usage
    metrics['fsq_perplexity'] = perplexity

    return {k: v/n if k not in ['fsq_usage', 'fsq_perplexity'] else v for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens, terrain_tokens, structure_weight):
    model.eval()
    model.fsq.reset_usage()

    metrics = {k: 0.0 for k in [
        'loss', 'overall_acc', 'terrain_acc', 'building_acc',
        'building_recall', 'building_false_air', 'struct_acc',
        'struct_recall', 'vol_ratio', 'terrain_frac', 'building_frac'
    ]}
    n = 0

    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)

        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, terrain_tokens, structure_weight)

        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1

    usage, perplexity = model.fsq.get_usage_stats()
    metrics['fsq_usage'] = usage
    metrics['fsq_perplexity'] = perplexity

    return {k: v/n if k not in ['fsq_usage', 'fsq_perplexity'] else v for k, v in metrics.items()}


print("Training functions defined!")

## Cell 9: Create Model and Optimizer

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

# Data loaders
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

# Create model
model = VQVAEv5(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    fsq_levels=FSQ_LEVELS,
    pretrained_emb=v3_embeddings,
    terrain_weight=TERRAIN_WEIGHT,
    building_weight=BUILDING_WEIGHT,
    air_weight=AIR_WEIGHT,
    dropout=DROPOUT,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"FSQ implicit codebook: {model.fsq.codebook_size:,}")

# Optimizer (no embedding params - they're frozen)
optimizer = optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=BASE_LR,
    weight_decay=1e-5
)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"\nOptimizer: AdamW, LR={BASE_LR}")

## Cell 10: Training Loop

In [None]:
print("="*70)
print("VQ-VAE V5 TRAINING - FSQ + TERRAIN-AWARE")
print("="*70)
print(f"Key Differences from v4:")
print(f"  - FSQ instead of VQ-VAE (no codebook collapse possible)")
print(f"  - Terrain-weighted loss (buildings prioritized)")
print(f"  - Building accuracy as key metric (not inflated by terrain)")
print()

history = {
    'train_loss': [], 'train_building_acc': [], 'train_building_recall': [],
    'train_terrain_acc': [], 'train_struct_recall': [],
    'train_fsq_usage': [], 'train_fsq_perplexity': [],
    'val_loss': [], 'val_building_acc': [], 'val_building_recall': [],
    'val_terrain_acc': [], 'val_struct_recall': [],
    'val_fsq_usage': [], 'val_fsq_perplexity': [],
}

best_building_acc = 0  # Track building accuracy as key metric
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    # Train
    train_m = train_epoch(model, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, STRUCTURE_WEIGHT)

    # Validate
    val_m = validate(model, val_loader, device,
                     AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, STRUCTURE_WEIGHT)

    # Record
    history['train_loss'].append(train_m['loss'])
    history['train_building_acc'].append(train_m['building_acc'])
    history['train_building_recall'].append(train_m['building_recall'])
    history['train_terrain_acc'].append(train_m['terrain_acc'])
    history['train_struct_recall'].append(train_m['struct_recall'])
    history['train_fsq_usage'].append(train_m['fsq_usage'])
    history['train_fsq_perplexity'].append(train_m['fsq_perplexity'])

    history['val_loss'].append(val_m['loss'])
    history['val_building_acc'].append(val_m['building_acc'])
    history['val_building_recall'].append(val_m['building_recall'])
    history['val_terrain_acc'].append(val_m['terrain_acc'])
    history['val_struct_recall'].append(val_m['struct_recall'])
    history['val_fsq_usage'].append(val_m['fsq_usage'])
    history['val_fsq_perplexity'].append(val_m['fsq_perplexity'])

    # Best model - track BUILDING ACCURACY as key metric
    if val_m['building_acc'] > best_building_acc:
        best_building_acc = val_m['building_acc']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v5_best.pt")

    # Log
    print(f"Epoch {epoch+1:2d} | "
          f"Building: {train_m['building_acc']:.1%}/{val_m['building_acc']:.1%} | "
          f"Terrain: {train_m['terrain_acc']:.1%}/{val_m['terrain_acc']:.1%} | "
          f"Recall: {val_m['building_recall']:.1%} | "
          f"FSQ: {val_m['fsq_usage']:.0%} ({val_m['fsq_perplexity']:.0f})")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")

## Cell 11: Plot Training Curves

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 8))
epochs = range(1, TOTAL_EPOCHS + 1)

# Building Accuracy (KEY METRIC)
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val', linewidth=2)
ax.set_title('Building Accuracy (KEY METRIC)', fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Building Recall
ax = axes[0, 1]
ax.plot(epochs, history['train_building_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_building_recall'], 'r--', label='Val')
ax.axhline(y=0.9, color='green', linestyle='--', alpha=0.5, label='Target')
ax.set_title('Building Recall')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Terrain Accuracy
ax = axes[0, 2]
ax.plot(epochs, history['train_terrain_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_terrain_acc'], 'r--', label='Val')
ax.set_title('Terrain Accuracy (easy)')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Struct Recall (combined)
ax = axes[0, 3]
ax.plot(epochs, history['train_struct_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_struct_recall'], 'r--', label='Val')
ax.set_title('Structure Recall (combined)')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Loss
ax = axes[1, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.set_title('Loss')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# FSQ Usage
ax = axes[1, 1]
ax.plot(epochs, history['train_fsq_usage'], 'b-', label='Train')
ax.plot(epochs, history['val_fsq_usage'], 'r--', label='Val')
ax.set_title('FSQ Code Usage')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# FSQ Perplexity
ax = axes[1, 2]
ax.plot(epochs, history['train_fsq_perplexity'], 'b-', label='Train')
ax.plot(epochs, history['val_fsq_perplexity'], 'r--', label='Val')
ax.set_title('FSQ Perplexity (effective codes)')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Final metrics comparison
ax = axes[1, 3]
final = {
    'Building\nAcc': history['val_building_acc'][-1],
    'Building\nRecall': history['val_building_recall'][-1],
    'Terrain\nAcc': history['val_terrain_acc'][-1],
    'Struct\nRecall': history['val_struct_recall'][-1],
}
colors = ['green', 'orange', 'gray', 'blue']
bars = ax.bar(final.keys(), final.values(), color=colors)
ax.set_title('Final Metrics')
ax.set_ylim(0, 1)
for bar, val in zip(bars, final.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{val:.1%}', ha='center', fontsize=9)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v5_training.png", dpi=150, bbox_inches='tight')
plt.show()

## Cell 12: Save Results

In [None]:
results = {
    'config': {
        'hidden_dims': HIDDEN_DIMS,
        'fsq_levels': FSQ_LEVELS,
        'fsq_codebook_size': IMPLICIT_CODEBOOK_SIZE,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'structure_weight': STRUCTURE_WEIGHT,
        'terrain_weight': TERRAIN_WEIGHT,
        'building_weight': BUILDING_WEIGHT,
        'air_weight': AIR_WEIGHT,
        'seed': SEED,
    },
    'results': {
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_building_recall': float(history['val_building_recall'][-1]),
        'final_terrain_acc': float(history['val_terrain_acc'][-1]),
        'final_struct_recall': float(history['val_struct_recall'][-1]),
        'final_fsq_usage': float(history['val_fsq_usage'][-1]),
        'final_fsq_perplexity': float(history['val_fsq_perplexity'][-1]),
        'training_time_min': float(train_time / 60),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v5_results.json", 'w') as f:
    json.dump(results, f, indent=2)

# Save checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'vocab_size': VOCAB_SIZE,
        'emb_dim': EMBEDDING_DIM,
        'hidden_dims': HIDDEN_DIMS,
        'fsq_levels': FSQ_LEVELS,
        'terrain_weight': TERRAIN_WEIGHT,
        'building_weight': BUILDING_WEIGHT,
        'air_weight': AIR_WEIGHT,
        'dropout': DROPOUT,
    },
    'air_tokens': AIR_TOKENS_LIST,
    'terrain_tokens': sorted(TERRAIN_TOKENS),
    'best_building_acc': float(best_building_acc),
    'best_epoch': best_epoch,
}

torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v5_best_checkpoint.pt")
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v5_final.pt")

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/vqvae_v5_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v5_best_checkpoint.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v5_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v5_training.png")

print("\n" + "="*70)
print("FINAL RESULTS - FSQ + TERRAIN-AWARE")
print("="*70)
print(f"Best building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")
print(f"Final building recall:  {history['val_building_recall'][-1]:.1%}")
print(f"Final terrain accuracy: {history['val_terrain_acc'][-1]:.1%}")
print(f"FSQ perplexity:         {history['val_fsq_perplexity'][-1]:.0f} effective codes")
print(f"Training time:          {train_time/60:.1f} minutes")

print("\n" + "="*70)
print("KEY INSIGHT")
print("="*70)
print("Building Accuracy is the honest metric that matters.")
print("It excludes trivial terrain blocks that inflate overall accuracy.")
print("If this is high, we're actually reconstructing buildings well.")