# VQ-VAE v8-B Training - 16Ãƒâ€”16Ãƒâ€”16 Latent Resolution Upgrade

## Changes from v6-freq

| Change | v6-freq | v8-B |
|--------|---------|------|
| Latent resolution | 8Ãƒâ€”8Ãƒâ€”8 (512 positions) | **16Ãƒâ€”16Ãƒâ€”16 (4,096 positions)** |
| Compression ratio | 64:1 | **8:1** |
| Downsampling stages | 2 (32Ã¢â€ â€™16Ã¢â€ â€™8) | **1 (32Ã¢â€ â€™16)** |
| ResBlocks at latent | 2 | **6 (more capacity)** |
| Batch size | 4 | **2 (larger model)** |
| Gradient checkpointing | No | **Yes (required for memory)** |
| Volume penalty | No | **Yes (fixes 1.68x over-prediction)** |
| Perceptual loss | No | **Yes (spatial smoothness)** |
| Training epochs | 25 | **35** |

## Why 16Ãƒâ€”16Ãƒâ€”16 Latent?

v6-freq's 8Ãƒâ€”8Ãƒâ€”8 latent (512 spatial positions) is the primary bottleneck:
- 64:1 compression ratio loses fine structural detail
- Only 512 positions to represent entire 32Ã‚Â³ structure

v8-B upgrades to 16Ãƒâ€”16Ãƒâ€”16 (4,096 positions):
- 8:1 compression ratio preserves more structure
- 8x more spatial positions
- Matches MaskGIT's approach for image generation (16:1 ratio for 256px images)

## Goals

| Metric | v6-freq | v8-B Target |
|--------|---------|-------------|
| Building Accuracy | 49.2% | **60-65%** |
| Building Recall | 97.0% | 92-95% |
| Volume Ratio | 1.68x | **1.1-1.2x** |
| False Air Rate | 3.0% | 4-6% |

## Architecture

```
Encoder: 32Ã‚Â³ Ã¢â€ â€™ conv(4,2) Ã¢â€ â€™ 16Ã‚Â³ Ã¢â€ â€™ ResBlocks(6) Ã¢â€ â€™ latent[4,16,16,16]
RFSQ: [5,5,5,5] levels Ãƒâ€” 2 stages (same as v6-freq)
Decoder: latent[4,16,16,16] Ã¢â€ â€™ ResBlocks(6) Ã¢â€ â€™ convT(4,2) Ã¢â€ â€™ 32Ã‚Â³ Ã¢â€ â€™ logits
```

## Setup - Mount Google Drive (if using Colab)

In [None]:
# Uncomment if using Google Colab
# from google.colab import drive
# drive.mount('/content/drive')

import os
import sys
from pathlib import Path

# Detect environment and set paths
if 'COLAB_GPU' in os.environ:
    # Google Colab
    OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v8b'
    DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'
    # Try Drive first (where data is), then /content
    drive_root = Path('/content/drive/MyDrive/minecraft_ai')
    if (drive_root / 'src' / 'models' / 'vqvae.py').exists():
        PROJECT_ROOT = drive_root
        print(f"Using repository from Drive: {PROJECT_ROOT}")
    else:
        PROJECT_ROOT = Path('/content/minecraft_ai')  # Will be cloned here
        print(f"Using repository from /content: {PROJECT_ROOT}")
    
    # Add project root to Python path for Colab
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
        print(f"Added {PROJECT_ROOT} to Python path")
    
    # Check if repository needs to be cloned
    if not (PROJECT_ROOT / 'src' / 'models' / 'vqvae.py').exists():
        print(f"WARNING: Repository not found at {PROJECT_ROOT}")
        print("Please clone the repository first:")
        print("  !git clone <your-repo-url> /content/minecraft_ai")
        print("Or if using Drive:")
        print("  PROJECT_ROOT = Path('/content/drive/MyDrive/minecraft_ai')")
else:
    # Local - find project root
    current = Path.cwd()
    # Look for marker files to find project root
    while current != current.parent:
        if (current / 'src' / 'models' / 'vqvae.py').exists():
            PROJECT_ROOT = current
            break
        current = current.parent
    else:
        # Fallback: assume we're in minecraft_ai directory
        PROJECT_ROOT = Path.cwd()
    
    OUTPUT_DIR = str(PROJECT_ROOT / 'data' / 'output' / 'vqvae' / 'v8b')
    DRIVE_BASE = str(PROJECT_ROOT / 'data')

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Project root: {PROJECT_ROOT}")
print(f"Python path includes project root: {str(PROJECT_ROOT) in sys.path}")
print(f"src/models/vqvae.py exists: {(PROJECT_ROOT / 'src' / 'models' / 'vqvae.py').exists()}")
print(f"Output will be saved to: {OUTPUT_DIR}")

## Cell 1: Imports

In [None]:
import json
import random
import sys
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint
from tqdm.notebook import tqdm

# Ensure PROJECT_ROOT is set and add to path
if 'PROJECT_ROOT' not in globals():
    # If running in Colab, set default path
    if 'COLAB_GPU' in os.environ:
        PROJECT_ROOT = Path('/content/minecraft_ai')
    else:
        # Find project root locally
        current = Path.cwd()
        while current != current.parent:
            if (current / 'src' / 'models' / 'vqvae.py').exists():
                PROJECT_ROOT = current
                break
            current = current.parent
        else:
            PROJECT_ROOT = Path.cwd()

# Add project root to path
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
    print(f"Added to path: {PROJECT_ROOT}")

# Verify the module exists
vqvae_path = PROJECT_ROOT / 'src' / 'models' / 'vqvae.py'
if not vqvae_path.exists():
    if 'COLAB_GPU' in os.environ:
        raise FileNotFoundError(
            f"Repository not found at {PROJECT_ROOT}\\n"
            f"Please clone the repository first:\\n"
            f"  !git clone <your-repo-url> /content/minecraft_ai\\n"
            f"Or if using Google Drive:\\n"
            f"  PROJECT_ROOT = Path('/content/drive/MyDrive/minecraft_ai')"
        )
    else:
        raise FileNotFoundError(
            f"Could not find vqvae.py at {vqvae_path}\\n"
            f"Project root: {PROJECT_ROOT}\\n"
            f"Current directory: {Path.cwd()}"
        )

# Import v8-B model and loss from src
from src.models.vqvae import VQVAEv8B, ResidualBlock3D
from src.models.losses import FrequencyWeightedLoss, compute_frequency_weights
from src.models.rfsq import RFSQ

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Cell 2: Configuration

In [None]:
# === Data Paths ===
DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/vocabulary/tok2block.json"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/vocabulary/block_embeddings_v3.npy"

# Verify paths exist
print(f"Checking paths...")
for name, path in [('DATA_DIR', DATA_DIR), ('VAL_DIR', VAL_DIR), 
                    ('VOCAB_PATH', VOCAB_PATH), ('V3_EMBEDDINGS_PATH', V3_EMBEDDINGS_PATH)]:
    exists = Path(path).exists() if not path.endswith('.h5') else Path(path).parent.exists()
    print(f"  {name}: {path} {'[OK]' if exists else '[NOT FOUND]'}")

# === V8-B Architecture Configuration ===
HIDDEN_DIM = 192  # Single hidden dim (simpler than [96, 192])
RFSQ_LEVELS_PER_STAGE = [5, 5, 5, 5]  # Same as v6-freq
NUM_STAGES = 2
DROPOUT = 0.1
NUM_RESBLOCKS_PER_STAGE = 2
NUM_RESBLOCKS_LATENT = 6  # More capacity at latent resolution

# === Frequency Weighting (from v6-freq) ===
USE_FREQUENCY_WEIGHTING = True
FREQUENCY_WEIGHT_CAP = 5.0  # Reduced from 10.0 for better balance

# === NEW: Volume Penalty and Perceptual Loss ===
VOLUME_PENALTY_WEIGHT = 1.0  # NEW: Fixes 1.68x over-prediction
PERCEPTUAL_WEIGHT = 0.1      # NEW: Spatial smoothness

# === Terrain Settings ===
TERRAIN_WEIGHT = 0.2
BUILDING_WEIGHT = 1.0
AIR_WEIGHT = 0.1

# === Training Configuration ===
TOTAL_EPOCHS = 35  # Increased from 25
BATCH_SIZE = 2  # Reduced from 4 (larger model)
BASE_LR = 2e-4  # Slightly lower than v6-freq's 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 8  # Effective batch = 16
USE_GRADIENT_CHECKPOINTING = True  # NEW: Required for T4 memory

SEED = 42
NUM_WORKERS = 2

CODES_PER_STAGE = int(np.prod(RFSQ_LEVELS_PER_STAGE))
TOTAL_IMPLICIT_CODES = CODES_PER_STAGE ** NUM_STAGES

print("\nVQ-VAE v8-B (16x16x16 Latent) Configuration:")
print(f"  Latent resolution: 16x16x16 = 4,096 spatial positions")
print(f"  Compression ratio: 8:1 (vs v6-freq's 64:1)")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  ResBlocks at latent: {NUM_RESBLOCKS_LATENT}")
print(f"  RFSQ: {NUM_STAGES} stages x {CODES_PER_STAGE:,} codes")
print(f"  Total implicit codes: {TOTAL_IMPLICIT_CODES:,}")
print(f"\nNEW FEATURES:")
print(f"  Gradient checkpointing: {USE_GRADIENT_CHECKPOINTING}")
print(f"  Volume penalty weight: {VOLUME_PENALTY_WEIGHT}")
print(f"  Perceptual weight: {PERCEPTUAL_WEIGHT}")
print(f"  Frequency weight cap: {FREQUENCY_WEIGHT_CAP}x (reduced from 10x)")
print(f"\nTRAINING:")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM_STEPS})")
print(f"  Epochs: {TOTAL_EPOCHS}")
print(f"  Learning rate: {BASE_LR}")

## Cell 3: Load Vocabulary and Embeddings

In [None]:
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)

# Load V3 embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## Cell 4: Compute Block Frequencies

In [None]:
print("Computing block frequencies from training data...")
print("(This may take a few minutes)")

# Scan all training data
all_block_ids = []
train_files = sorted(Path(DATA_DIR).glob("*.h5"))

for h5_file in tqdm(train_files, desc="Scanning training data"):
    with h5py.File(h5_file, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].flatten()
        all_block_ids.append(torch.from_numpy(structure).long())

all_block_ids = torch.cat(all_block_ids)
print(f"\nTotal blocks scanned: {len(all_block_ids):,}")

# Compute frequency weights using helper function
FREQUENCY_WEIGHT_TENSOR = compute_frequency_weights(
    all_block_ids,
    vocab_size=VOCAB_SIZE,
    smoothing=0.5  # sqrt weighting
)

# Cap at max weight
FREQUENCY_WEIGHT_TENSOR = FREQUENCY_WEIGHT_TENSOR.clamp(max=FREQUENCY_WEIGHT_CAP)

# Show top 10 rarest blocks
print("\nTop 10 RAREST blocks (highest weights):")
top_indices = torch.argsort(FREQUENCY_WEIGHT_TENSOR, descending=True)[:10]
for idx in top_indices:
    tok = idx.item()
    weight = FREQUENCY_WEIGHT_TENSOR[tok].item()
    block_name = tok2block.get(tok, f"UNKNOWN_{tok}")
    count = (all_block_ids == tok).sum().item()
    print(f"  {block_name}: weight={weight:.1f}x (count={count:,})")

# Show top 10 most common blocks
print("\nTop 10 MOST COMMON blocks (lowest weights):")
bottom_indices = torch.argsort(FREQUENCY_WEIGHT_TENSOR)[:10]
for idx in bottom_indices:
    tok = idx.item()
    weight = FREQUENCY_WEIGHT_TENSOR[tok].item()
    block_name = tok2block.get(tok, f"UNKNOWN_{tok}")
    count = (all_block_ids == tok).sum().item()
    print(f"  {block_name}: weight={weight:.2f}x (count={count:,})")

# Identify rare block categories for tracking
RARE_BLOCK_KEYWORDS = ['chest', 'door', 'fence', 'trapdoor', 'carpet', 'bed', 'button', 'lever']
RARE_BLOCK_TOKENS = set()
for tok, block in tok2block.items():
    for keyword in RARE_BLOCK_KEYWORDS:
        if keyword in block.lower():
            RARE_BLOCK_TOKENS.add(tok)
            break

RARE_BLOCK_TOKENS_TENSOR = torch.tensor(sorted(RARE_BLOCK_TOKENS), dtype=torch.long)
print(f"\nRare block tokens identified: {len(RARE_BLOCK_TOKENS)}")
print(f"Keywords: {RARE_BLOCK_KEYWORDS}")

## Cell 5: Terrain Detection

In [None]:
TERRAIN_BLOCKS: Set[str] = {
    'minecraft:dirt', 'minecraft:grass_block', 'minecraft:coarse_dirt',
    'minecraft:podzol', 'minecraft:mycelium', 'minecraft:rooted_dirt',
    'minecraft:dirt_path', 'minecraft:farmland', 'minecraft:mud',
    'minecraft:stone', 'minecraft:cobblestone', 'minecraft:mossy_cobblestone',
    'minecraft:bedrock', 'minecraft:deepslate', 'minecraft:tuff',
    'minecraft:granite', 'minecraft:diorite', 'minecraft:andesite',
    'minecraft:sand', 'minecraft:red_sand', 'minecraft:gravel', 'minecraft:clay',
    'minecraft:water', 'minecraft:lava',
    'minecraft:terracotta', 'minecraft:white_terracotta', 'minecraft:orange_terracotta',
    'minecraft:brown_terracotta', 'minecraft:red_terracotta',
    'minecraft:netherrack', 'minecraft:soul_sand', 'minecraft:soul_soil',
    'minecraft:end_stone',
    'minecraft:snow_block', 'minecraft:ice', 'minecraft:packed_ice',
}

TERRAIN_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    base_name = block.split('[')[0] if '[' in block else block
    if base_name in TERRAIN_BLOCKS:
        TERRAIN_TOKENS.add(tok)

TERRAIN_TOKENS_TENSOR = torch.tensor(sorted(TERRAIN_TOKENS), dtype=torch.long)
print(f"Terrain tokens: {len(TERRAIN_TOKENS)}")

## Cell 6: Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Cell 7: Create Model

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

# Create model
model = VQVAEv8B(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    rfsq_levels=RFSQ_LEVELS_PER_STAGE,
    num_stages=NUM_STAGES,
    dropout=DROPOUT,
    pretrained_embeddings=torch.from_numpy(v3_embeddings),
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Statistics:")
print(f"  Total params: {total_params:,}")
print(f"  Trainable params: {trainable_params:,}")
print(f"  RFSQ total codes: {model.quantizer.codebook_size:,}")
print(f"  Latent shape: [B, {model.rfsq_dim}, 16, 16, 16]")

# Create loss function
criterion = FrequencyWeightedLoss(
    frequency_weights=FREQUENCY_WEIGHT_TENSOR,
    frequency_cap=FREQUENCY_WEIGHT_CAP,
    terrain_weight=TERRAIN_WEIGHT,
    building_weight=BUILDING_WEIGHT,
    air_weight=AIR_WEIGHT,
    volume_penalty_weight=VOLUME_PENALTY_WEIGHT,
    perceptual_weight=PERCEPTUAL_WEIGHT,
    air_tokens=AIR_TOKENS,
).to(device)

print(f"\nLoss Function:")
print(f"  Frequency weighting: {USE_FREQUENCY_WEIGHTING} (cap={FREQUENCY_WEIGHT_CAP}x)")
print(f"  Volume penalty: {VOLUME_PENALTY_WEIGHT}")
print(f"  Perceptual loss: {PERCEPTUAL_WEIGHT}")

# Create optimizer and scheduler
optimizer = optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=BASE_LR,
    weight_decay=1e-4,
    betas=(0.9, 0.999)
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=TOTAL_EPOCHS,
    eta_min=1e-5
)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"\nOptimizer: AdamW")
print(f"  LR: {BASE_LR}")
print(f"  Weight decay: 1e-4")
print(f"  Scheduler: CosineAnnealingLR")
print(f"  Mixed precision: {USE_AMP}")

## Cell 8: Create Data Loaders

In [None]:
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    generator=g
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Data loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## Cell 9: Compute Metrics Function

In [None]:
def compute_metrics(logits, targets, air_tokens, terrain_tokens, rare_tokens, block_embeddings):
    """Compute ALL training metrics (complete checklist from CLAUDE.md)."""
    device = logits.device
    B, C, X, Y, Z = logits.shape
    
    logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
    targets_flat = targets.view(-1)
    
    air_dev = air_tokens.to(device)
    terrain_dev = terrain_tokens.to(device)
    rare_dev = rare_tokens.to(device)
    
    is_air = torch.isin(targets_flat, air_dev)
    is_terrain = torch.isin(targets_flat, terrain_dev) & ~is_air
    is_building = ~is_air & ~is_terrain
    is_rare = torch.isin(targets_flat, rare_dev)
    
    preds = logits_flat.argmax(dim=1)
    is_air_pred = torch.isin(preds, air_dev)
    correct = (preds == targets_flat).float()
    
    metrics = {}
    
    # === CORE METRICS ===
    metrics['overall_acc'] = correct.mean()
    metrics['terrain_acc'] = correct[is_terrain].mean() if is_terrain.any() else torch.tensor(0.0, device=device)
    
    # === BUILDING METRICS ===
    if is_building.any():
        metrics['building_acc'] = correct[is_building].mean()
        metrics['building_recall'] = (is_building & ~is_air_pred).sum().float() / is_building.sum()
        metrics['building_false_air'] = (is_building & is_air_pred).sum().float() / is_building.sum()
        
        # Building precision
        pred_building = ~is_air_pred
        if pred_building.sum() > 0:
            metrics['building_precision'] = (pred_building & is_building & correct).sum().float() / pred_building.sum()
        else:
            metrics['building_precision'] = torch.tensor(0.0, device=device)
        
        # Building F1 (NEW)
        if metrics['building_precision'] + metrics['building_recall'] > 0:
            metrics['building_f1'] = 2 * (metrics['building_precision'] * metrics['building_recall']) / \
                                     (metrics['building_precision'] + metrics['building_recall'])
        else:
            metrics['building_f1'] = torch.tensor(0.0, device=device)
    else:
        metrics['building_acc'] = torch.tensor(0.0, device=device)
        metrics['building_recall'] = torch.tensor(0.0, device=device)
        metrics['building_false_air'] = torch.tensor(0.0, device=device)
        metrics['building_precision'] = torch.tensor(0.0, device=device)
        metrics['building_f1'] = torch.tensor(0.0, device=device)
    
    # === AIR METRICS (COMPLETE) ===
    if is_air.any():
        metrics['air_acc'] = correct[is_air].mean()
        
        # Air precision (NEW)
        if is_air_pred.sum() > 0:
            metrics['air_precision'] = (is_air_pred & is_air & correct).sum().float() / is_air_pred.sum()
        else:
            metrics['air_precision'] = torch.tensor(0.0, device=device)
        
        # False block rate (NEW) - air incorrectly predicted as building
        metrics['false_block_rate'] = (is_air & ~is_air_pred).sum().float() / is_air.sum()
    else:
        metrics['air_acc'] = torch.tensor(0.0, device=device)
        metrics['air_precision'] = torch.tensor(0.0, device=device)
        metrics['false_block_rate'] = torch.tensor(0.0, device=device)
    
    # === RARE BLOCK METRICS ===
    if is_rare.any():
        metrics['rare_acc'] = correct[is_rare].mean()
        metrics['rare_recall'] = (is_rare & ~is_air_pred).sum().float() / is_rare.sum()
        
        # Rare precision (NEW)
        pred_rare = torch.isin(preds, rare_dev)
        if pred_rare.sum() > 0:
            metrics['rare_precision'] = (pred_rare & is_rare & correct).sum().float() / pred_rare.sum()
        else:
            metrics['rare_precision'] = torch.tensor(0.0, device=device)
    else:
        metrics['rare_acc'] = torch.tensor(0.0, device=device)
        metrics['rare_recall'] = torch.tensor(0.0, device=device)
        metrics['rare_precision'] = torch.tensor(0.0, device=device)
    
    # === STRUCTURE RECALL ===
    is_struct = ~is_air
    metrics['struct_recall'] = (is_struct & ~is_air_pred).sum().float() / is_struct.sum() if is_struct.any() else torch.tensor(0.0, device=device)
    
    # === VOLUME METRICS ===
    # Building volume ratio
    pred_building = ~is_air_pred
    orig_vol = is_struct.sum().float()
    pred_vol = pred_building.sum().float()
    metrics['vol_ratio'] = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0, device=device)
    
    # Air ratio (NEW)
    orig_air_vol = is_air.sum().float()
    pred_air_vol = is_air_pred.sum().float()
    metrics['air_ratio'] = pred_air_vol / orig_air_vol if orig_air_vol > 0 else torch.tensor(1.0, device=device)
    
    # === ERROR SIMILARITY (CRITICAL FOR v9 ANALYSIS) ===
    wrong_mask = preds != targets_flat
    
    # Overall error similarity
    if wrong_mask.any():
        pred_emb = block_embeddings[preds[wrong_mask]]
        gt_emb = block_embeddings[targets_flat[wrong_mask]]
        metrics['error_similarity'] = F.cosine_similarity(pred_emb, gt_emb, dim=-1).mean()
    else:
        metrics['error_similarity'] = torch.tensor(0.0, device=device)
    
    # Terrain error similarity (NEW)
    terrain_wrong = is_terrain & wrong_mask
    if terrain_wrong.any():
        pred_emb = block_embeddings[preds[terrain_wrong]]
        gt_emb = block_embeddings[targets_flat[terrain_wrong]]
        metrics['terrain_error_similarity'] = F.cosine_similarity(pred_emb, gt_emb, dim=-1).mean()
    else:
        metrics['terrain_error_similarity'] = torch.tensor(0.0, device=device)
    
    # Building error similarity (NEW)
    building_wrong = is_building & wrong_mask
    if building_wrong.any():
        pred_emb = block_embeddings[preds[building_wrong]]
        gt_emb = block_embeddings[targets_flat[building_wrong]]
        metrics['building_error_similarity'] = F.cosine_similarity(pred_emb, gt_emb, dim=-1).mean()
    else:
        metrics['building_error_similarity'] = torch.tensor(0.0, device=device)
    
    return metrics

print("Metrics computation function defined (COMPLETE - all metrics from CLAUDE.md)")

## Cell 10: Training Functions

In [None]:
def train_epoch(model, criterion, loader, optimizer, scaler, device, air_tokens, terrain_tokens, rare_tokens, block_embeddings):
    """Train for one epoch."""
    model.train()
    model.quantizer.reset_usage()
    
    metrics_sum = {
        'loss': 0.0, 'ce_loss': 0.0, 'volume_loss': 0.0, 'perceptual_loss': 0.0,
        'overall_acc': 0.0, 'terrain_acc': 0.0, 'building_acc': 0.0,
        'building_recall': 0.0, 'building_false_air': 0.0, 'building_precision': 0.0,
        'building_f1': 0.0,  # NEW
        'struct_recall': 0.0, 'vol_ratio': 0.0, 'air_ratio': 0.0,  # air_ratio NEW
        'air_acc': 0.0, 'air_precision': 0.0, 'false_block_rate': 0.0,  # air_precision, false_block_rate NEW
        'rare_acc': 0.0, 'rare_recall': 0.0, 'rare_precision': 0.0,  # rare_precision NEW
        'error_similarity': 0.0, 'terrain_error_similarity': 0.0, 'building_error_similarity': 0.0,  # NEW
    }
    grad_norms = []
    all_residual_norms = []  # NEW: Track residual norms
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q)
            loss = loss_dict['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        # Track residual norms every 100 batches
        if batch_idx % 100 == 0:
            with torch.no_grad():
                # Re-encode to get residual norms
                z_e = model.encode(batch)
                residual = z_e
                norms = []
                for stage in model.quantizer.stages:
                    norms.append(residual.norm().item())
                    z_norm = stage.layernorm(residual)
                    z_q_norm, _ = stage.fsq(z_norm)
                    z_q_stage = stage.layernorm.inverse(z_q_norm)
                    residual = residual - z_q_stage
                norms.append(residual.norm().item())  # Final residual
                all_residual_norms.append(norms)
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            grad_norms.append(grad_norm.item())
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        # Accumulate loss components
        with torch.no_grad():
            metrics_sum['loss'] += loss_dict['loss'].item()
            metrics_sum['ce_loss'] += loss_dict['ce_loss'].item()
            metrics_sum['volume_loss'] += loss_dict['volume_loss'].item()
            metrics_sum['perceptual_loss'] += loss_dict['perceptual_loss'].item()
            metrics_sum['vol_ratio'] += loss_dict['volume_ratio'].item()
            
            # Compute batch metrics (with embeddings for error similarity)
            batch_metrics = compute_metrics(logits, batch, air_tokens, terrain_tokens, rare_tokens, block_embeddings)
            for k, v in batch_metrics.items():
                metrics_sum[k] += v.item()
        
        n += 1
    
    # Average all metrics
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    # Add RFSQ stats
    stage_stats = model.quantizer.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    metrics['grad_norm'] = sum(grad_norms) / len(grad_norms) if grad_norms else 0.0
    
    # Residual decay (NEW - CRITICAL RFSQ DIAGNOSTIC)
    if all_residual_norms:
        avg_norms = [sum(x) / len(all_residual_norms) for x in zip(*all_residual_norms)]
        metrics['residual_decay'] = avg_norms[-1] / avg_norms[0] if avg_norms[0] > 0 else 1.0
    else:
        metrics['residual_decay'] = 1.0
    
    return metrics


@torch.no_grad()
def validate(model, criterion, loader, device, air_tokens, terrain_tokens, rare_tokens, block_embeddings):
    """Validate on validation set."""
    model.eval()
    model.quantizer.reset_usage()
    
    metrics_sum = {
        'loss': 0.0, 'ce_loss': 0.0, 'volume_loss': 0.0, 'perceptual_loss': 0.0,
        'overall_acc': 0.0, 'terrain_acc': 0.0, 'building_acc': 0.0,
        'building_recall': 0.0, 'building_false_air': 0.0, 'building_precision': 0.0,
        'building_f1': 0.0,
        'struct_recall': 0.0, 'vol_ratio': 0.0, 'air_ratio': 0.0,
        'air_acc': 0.0, 'air_precision': 0.0, 'false_block_rate': 0.0,
        'rare_acc': 0.0, 'rare_recall': 0.0, 'rare_precision': 0.0,
        'error_similarity': 0.0, 'terrain_error_similarity': 0.0, 'building_error_similarity': 0.0,
    }
    all_residual_norms = []
    n = 0
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q)
        
        # Track residual norms every 50 batches
        if batch_idx % 50 == 0:
            z_e = model.encode(batch)
            residual = z_e
            norms = []
            for stage in model.quantizer.stages:
                norms.append(residual.norm().item())
                z_norm = stage.layernorm(residual)
                z_q_norm, _ = stage.fsq(z_norm)
                z_q_stage = stage.layernorm.inverse(z_q_norm)
                residual = residual - z_q_stage
            norms.append(residual.norm().item())
            all_residual_norms.append(norms)
        
        metrics_sum['loss'] += loss_dict['loss'].item()
        metrics_sum['ce_loss'] += loss_dict['ce_loss'].item()
        metrics_sum['volume_loss'] += loss_dict['volume_loss'].item()
        metrics_sum['perceptual_loss'] += loss_dict['perceptual_loss'].item()
        metrics_sum['vol_ratio'] += loss_dict['volume_ratio'].item()
        
        batch_metrics = compute_metrics(logits, batch, air_tokens, terrain_tokens, rare_tokens, block_embeddings)
        for k, v in batch_metrics.items():
            metrics_sum[k] += v.item()
        
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    stage_stats = model.quantizer.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    # Residual decay
    if all_residual_norms:
        avg_norms = [sum(x) / len(all_residual_norms) for x in zip(*all_residual_norms)]
        metrics['residual_decay'] = avg_norms[-1] / avg_norms[0] if avg_norms[0] > 0 else 1.0
    else:
        metrics['residual_decay'] = 1.0
    
    return metrics

print("Training functions defined (with ALL metrics)")

## Cell 11: Training Loop

In [None]:
print("="*70)
print("VQ-VAE V8-B TRAINING - 16×16×16 LATENT RESOLUTION")
print("="*70)
print(f"Key improvements over v6-freq:")
print(f"  - 8x more spatial positions (4,096 vs 512)")
print(f"  - 8:1 compression ratio (vs 64:1)")
print(f"  - Volume penalty (fixes 1.68x over-prediction)")
print(f"  - Perceptual loss (spatial smoothness)")
print(f"\nNEW METRICS TRACKED:")
print(f"  - building_f1, air_precision, air_ratio")
print(f"  - false_block_rate, rare_precision")
print(f"  - error_similarity (overall + terrain + building)")
print(f"  - residual_decay (RFSQ diagnostic)")
print()

# Get block embeddings for error similarity computation
block_embeddings_tensor = model.block_emb.weight.data.to(device)

history = {
    'train_loss': [], 'train_ce_loss': [], 'train_volume_loss': [], 'train_perceptual_loss': [],
    'train_building_acc': [], 'train_building_recall': [], 'train_building_precision': [], 'train_building_f1': [],
    'train_terrain_acc': [], 'train_struct_recall': [],
    'train_air_acc': [], 'train_air_precision': [], 'train_air_ratio': [], 'train_false_block_rate': [],
    'train_volume_ratio': [],
    'train_rare_acc': [], 'train_rare_recall': [], 'train_rare_precision': [],
    'train_error_similarity': [], 'train_terrain_error_similarity': [], 'train_building_error_similarity': [],
    'val_loss': [], 'val_ce_loss': [], 'val_volume_loss': [], 'val_perceptual_loss': [],
    'val_building_acc': [], 'val_building_recall': [], 'val_building_precision': [], 'val_building_f1': [],
    'val_terrain_acc': [], 'val_struct_recall': [],
    'val_air_acc': [], 'val_air_precision': [], 'val_air_ratio': [], 'val_false_block_rate': [],
    'val_volume_ratio': [],
    'val_rare_acc': [], 'val_rare_recall': [], 'val_rare_precision': [],
    'val_error_similarity': [], 'val_terrain_error_similarity': [], 'val_building_error_similarity': [],
    'train_stage0_usage': [], 'train_stage0_perplexity': [],
    'train_stage1_usage': [], 'train_stage1_perplexity': [],
    'val_stage0_usage': [], 'val_stage0_perplexity': [],
    'val_stage1_usage': [], 'val_stage1_perplexity': [],
    'train_grad_norm': [], 'train_residual_decay': [], 'val_residual_decay': [],
    'learning_rate': [],
}

best_building_acc = 0
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    train_m = train_epoch(model, criterion, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, RARE_BLOCK_TOKENS_TENSOR, block_embeddings_tensor)
    val_m = validate(model, criterion, val_loader, device,
                     AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, RARE_BLOCK_TOKENS_TENSOR, block_embeddings_tensor)
    
    # Step scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Record all metrics
    for key in history:
        if key.startswith('train_'):
            metric_name = key[6:]
            history[key].append(train_m.get(metric_name, 0))
        elif key.startswith('val_'):
            metric_name = key[4:]
            history[key].append(val_m.get(metric_name, 0))
        elif key == 'learning_rate':
            history[key].append(current_lr)
    
    # Save best model
    if val_m['building_acc'] > best_building_acc:
        best_building_acc = val_m['building_acc']
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_building_acc': best_building_acc,
        }, f"{OUTPUT_DIR}/vqvae_v8b_best.pt")
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v8b_epoch{epoch+1}.pt")
    
    print(f"Epoch {epoch+1:2d} | "
          f"Build: {train_m['building_acc']:.1%}/{val_m['building_acc']:.1%} | "
          f"Vol: {val_m['vol_ratio']:.2f}x | "
          f"ErrSim: {val_m['error_similarity']:.2f} | "
          f"Decay: {val_m['residual_decay']:.2f} | "
          f"LR: {current_lr:.2e}")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")

## Cell 12: Plot Training Curves

In [None]:
fig, axes = plt.subplots(5, 4, figsize=(24, 20))
epochs = range(1, TOTAL_EPOCHS + 1)

# === ROW 1: Core Building Metrics ===
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val')
ax.axhline(y=0.492, color='g', linestyle=':', alpha=0.5, label='v6-freq (49.2%)')
ax.axhline(y=0.60, color='orange', linestyle='--', alpha=0.5, label='Target (60%)')
ax.set_title('Building Accuracy', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 1]
ax.plot(epochs, history['train_building_f1'], 'b-', label='Train')
ax.plot(epochs, history['val_building_f1'], 'r--', label='Val')
ax.set_title('Building F1 Score (NEW)', fontweight='bold', color='blue')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 2]
ax.plot(epochs, history['train_building_precision'], 'b-', label='Train')
ax.plot(epochs, history['val_building_precision'], 'r--', label='Val')
ax.set_title('Building Precision', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 3]
ax.plot(epochs, history['train_building_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_building_recall'], 'r--', label='Val')
ax.set_title('Building Recall')
ax.legend(); ax.grid(True, alpha=0.3)

# === ROW 2: Volume & Air Metrics ===
ax = axes[1, 0]
ax.plot(epochs, history['train_volume_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_volume_ratio'], 'r--', label='Val')
ax.axhline(y=1.0, color='g', linestyle='--', alpha=0.5, label='Target (1.0x)')
ax.axhline(y=1.68, color='orange', linestyle=':', alpha=0.5, label='v6-freq (1.68x)')
ax.set_title('Volume Ratio (Building)', fontweight='bold', color='red')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.plot(epochs, history['train_air_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_air_ratio'], 'r--', label='Val')
ax.axhline(y=1.0, color='g', linestyle='--', alpha=0.5, label='Target (1.0x)')
ax.set_title('Air Ratio (NEW)', fontweight='bold', color='blue')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 2]
ax.plot(epochs, history['train_air_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_air_acc'], 'r--', label='Val')
ax.set_title('Air Accuracy')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 3]
ax.plot(epochs, history['train_air_precision'], 'b-', label='Train')
ax.plot(epochs, history['val_air_precision'], 'r--', label='Val')
ax.set_title('Air Precision (NEW)', fontweight='bold', color='blue')
ax.legend(); ax.grid(True, alpha=0.3)

# === ROW 3: Error Analysis (CRITICAL FOR v9) ===
ax = axes[2, 0]
ax.plot(epochs, history['train_error_similarity'], 'b-', label='Train')
ax.plot(epochs, history['val_error_similarity'], 'r--', label='Val')
ax.axhline(y=0.3, color='orange', linestyle='--', alpha=0.5, label='Random (0.3)')
ax.axhline(y=0.5, color='g', linestyle=':', alpha=0.5, label='Systematic (0.5)')
ax.set_title('Error Similarity (CRITICAL)', fontweight='bold', color='red')
ax.text(0.5, 0.95, 'Low=random errors\nHigh=similar blocks', transform=ax.transAxes,
        ha='center', va='top', fontsize=8, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 1]
ax.plot(epochs, history['train_terrain_error_similarity'], 'b-', label='Train')
ax.plot(epochs, history['val_terrain_error_similarity'], 'r--', label='Val')
ax.axhline(y=0.3, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Terrain Error Similarity (NEW)', fontweight='bold', color='blue')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 2]
ax.plot(epochs, history['train_building_error_similarity'], 'b-', label='Train')
ax.plot(epochs, history['val_building_error_similarity'], 'r--', label='Val')
ax.axhline(y=0.3, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Building Error Similarity (NEW)', fontweight='bold', color='blue')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 3]
ax.plot(epochs, history['train_false_block_rate'], 'b-', label='Train (air→block)')
ax.plot(epochs, history['val_false_block_rate'], 'r--', label='Val (air→block)')
ax.plot(epochs, history['train_building_false_air'], 'g-', label='Train (block→air)')
ax.plot(epochs, history['val_building_false_air'], 'orange', linestyle='--', label='Val (block→air)')
ax.set_title('False Prediction Rates', fontweight='bold')
ax.legend(fontsize=7); ax.grid(True, alpha=0.3)

# === ROW 4: Loss Components & RFSQ ===
ax = axes[3, 0]
ax.plot(epochs, history['train_ce_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_ce_loss'], 'r--', label='Val')
ax.set_title('CE Loss')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[3, 1]
ax.plot(epochs, history['train_volume_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_volume_loss'], 'r--', label='Val')
ax.set_title('Volume Loss (NEW)')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[3, 2]
ax.plot(epochs, history['train_perceptual_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_perceptual_loss'], 'r--', label='Val')
ax.set_title('Perceptual Loss (NEW)')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[3, 3]
ax.plot(epochs, history['train_residual_decay'], 'b-', label='Train')
ax.plot(epochs, history['val_residual_decay'], 'r--', label='Val')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='Target (<0.5)')
ax.axhline(y=0.12, color='g', linestyle=':', alpha=0.5, label='v6-freq (0.12)')
ax.set_title('Residual Decay (RFSQ)', fontweight='bold', color='red')
ax.text(0.5, 0.95, 'LayerNorm working if <0.5', transform=ax.transAxes,
        ha='center', va='top', fontsize=8, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.legend(); ax.grid(True, alpha=0.3)

# === ROW 5: Comparisons & Summary ===
ax = axes[4, 0]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train S0')
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val S0')
ax.set_title('Stage 0 Perplexity')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[4, 1]
ax.plot(epochs, history['train_stage1_perplexity'], 'b-', label='Train S1')
ax.plot(epochs, history['val_stage1_perplexity'], 'r--', label='Val S1')
ax.set_title('Stage 1 Perplexity')
ax.legend(); ax.grid(True, alpha=0.3)

# v6-freq comparison
ax = axes[4, 2]
v6_freq = {'Build\nAcc': 0.492, 'Vol\nRatio': 1.68, 'Err\nSim': 0.31, 'Decay': 0.12}
v8b = {
    'Build\nAcc': history['val_building_acc'][-1],
    'Vol\nRatio': history['val_volume_ratio'][-1],
    'Err\nSim': history['val_error_similarity'][-1],
    'Decay': history['val_residual_decay'][-1],
}
x = np.arange(len(v6_freq))
width = 0.35
ax.bar(x - width/2, list(v6_freq.values()), width, label='v6-freq', color='gray')
ax.bar(x + width/2, list(v8b.values()), width, label='v8-B', color='green')
ax.set_xticks(x)
ax.set_xticklabels(v6_freq.keys())
ax.set_title('v6-freq vs v8-B', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

# Summary text
ax = axes[4, 3]
ax.axis('off')
target_met = '[OK]' if history['val_building_acc'][-1] >= 0.60 else '[X]'
vol_fixed = '[OK]' if history['val_volume_ratio'][-1] <= 1.3 else '[X]'
err_random = '[!]' if history['val_error_similarity'][-1] < 0.3 else '[OK]'
rfsq_ok = '[OK]' if history['val_residual_decay'][-1] < 0.5 else '[X]'
summary = f'''VQ-VAE v8-B Results
──────────────────────
Best: {best_building_acc:.1%} (ep {best_epoch})

TARGETS:
{target_met} Build ≥60%: {history['val_building_acc'][-1]:.1%}
{vol_fixed} Vol ≤1.3x: {history['val_volume_ratio'][-1]:.2f}x

DIAGNOSTICS:
{err_random} Err Sim: {history['val_error_similarity'][-1]:.2f}
  (<0.3 = random errors)
{rfsq_ok} Decay: {history['val_residual_decay'][-1]:.2f}
  (<0.5 = RFSQ working)

NEW METRICS:
  F1: {history['val_building_f1'][-1]:.1%}
  Air Prec: {history['val_air_precision'][-1]:.1%}
  Rare Prec: {history['val_rare_precision'][-1]:.1%}

Time: {train_time/60:.0f} min'''
ax.text(0.1, 0.9, summary, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v8b_training.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("COMPREHENSIVE RESULTS ANALYSIS")
print("="*70)
print(f"\nCORE METRICS:")
print(f"  Building accuracy: {history['val_building_acc'][-1]:.1%} (target: ≥60%)")
print(f"  Building F1: {history['val_building_f1'][-1]:.1%}")
print(f"  Building precision: {history['val_building_precision'][-1]:.1%}")
print(f"  Building recall: {history['val_building_recall'][-1]:.1%}")
print(f"\nVOLUME METRICS:")
print(f"  Volume ratio: {history['val_volume_ratio'][-1]:.2f}x (target: ≤1.3x)")
print(f"  Air ratio: {history['val_air_ratio'][-1]:.2f}x")
print(f"\nAIR METRICS:")
print(f"  Air accuracy: {history['val_air_acc'][-1]:.1%}")
print(f"  Air precision: {history['val_air_precision'][-1]:.1%}")
print(f"  False block rate: {history['val_false_block_rate'][-1]:.1%}")
print(f"  False air rate: {history['val_building_false_air'][-1]:.1%}")
print(f"\nERROR ANALYSIS (CRITICAL FOR v9):")
print(f"  Overall error similarity: {history['val_error_similarity'][-1]:.3f}")
if history['val_error_similarity'][-1] < 0.3:
    print(f"    -> RANDOM errors (like v5.1) - architecture needs fundamental change for v9")
elif history['val_error_similarity'][-1] < 0.5:
    print(f"    -> Moderately systematic - v9 could improve with better loss weights")
else:
    print(f"    -> Systematic errors - v9 can use material similarity loss")
print(f"  Terrain error similarity: {history['val_terrain_error_similarity'][-1]:.3f}")
print(f"  Building error similarity: {history['val_building_error_similarity'][-1]:.3f}")
print(f"\nRFSQ DIAGNOSTICS:")
print(f"  Residual decay: {history['val_residual_decay'][-1]:.3f} (v6-freq: 0.12)")
if history['val_residual_decay'][-1] > 0.5:
    print(f"    -> WARNING: LayerNorm may not be working correctly!")
else:
    print(f"    -> RFSQ working correctly")
print(f"  Stage 0 perplexity: {history['val_stage0_perplexity'][-1]:.0f}")
print(f"  Stage 1 perplexity: {history['val_stage1_perplexity'][-1]:.0f}")
print(f"\nRARE BLOCKS:")
print(f"  Rare accuracy: {history['val_rare_acc'][-1]:.1%}")
print(f"  Rare recall: {history['val_rare_recall'][-1]:.1%}")
print(f"  Rare precision: {history['val_rare_precision'][-1]:.1%}")
print()
if history['val_building_acc'][-1] >= 0.60 and history['val_volume_ratio'][-1] <= 1.3:
    print("[SUCCESS] Stage 1 targets met! Ready for Stage 2 (v8-C with attention)")
elif history['val_building_acc'][-1] >= 0.55:
    print("[PARTIAL] Close to target, consider Stage 2 with caution")
else:
    print("[BELOW TARGET] Analyze error patterns before Stage 2")

## Cell 13: Save Results

In [None]:
results = {
    'config': {
        'version': 'v8-B',
        'changes_from_v6freq': [
            '16×16×16 latent (vs 8×8×8)',
            '8:1 compression ratio (vs 64:1)',
            'Volume penalty loss',
            'Perceptual loss',
            f'Frequency cap: {FREQUENCY_WEIGHT_CAP}x (reduced from 10x)',
            'COMPLETE metrics tracking (all CLAUDE.md requirements)',
        ],
        'hidden_dim': HIDDEN_DIM,
        'rfsq_levels_per_stage': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'volume_penalty_weight': VOLUME_PENALTY_WEIGHT,
        'perceptual_weight': PERCEPTUAL_WEIGHT,
        'seed': SEED,
    },
    'results': {
        # Core
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_building_f1': float(history['val_building_f1'][-1]),
        'final_building_precision': float(history['val_building_precision'][-1]),
        'final_building_recall': float(history['val_building_recall'][-1]),
        'final_terrain_acc': float(history['val_terrain_acc'][-1]),
        # Volume
        'final_volume_ratio': float(history['val_volume_ratio'][-1]),
        'final_air_ratio': float(history['val_air_ratio'][-1]),
        # Air
        'final_air_acc': float(history['val_air_acc'][-1]),
        'final_air_precision': float(history['val_air_precision'][-1]),
        'final_false_block_rate': float(history['val_false_block_rate'][-1]),
        'final_building_false_air': float(history['val_building_false_air'][-1]),
        # Rare
        'final_rare_acc': float(history['val_rare_acc'][-1]),
        'final_rare_recall': float(history['val_rare_recall'][-1]),
        'final_rare_precision': float(history['val_rare_precision'][-1]),
        # Error analysis (CRITICAL FOR v9)
        'final_error_similarity': float(history['val_error_similarity'][-1]),
        'final_terrain_error_similarity': float(history['val_terrain_error_similarity'][-1]),
        'final_building_error_similarity': float(history['val_building_error_similarity'][-1]),
        'error_type': 'random' if history['val_error_similarity'][-1] < 0.3 else 'systematic',
        # RFSQ diagnostics
        'final_residual_decay': float(history['val_residual_decay'][-1]),
        'rfsq_working': bool(history['val_residual_decay'][-1] < 0.5),
        'final_stage0_perplexity': float(history['val_stage0_perplexity'][-1]),
        'final_stage1_perplexity': float(history['val_stage1_perplexity'][-1]),
        # Meta
        'training_time_min': float(train_time / 60),
        'target_60pct_met': bool(history['val_building_acc'][-1] >= 0.60),
        'volume_target_met': bool(history['val_volume_ratio'][-1] <= 1.3),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
    'v9_recommendations': {
        'error_analysis': 'random' if history['val_error_similarity'][-1] < 0.3 else 'systematic',
        'rfsq_status': 'working' if history['val_residual_decay'][-1] < 0.5 else 'broken',
        'volume_status': 'fixed' if history['val_volume_ratio'][-1] <= 1.3 else 'needs_work',
    }
}

with open(f"{OUTPUT_DIR}/vqvae_v8b_results.json", 'w') as f:
    json.dump(results, f, indent=2)

torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v8b_final.pt")

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_best.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_training.png")

print("\n" + "="*70)
print("FINAL SUMMARY - VQ-VAE v8-B (COMPLETE METRICS)")
print("="*70)
print(f"Best building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")
print(f"Final volume ratio: {history['val_volume_ratio'][-1]:.2f}x")
print(f"Error similarity: {history['val_error_similarity'][-1]:.3f} ({'random' if history['val_error_similarity'][-1] < 0.3 else 'systematic'})")
print(f"RFSQ residual decay: {history['val_residual_decay'][-1]:.3f} ({'OK' if history['val_residual_decay'][-1] < 0.5 else 'WARNING'})")
print(f"Training time: {train_time/60:.1f} minutes")
print()
if results['results']['target_60pct_met'] and results['results']['volume_target_met']:
    print("[SUCCESS] Stage 1 complete: Proceed to Stage 2 (v8-C with attention)")
elif results['results']['target_60pct_met']:
    print("[PARTIAL] Building acc met, but volume needs work")
elif history['val_building_acc'][-1] >= 0.55:
    print("[CLOSE] Consider Stage 2 with caution")
else:
    print("[BELOW TARGET] Analyze error patterns - may need v9 instead of Stage 2")