# VQ-VAE v9 Training - Two-Phase Volume Fix

## Overview
This notebook implements the v9 training strategy to fix the volume ratio problem (stuck at 2.0x in v8).

## Key Innovations

### 1. Two-Phase Training
- **Phase 1 (epochs 1-10)**: Binary air/structure classification only
  - Learns correct volume ratio FIRST
  - Strong volume penalty (100x weight)
  - No frequency weighting
- **Phase 2 (epochs 11-40)**: Full vocabulary training
  - Starts from Phase 1 checkpoint
  - Model already knows air/structure boundary

### 2. Asymmetric Focal Loss
- Focal loss (gamma=2.0) focuses on hard examples
- Air boost (3x) makes air predictions more valuable
- Prevents model from "coasting" on easy predictions

### 3. Volume-Aware Logit Adjustment
- Dynamically adjusts air logits based on current volume ratio
- If over-predicting: boost air logits to encourage more air predictions

### 4. Reduced Frequency Cap
- Cap reduced from 5.0x to 2.0x
- Reduces incentive to over-predict rare blocks

## Expected Results
| Metric | v8-B (Failed) | v9 Target |
|--------|---------------|----------|
| Volume Ratio | 2.0x (stuck) | 0.9-1.1x |
| Building Accuracy | 80% (fake) | 45-55% (real) |
| Recall | 99.7% | 85-95% |

## 1. Setup - Mount Google Drive

**What this does (Technical):** Mounts Google Drive filesystem to access training data stored in the cloud. Creates output directory for model checkpoints and results.

**What this does (Simple):** Connects to your Google Drive so we can load the Minecraft structure data and save our trained model.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v9'

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

## 2. Imports

**What this does (Technical):** Imports PyTorch for deep learning, numpy for numerical operations, h5py for reading HDF5 structure files, and matplotlib for plotting training curves.

**What this does (Simple):** Loads all the software libraries we need to train our AI model.

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Set, Optional
from collections import defaultdict

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 3. RFSQ Quantization Module

**What this does (Technical):** Implements Residual Finite Scalar Quantization (RFSQ) with LayerNorm conditioning. This discretizes continuous latent vectors into a finite set of codes, enabling the model to learn a compressed discrete representation. Uses straight-through estimator for gradient flow.

**What this does (Simple):** Converts the AI's internal representation into discrete codes (like converting analog to digital). This compression forces the model to learn the most important features of Minecraft structures.

In [None]:
class FSQ(nn.Module):
    """Finite Scalar Quantization - quantizes each dimension to fixed levels."""
    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))
        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))
        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))
        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z_bounded = torch.tanh(z)
        z_q = self._quantize(z_bounded)
        z_q = z_bounded + (z_q - z_bounded).detach()  # Straight-through
        indices = self._to_indices(z_q)
        return z_q, indices

    def _quantize(self, z: torch.Tensor) -> torch.Tensor:
        z_q_list = []
        for i in range(self.dim):
            half_L = self._half_levels[i]
            z_i = z[..., i] * half_L
            z_i = torch.round(z_i).clamp(-half_L, half_L) / half_L
            z_q_list.append(z_i)
        return torch.stack(z_q_list, dim=-1)

    def _to_indices(self, z_q: torch.Tensor) -> torch.Tensor:
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            level_idx = ((z_q[..., i] * half_L) + half_L).round().long().clamp(0, L - 1)
            indices = indices + level_idx * self._basis[i]
        return indices

    def get_codebook_usage(self, indices: torch.Tensor) -> Tuple[float, float]:
        flat = indices.flatten()
        counts = torch.bincount(flat, minlength=self.codebook_size).float()
        usage = (counts > 0).float().mean().item()
        probs = counts / counts.sum()
        probs = probs[probs > 0]
        entropy = -(probs * torch.log(probs)).sum()
        perplexity = torch.exp(entropy).item()
        return usage, perplexity


class InvertibleLayerNorm(nn.Module):
    """LayerNorm that stores statistics for inverse transformation."""
    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('stored_mean', None, persistent=False)
        self.register_buffer('stored_std', None, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.stored_mean = x.mean(dim=(1, 2, 3), keepdim=True)
        self.stored_std = x.std(dim=(1, 2, 3), keepdim=True) + self.eps
        x_norm = (x - self.stored_mean) / self.stored_std
        return x_norm * self.weight + self.bias

    def inverse(self, x_norm: torch.Tensor) -> torch.Tensor:
        x = (x_norm - self.bias) / self.weight
        return x * self.stored_std + self.stored_mean


class RFSQStage(nn.Module):
    """Single stage of RFSQ with LayerNorm conditioning."""
    def __init__(self, levels: List[int]):
        super().__init__()
        self.fsq = FSQ(levels)
        self.layernorm = InvertibleLayerNorm(len(levels))

    def forward(self, residual: torch.Tensor):
        z_norm = self.layernorm(residual)
        z_q_norm, indices = self.fsq(z_norm)
        z_q = self.layernorm.inverse(z_q_norm)
        new_residual = residual - z_q
        return z_q, new_residual, indices


class RFSQ(nn.Module):
    """Residual FSQ with multiple stages."""
    def __init__(self, levels_per_stage: List[int], num_stages: int = 2):
        super().__init__()
        self.num_stages = num_stages
        self.stages = nn.ModuleList([RFSQStage(levels_per_stage) for _ in range(num_stages)])
        self._usage_indices = []

    def reset_usage(self):
        self._usage_indices = []

    def forward(self, z: torch.Tensor):
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []
        for stage in self.stages:
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)
        self._usage_indices.append(all_indices)
        return z_q_sum, all_indices

    def get_usage_stats(self) -> Dict:
        if not self._usage_indices:
            return {}
        stats = {}
        for stage_idx in range(self.num_stages):
            all_stage_indices = torch.cat([b[stage_idx].flatten() for b in self._usage_indices])
            usage, perplexity = self.stages[stage_idx].fsq.get_codebook_usage(all_stage_indices)
            stats[f'stage{stage_idx}'] = (usage, perplexity)
        return stats

print("RFSQ modules defined")

## 4. VQ-VAE v9 Architecture (8x8x8 Latent)

**What this does (Technical):** Defines the encoder-decoder architecture with 8x8x8 latent resolution (reduced from v8's 16x16x16). The encoder compresses 32x32x32 input through convolutional layers to 8x8x8 latent space. The decoder reconstructs back to 32x32x32. Uses residual blocks for stable gradients.

**What this does (Simple):** The neural network that learns to compress Minecraft structures into a small code and then expand them back. We use a smaller bottleneck (8x8x8) to force better compression and prevent volume bias.

In [None]:
class ResidualBlock3D(nn.Module):
    """3D residual block with BatchNorm."""
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm3d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)


class EncoderV9(nn.Module):
    """Encoder: 32x32x32 -> 8x8x8 latent (back to v6 resolution)."""
    def __init__(self, in_channels: int = 40, hidden_dim: int = 128, rfsq_dim: int = 4):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dim, 3, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        # 32 -> 16
        self.down1 = nn.Sequential(
            nn.Conv3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.res1 = ResidualBlock3D(hidden_dim)
        # 16 -> 8
        self.down2 = nn.Sequential(
            nn.Conv3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.res2 = nn.Sequential(*[ResidualBlock3D(hidden_dim) for _ in range(4)])
        self.proj = nn.Conv3d(hidden_dim, rfsq_dim, 3, padding=1)

    def forward(self, x):
        x = self.initial(x)
        x = self.res1(self.down1(x))
        x = self.res2(self.down2(x))
        return self.proj(x)


class DecoderV9(nn.Module):
    """Decoder: 8x8x8 latent -> 32x32x32 output."""
    def __init__(self, rfsq_dim: int = 4, hidden_dim: int = 128, num_blocks: int = 3717):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv3d(rfsq_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.res1 = nn.Sequential(*[ResidualBlock3D(hidden_dim) for _ in range(4)])
        # 8 -> 16
        self.up1 = nn.Sequential(
            nn.ConvTranspose3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.res2 = ResidualBlock3D(hidden_dim)
        # 16 -> 32
        self.up2 = nn.Sequential(
            nn.ConvTranspose3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.final = nn.Conv3d(hidden_dim, num_blocks, 3, padding=1)

    def forward(self, z_q):
        x = self.initial(z_q)
        x = self.res1(x)
        x = self.res2(self.up1(x))
        x = self.up2(x)
        return self.final(x)


class VQVAEv9(nn.Module):
    """VQ-VAE v9 with 8x8x8 latent and two-phase training support."""
    def __init__(self, vocab_size: int = 3717, emb_dim: int = 40, hidden_dim: int = 128,
                 rfsq_levels: List[int] = None, num_stages: int = 2,
                 pretrained_embeddings: torch.Tensor = None):
        super().__init__()
        if rfsq_levels is None:
            rfsq_levels = [5, 5, 5, 5]
        self.rfsq_dim = len(rfsq_levels)
        self.vocab_size = vocab_size
        
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        if pretrained_embeddings is not None:
            self.block_emb.weight.data.copy_(pretrained_embeddings)
            self.block_emb.weight.requires_grad = False
        
        self.encoder = EncoderV9(emb_dim, hidden_dim, self.rfsq_dim)
        self.quantizer = RFSQ(rfsq_levels, num_stages)
        self.decoder = DecoderV9(self.rfsq_dim, hidden_dim, vocab_size)

    def forward(self, block_ids):
        z_e = self.encode(block_ids)
        z_q, indices = self.quantize(z_e)
        logits = self.decode(z_q)
        return logits, z_q, indices

    def encode(self, block_ids):
        x = self.block_emb(block_ids).permute(0, 4, 1, 2, 3)
        z_e = self.encoder(x).permute(0, 2, 3, 4, 1)
        return z_e

    def quantize(self, z_e):
        return self.quantizer(z_e)

    def decode(self, z_q):
        return self.decoder(z_q.permute(0, 4, 1, 2, 3))

print("VQ-VAE v9 architecture defined (8x8x8 latent)")

## 5. Asymmetric Focal Loss with Air Boost

**What this does (Technical):** Implements focal loss that down-weights easy examples (gamma=2.0) and up-weights air predictions (air_boost=3.0). This addresses the root cause of volume over-prediction: the model learns to predict more structure blocks because they're "easy wins" for the CE loss.

**What this does (Simple):** A smarter way to calculate the training error that:
1. Focuses more on hard-to-predict blocks (focal loss)
2. Makes air predictions 3x more important (air boost)
3. This stops the model from cheating by over-predicting structure blocks

In [None]:
class AsymmetricFocalLoss(nn.Module):
    """Focal loss with asymmetric air boosting to fix volume over-prediction.
    
    Key insight: Standard CE loss treats all predictions equally. But we NEED
    the model to predict air correctly, otherwise it over-predicts structure.
    
    Solution:
    1. Focal loss: (1-p_t)^gamma * CE - focuses on hard examples
    2. Air boost: Multiply loss by air_boost at GT air locations
    3. Volume penalty: Strong L2 penalty on volume ratio deviation
    """
    def __init__(self, air_tokens: Set[int], gamma: float = 2.0, air_boost: float = 3.0,
                 volume_penalty: float = 100.0, frequency_weights: torch.Tensor = None,
                 frequency_cap: float = 2.0):
        super().__init__()
        self.air_tokens = list(air_tokens)
        self.gamma = gamma
        self.air_boost = air_boost
        self.volume_penalty = volume_penalty
        
        if frequency_weights is not None:
            clamped = frequency_weights.clamp(max=frequency_cap)
            self.register_buffer('freq_weights', clamped)
        else:
            self.freq_weights = None

    def forward(self, logits: torch.Tensor, target: torch.Tensor, z_q: torch.Tensor,
                phase: int = 2) -> Dict[str, torch.Tensor]:
        """Compute loss. phase=1 for binary, phase=2 for full vocab."""
        device = logits.device
        B, C, X, Y, Z = logits.shape
        
        # Create air mask
        air_tensor = torch.tensor(self.air_tokens, device=device, dtype=target.dtype)
        gt_is_air = torch.isin(target, air_tensor)
        gt_is_struct = ~gt_is_air
        
        if phase == 1:
            # PHASE 1: Binary classification (air vs structure)
            return self._phase1_loss(logits, target, gt_is_air, gt_is_struct, device)
        else:
            # PHASE 2: Full vocabulary with focal loss
            return self._phase2_loss(logits, target, gt_is_air, gt_is_struct, z_q, device)

    def _phase1_loss(self, logits, target, gt_is_air, gt_is_struct, device):
        """Binary air/structure loss for Phase 1."""
        B, C, X, Y, Z = logits.shape
        
        # Collapse logits to binary: air vs non-air
        air_mask = torch.zeros(C, device=device, dtype=torch.bool)
        for t in self.air_tokens:
            if t < C:
                air_mask[t] = True
        
        # Max air logit and max non-air logit per voxel
        air_logits = logits[:, air_mask, :, :, :].max(dim=1)[0] if air_mask.any() else torch.zeros(B, X, Y, Z, device=device)
        struct_logits = logits[:, ~air_mask, :, :, :].max(dim=1)[0]
        
        # Binary logits: [B, 2, X, Y, Z]
        binary_logits = torch.stack([air_logits, struct_logits], dim=1)
        binary_target = gt_is_struct.long()  # 0=air, 1=structure
        
        # Binary CE loss
        binary_logits_flat = binary_logits.permute(0, 2, 3, 4, 1).reshape(-1, 2)
        binary_target_flat = binary_target.reshape(-1)
        ce_loss = F.cross_entropy(binary_logits_flat, binary_target_flat, reduction='mean')
        
        # Strong volume penalty
        probs = F.softmax(binary_logits, dim=1)
        pred_struct_prob = probs[:, 1, :, :, :]  # Probability of structure
        pred_volume = pred_struct_prob.sum()
        gt_volume = gt_is_struct.float().sum()
        volume_ratio = pred_volume / (gt_volume + 1e-6)
        volume_loss = (volume_ratio - 1.0) ** 2
        
        total_loss = ce_loss + self.volume_penalty * volume_loss
        
        # Metrics
        with torch.no_grad():
            pred_binary = binary_logits.argmax(dim=1)  # 0=air, 1=struct
            pred_is_struct = pred_binary == 1
            correct = (pred_binary == binary_target)
            accuracy = correct.float().mean()
            vol_ratio_hard = pred_is_struct.float().sum() / (gt_is_struct.float().sum() + 1e-6)
            recall = (gt_is_struct & pred_is_struct).float().sum() / (gt_is_struct.float().sum() + 1e-6)
            precision = (gt_is_struct & pred_is_struct).float().sum() / (pred_is_struct.float().sum() + 1e-6)
        
        return {
            'loss': total_loss,
            'ce_loss': ce_loss.detach(),
            'volume_loss': volume_loss.detach(),
            'focal_loss': torch.tensor(0.0, device=device),
            'volume_ratio': vol_ratio_hard,
            'recall': recall,
            'precision': precision,
            'accuracy': accuracy,
            'phase': torch.tensor(1.0, device=device),
        }

    def _phase2_loss(self, logits, target, gt_is_air, gt_is_struct, z_q, device):
        """Full vocabulary loss with focal loss and air boost."""
        B, C, X, Y, Z = logits.shape
        
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
        target_flat = target.reshape(-1)
        gt_is_air_flat = gt_is_air.reshape(-1)
        
        # Compute probabilities for focal loss
        log_probs = F.log_softmax(logits_flat, dim=1)
        probs = torch.exp(log_probs)
        
        # Get probability of correct class
        p_t = probs.gather(1, target_flat.unsqueeze(1)).squeeze(1)
        
        # Focal loss weight: (1 - p_t)^gamma
        focal_weight = (1 - p_t) ** self.gamma
        
        # CE loss per sample
        ce_per_sample = -log_probs.gather(1, target_flat.unsqueeze(1)).squeeze(1)
        
        # Apply frequency weights if available
        if self.freq_weights is not None:
            freq_w = self.freq_weights[target_flat]
            ce_per_sample = ce_per_sample * freq_w
        
        # Apply air boost at GT air locations
        air_weight = torch.where(gt_is_air_flat, 
                                  torch.tensor(self.air_boost, device=device),
                                  torch.tensor(1.0, device=device))
        
        # Focal loss
        focal_loss = (focal_weight * ce_per_sample * air_weight).mean()
        
        # Volume penalty
        air_tensor = torch.tensor(self.air_tokens, device=device, dtype=target.dtype)
        pred_hard = logits_flat.argmax(dim=1)
        pred_is_air = torch.isin(pred_hard, air_tensor)
        pred_is_struct = ~pred_is_air
        
        pred_volume = pred_is_struct.float().sum()
        gt_volume = gt_is_struct.float().sum()
        volume_ratio = pred_volume / (gt_volume + 1e-6)
        volume_loss = (volume_ratio - 1.0) ** 2
        
        total_loss = focal_loss + self.volume_penalty * volume_loss
        
        # Metrics
        with torch.no_grad():
            correct = (pred_hard == target_flat)
            accuracy = correct.float().mean()
            building_acc = correct[~gt_is_air_flat].float().mean() if (~gt_is_air_flat).any() else torch.tensor(0.0)
            air_acc = correct[gt_is_air_flat].float().mean() if gt_is_air_flat.any() else torch.tensor(0.0)
            recall = (pred_is_struct & ~gt_is_air_flat).float().sum() / ((~gt_is_air_flat).float().sum() + 1e-6)
            precision = (pred_is_struct & ~gt_is_air_flat).float().sum() / (pred_is_struct.float().sum() + 1e-6)
            false_air_rate = (pred_is_air & ~gt_is_air_flat).float().sum() / ((~gt_is_air_flat).float().sum() + 1e-6)
        
        return {
            'loss': total_loss,
            'ce_loss': focal_loss.detach(),
            'focal_loss': focal_loss.detach(),
            'volume_loss': volume_loss.detach(),
            'volume_ratio': volume_ratio.detach(),
            'recall': recall,
            'precision': precision,
            'accuracy': accuracy,
            'building_acc': building_acc,
            'air_acc': air_acc,
            'false_air_rate': false_air_rate,
            'phase': torch.tensor(2.0, device=device),
        }

print("AsymmetricFocalLoss defined")
print(f"  - Focal gamma: focuses on hard examples")
print(f"  - Air boost: makes air predictions more important")
print(f"  - Volume penalty: strong L2 on ratio deviation")

## 6. Configuration

**What this does (Technical):** Defines all hyperparameters for the v9 two-phase training experiment. Key changes from v8:
- **Latent resolution**: Reduced to 8x8x8 (from 16x16x16) to force better compression and prevent volume bias encoding
- **Frequency cap**: Reduced to 2.0x (from 5.0x) to reduce incentive for over-predicting rare blocks
- **Volume penalty**: Increased to 100x (from 10x) for strong volume ratio enforcement
- **Focal gamma**: Set to 2.0 to focus training on hard-to-predict examples
- **Air boost**: Set to 3.0x to make correct air predictions more valuable

**What this does (Simple):** Sets up all the training settings. The key insight is that we train in two phases:
1. **Phase 1 (10 epochs)**: Only teaches air vs non-air (binary). This establishes the correct volume ratio FIRST.
2. **Phase 2 (10 epochs)**: Teaches specific block types. Since volume is already correct, the model can focus on accuracy.

This is like teaching someone to draw outlines before filling in colors - get the shape right first!

In [None]:
# === Data Paths ===
DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/vocabulary/tok2block.json"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/embeddings/block_embeddings_v3.npy"

# Verify paths
print("Checking paths...")
for name, path in [('DATA_DIR', DATA_DIR), ('VAL_DIR', VAL_DIR),
                   ('VOCAB_PATH', VOCAB_PATH), ('V3_EMBEDDINGS_PATH', V3_EMBEDDINGS_PATH)]:
    exists = Path(path).exists()
    print(f"  {name}: {'[OK]' if exists else '[NOT FOUND]'}")

# === V9 Architecture (back to 8x8x8 latent) ===
HIDDEN_DIM = 128  # Reduced from v8's 192
RFSQ_LEVELS = [5, 5, 5, 5]  # 4 dims x 5 levels = 625 codes per stage
NUM_STAGES = 2

# === V9 Loss Weights (KEY CHANGES) ===
FREQUENCY_CAP = 2.0       # REDUCED from 5.0 - less incentive to over-predict
FOCAL_GAMMA = 2.0         # Focus on hard examples
AIR_BOOST = 3.0           # Air predictions are 3x more important
VOLUME_PENALTY = 100.0    # STRONG volume penalty (was 10.0)

# === Two-Phase Training Schedule ===
PHASE1_EPOCHS = 10        # Binary air/structure training
PHASE2_EPOCHS = 10        # Full vocabulary (increase to 30 later if Phase 1 works)
TOTAL_EPOCHS = PHASE1_EPOCHS + PHASE2_EPOCHS

# === Training ===
BATCH_SIZE = 4            # Increased since 8x8x8 uses less memory
BASE_LR = 2e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4
SEED = 42
NUM_WORKERS = 2

print(f"\n{'='*60}")
print("V9 CONFIGURATION - TWO-PHASE TRAINING")
print(f"{'='*60}")
print(f"Latent resolution: 8x8x8 (reduced from v8's 16x16x16)")
print(f"Hidden dim: {HIDDEN_DIM}")
print(f"\nPhase 1: {PHASE1_EPOCHS} epochs - Binary air/structure")
print(f"  - Learns correct volume ratio FIRST")
print(f"  - Volume penalty: {VOLUME_PENALTY}x")
print(f"\nPhase 2: {PHASE2_EPOCHS} epochs - Full vocabulary")
print(f"  - Focal gamma: {FOCAL_GAMMA}")
print(f"  - Air boost: {AIR_BOOST}x")
print(f"  - Frequency cap: {FREQUENCY_CAP}x (reduced from 5.0)")
print(f"\nTotal epochs: {TOTAL_EPOCHS}")
print(f"{'='*60}")

## 7. Load Vocabulary and Embeddings

**What this does (Technical):** Loads the block vocabulary mapping (token ID to block name) and pre-trained block embeddings. Identifies air tokens by searching for 'air' in block names.

**What this does (Simple):** Loads the dictionary that maps numbers to Minecraft block names, and the pre-learned representations of each block type.

In [None]:
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_TENSOR = torch.tensor(sorted(AIR_TOKENS), dtype=torch.long)

# Load embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"\nEmbeddings: {v3_embeddings.shape}")

## 8. Compute Frequency Weights

**What this does (Technical):** Scans all training structures to count block frequencies, then computes inverse-frequency weights (rare blocks get higher weights). Capped at 2.0x to prevent extreme over-prediction of rare blocks.

**What this does (Simple):** Counts how often each block type appears in the training data. Rare blocks get slightly higher importance, but capped to prevent the model from over-predicting them.

In [None]:
print("Computing block frequencies...")

all_block_ids = []
train_files = sorted(Path(DATA_DIR).glob("*.h5"))

for h5_file in tqdm(train_files, desc="Scanning"):
    with h5py.File(h5_file, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].flatten()
        all_block_ids.append(torch.from_numpy(structure).long())

all_block_ids = torch.cat(all_block_ids)
print(f"Total blocks scanned: {len(all_block_ids):,}")

# Compute frequency weights with REDUCED cap
counts = torch.bincount(all_block_ids.flatten(), minlength=VOCAB_SIZE).clamp(min=1)
total = counts.sum()
FREQUENCY_WEIGHTS = ((total.float() / counts.float()) ** 0.5).clamp(max=FREQUENCY_CAP)

print(f"Frequency weights computed (cap={FREQUENCY_CAP}x)")
print(f"  Max weight: {FREQUENCY_WEIGHTS.max():.2f}")
print(f"  Min weight: {FREQUENCY_WEIGHTS.min():.2f}")

## 9. Dataset

**What this does (Technical):** PyTorch Dataset class that loads Minecraft structures from HDF5 files. Each structure is a 32x32x32 array of block token IDs.

**What this does (Simple):** Defines how to load Minecraft structures from files for training.

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.h5_files = sorted(Path(data_dir).glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            structure = f[list(f.keys())[0]][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

## 10. Create Model and Optimizer

**What this does (Technical):** Instantiates VQ-VAE v9 model with 8x8x8 latent, creates AsymmetricFocalLoss criterion, AdamW optimizer with weight decay, and cosine annealing LR scheduler.

**What this does (Simple):** Creates the neural network and sets up the training optimizer that will adjust the model's parameters to minimize errors.

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

model = VQVAEv9(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    rfsq_levels=RFSQ_LEVELS,
    num_stages=NUM_STAGES,
    pretrained_embeddings=torch.from_numpy(v3_embeddings),
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

criterion = AsymmetricFocalLoss(
    air_tokens=AIR_TOKENS,
    gamma=FOCAL_GAMMA,
    air_boost=AIR_BOOST,
    volume_penalty=VOLUME_PENALTY,
    frequency_weights=FREQUENCY_WEIGHTS,
    frequency_cap=FREQUENCY_CAP,
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS, eta_min=1e-5)
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print("Model, criterion, optimizer created")

## 11. Data Loaders

**What this does (Technical):** Creates PyTorch DataLoaders that batch, shuffle, and parallelize data loading for training efficiency.

**What this does (Simple):** Sets up efficient data loading that feeds batches of structures to the model during training.

In [None]:
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS}")

## 12. Training Functions

**What this does (Technical):** Defines train_epoch() and validate() functions that iterate through data, compute loss, backpropagate gradients, and collect metrics. Supports two-phase training via the `phase` parameter.

**What this does (Simple):** The core training logic that:
1. Shows structures to the model
2. Compares predictions to ground truth
3. Adjusts model to reduce errors
4. Tracks all important metrics

In [None]:
def train_epoch(model, criterion, loader, optimizer, scaler, device, phase: int):
    """Train one epoch. phase=1 for binary, phase=2 for full vocab."""
    model.train()
    model.quantizer.reset_usage()
    
    metrics_sum = defaultdict(float)
    n = 0
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc=f"Train P{phase}", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q, phase=phase)
            loss = loss_dict['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        with torch.no_grad():
            for k, v in loss_dict.items():
                metrics_sum[k] += v.item() if torch.is_tensor(v) else v
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    # RFSQ usage stats
    for name, (usage, perp) in model.quantizer.get_usage_stats().items():
        metrics[f'{name}_usage'] = usage
        metrics[f'{name}_perplexity'] = perp
    
    return metrics


@torch.no_grad()
def validate(model, criterion, loader, device, phase: int):
    """Validate. phase=1 for binary, phase=2 for full vocab."""
    model.eval()
    model.quantizer.reset_usage()
    
    metrics_sum = defaultdict(float)
    n = 0
    
    for batch in tqdm(loader, desc=f"Val P{phase}", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q, phase=phase)
        
        for k, v in loss_dict.items():
            metrics_sum[k] += v.item() if torch.is_tensor(v) else v
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    for name, (usage, perp) in model.quantizer.get_usage_stats().items():
        metrics[f'{name}_usage'] = usage
        metrics[f'{name}_perplexity'] = perp
    
    return metrics

print("Training functions defined")

## 13. Two-Phase Training Loop

**What this does (Technical):** Executes two-phase training: Phase 1 trains binary air/structure classification with strong volume penalty, Phase 2 continues with full vocabulary and focal loss. Saves checkpoints at phase transitions and every 5 epochs.

**What this does (Simple):** The main training loop that:
1. **Phase 1**: Teaches the model to distinguish air from non-air (establishes correct volume)
2. **Phase 2**: Teaches the model to identify specific block types (improves accuracy)

In [None]:
print("="*70)
print("VQ-VAE V9 TWO-PHASE TRAINING")
print("="*70)
print(f"Phase 1 (epochs 1-{PHASE1_EPOCHS}): Binary air/structure")
print(f"Phase 2 (epochs {PHASE1_EPOCHS+1}-{TOTAL_EPOCHS}): Full vocabulary")
print(f"Target: volume_ratio ~1.0x, building_acc >45%")
print("="*70)

history = {
    'train_loss': [], 'val_loss': [],
    'train_volume_ratio': [], 'val_volume_ratio': [],
    'train_recall': [], 'val_recall': [],
    'train_precision': [], 'val_precision': [],
    'train_accuracy': [], 'val_accuracy': [],
    'train_building_acc': [], 'val_building_acc': [],
    'train_air_acc': [], 'val_air_acc': [],
    'train_false_air_rate': [], 'val_false_air_rate': [],
    'train_ce_loss': [], 'val_ce_loss': [],
    'train_volume_loss': [], 'val_volume_loss': [],
    'train_focal_loss': [], 'val_focal_loss': [],
    'train_stage0_perplexity': [], 'val_stage0_perplexity': [],
    'learning_rate': [], 'phase': [],
}

best_metric = 0
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    # Determine phase
    phase = 1 if epoch < PHASE1_EPOCHS else 2
    
    # Train and validate
    train_m = train_epoch(model, criterion, train_loader, optimizer, scaler, device, phase)
    val_m = validate(model, criterion, val_loader, device, phase)
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Record metrics
    for prefix, m in [('train', train_m), ('val', val_m)]:
        history[f'{prefix}_loss'].append(m.get('loss', 0))
        history[f'{prefix}_volume_ratio'].append(m.get('volume_ratio', 0))
        history[f'{prefix}_recall'].append(m.get('recall', 0))
        history[f'{prefix}_precision'].append(m.get('precision', 0))
        history[f'{prefix}_accuracy'].append(m.get('accuracy', 0))
        history[f'{prefix}_building_acc'].append(m.get('building_acc', 0))
        history[f'{prefix}_air_acc'].append(m.get('air_acc', 0))
        history[f'{prefix}_false_air_rate'].append(m.get('false_air_rate', 0))
        history[f'{prefix}_ce_loss'].append(m.get('ce_loss', 0))
        history[f'{prefix}_volume_loss'].append(m.get('volume_loss', 0))
        history[f'{prefix}_focal_loss'].append(m.get('focal_loss', 0))
        history[f'{prefix}_stage0_perplexity'].append(m.get('stage0_perplexity', 0))
    history['learning_rate'].append(current_lr)
    history['phase'].append(phase)
    
    # Best model tracking (different metric per phase)
    if phase == 1:
        current_metric = 1.0 - abs(val_m['volume_ratio'] - 1.0)  # Closer to 1.0 = better
    else:
        current_metric = val_m.get('building_acc', val_m.get('accuracy', 0))
    
    if current_metric > best_metric:
        best_metric = current_metric
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v9_best.pt")
    
    # Save at phase transition
    if epoch == PHASE1_EPOCHS - 1:
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v9_phase1_complete.pt")
        print(f"\n{'='*70}")
        print("PHASE 1 COMPLETE - Transitioning to Phase 2 (full vocabulary)")
        print(f"Volume ratio at transition: {val_m['volume_ratio']:.3f}x")
        print(f"{'='*70}\n")
    
    # Checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v9_epoch{epoch+1}.pt")
    
    # Print progress
    vol_str = f"{val_m['volume_ratio']:.3f}x"
    vol_color = '' if 0.9 <= val_m['volume_ratio'] <= 1.1 else ' (!!)'
    
    if phase == 1:
        print(f"E{epoch+1:2d} P1 | Acc: {val_m['accuracy']:.1%} | "
              f"Vol: {vol_str}{vol_color} | Recall: {val_m['recall']:.1%} | LR: {current_lr:.2e}")
    else:
        ba = val_m.get('building_acc', 0)
        print(f"E{epoch+1:2d} P2 | Build: {ba:.1%} | "
              f"Vol: {vol_str}{vol_color} | Recall: {val_m['recall']:.1%} | LR: {current_lr:.2e}")

train_time = time.time() - start_time
print(f"\n{'='*70}")
print("TRAINING COMPLETE")
print(f"{'='*70}")
print(f"Time: {train_time/60:.1f} minutes")
print(f"Best epoch: {best_epoch}")
print(f"Final volume ratio: {history['val_volume_ratio'][-1]:.3f}x")
print(f"Final recall: {history['val_recall'][-1]:.1%}")
if history['val_building_acc'][-1] > 0:
    print(f"Final building accuracy: {history['val_building_acc'][-1]:.1%}")

## 14. Plot Training Results

**What this does (Technical):** Creates a 3x3 grid of plots showing all key metrics over training epochs, with phase transitions marked and target baselines shown.

**What this does (Simple):** Visualizes how well the training went - we want to see volume ratio approaching 1.0 and accuracy improving over time.

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(16, 14))
epochs = range(1, len(history['train_loss']) + 1)
phase_transition = PHASE1_EPOCHS + 0.5

# 1. Volume Ratio (MOST IMPORTANT)
ax = axes[0, 0]
ax.plot(epochs, history['train_volume_ratio'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_volume_ratio'], 'r--', label='Val', linewidth=2)
ax.axhline(y=1.0, color='g', linestyle='--', linewidth=2, label='Target (1.0x)')
ax.axhline(y=2.0, color='orange', linestyle=':', alpha=0.7, label='v8 stuck (2.0x)')
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5, label='Phase transition')
ax.fill_between(epochs, 0.9, 1.1, alpha=0.2, color='green', label='Target zone')
ax.set_xlabel('Epoch'); ax.set_ylabel('Volume Ratio')
ax.set_title('Volume Ratio (KEY METRIC)', fontweight='bold', fontsize=12, color='darkred')
ax.legend(loc='upper right', fontsize=8); ax.grid(True, alpha=0.3)
ax.set_ylim(0, max(2.5, max(history['val_volume_ratio']) * 1.1))

# 2. Recall
ax = axes[0, 1]
ax.plot(epochs, history['train_recall'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_recall'], 'r--', label='Val', linewidth=2)
ax.axhline(y=0.90, color='g', linestyle='--', alpha=0.5, label='Target (90%)')
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Recall')
ax.set_title('Recall (Structure Preservation)', fontweight='bold', fontsize=12)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3); ax.set_ylim(0, 1.05)

# 3. Precision
ax = axes[0, 2]
ax.plot(epochs, history['train_precision'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_precision'], 'r--', label='Val', linewidth=2)
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Precision')
ax.set_title('Precision', fontweight='bold', fontsize=12)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3); ax.set_ylim(0, 1.05)

# 4. Building Accuracy (Phase 2 only)
ax = axes[1, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val', linewidth=2)
ax.axhline(y=0.492, color='g', linestyle=':', alpha=0.5, label='v6-freq (49.2%)')
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5, label='Phase 2 start')
ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy')
ax.set_title('Building Accuracy', fontweight='bold', fontsize=12)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)

# 5. Air Accuracy (Phase 2 only)
ax = axes[1, 1]
ax.plot(epochs, history['train_air_acc'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_air_acc'], 'r--', label='Val', linewidth=2)
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy')
ax.set_title('Air Accuracy', fontweight='bold', fontsize=12)
ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)

# 6. False Air Rate
ax = axes[1, 2]
ax.plot(epochs, history['train_false_air_rate'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_false_air_rate'], 'r--', label='Val', linewidth=2)
ax.axhline(y=0.10, color='orange', linestyle='--', alpha=0.5, label='Max acceptable (10%)')
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Rate')
ax.set_title('False Air Rate (Erasure)', fontweight='bold', fontsize=12)
ax.legend(loc='upper right'); ax.grid(True, alpha=0.3)

# 7. Loss Components
ax = axes[2, 0]
ax.plot(epochs, history['val_ce_loss'], 'b-', label='CE/Focal', linewidth=2)
ax.plot(epochs, history['val_volume_loss'], 'r-', label='Volume', linewidth=2)
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')
ax.set_title('Loss Components (Val)', fontweight='bold', fontsize=12)
ax.legend(); ax.grid(True, alpha=0.3)

# 8. Total Loss
ax = axes[2, 1]
ax.plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_loss'], 'r--', label='Val', linewidth=2)
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')
ax.set_title('Total Loss', fontweight='bold', fontsize=12)
ax.legend(); ax.grid(True, alpha=0.3)

# 9. RFSQ Perplexity
ax = axes[2, 2]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val', linewidth=2)
ax.axvline(x=phase_transition, color='purple', linestyle=':', alpha=0.5)
ax.set_xlabel('Epoch'); ax.set_ylabel('Perplexity')
ax.set_title('RFSQ Perplexity', fontweight='bold', fontsize=12)
ax.legend(); ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v9_training.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)
final_vol = history['val_volume_ratio'][-1]
final_recall = history['val_recall'][-1]
final_build = history['val_building_acc'][-1]

print(f"Final volume ratio: {final_vol:.3f}x")
print(f"Final recall: {final_recall:.1%}")
print(f"Final building accuracy: {final_build:.1%}")
print()

# Success criteria
vol_ok = 0.9 <= final_vol <= 1.1
recall_ok = final_recall >= 0.85
acc_ok = final_build >= 0.45

print("Success Criteria:")
print(f"  Volume 0.9-1.1x: {'PASS' if vol_ok else 'FAIL'} ({final_vol:.3f}x)")
print(f"  Recall >= 85%: {'PASS' if recall_ok else 'FAIL'} ({final_recall:.1%})")
print(f"  Build acc >= 45%: {'PASS' if acc_ok else 'FAIL'} ({final_build:.1%})")
print()

if vol_ok and recall_ok and acc_ok:
    print("SUCCESS! All targets met.")
elif vol_ok:
    print("PARTIAL SUCCESS: Volume ratio fixed! May need more Phase 2 epochs for accuracy.")
else:
    print("NEEDS WORK: Check Phase 1 metrics - volume ratio should be ~1.0 before Phase 2.")

## 15. Save Results

**What this does (Technical):** Serializes training configuration, final metrics, and full training history to JSON. Saves final model checkpoint.

**What this does (Simple):** Saves all training results and the trained model to files so we can use them later.

In [None]:
results = {
    'config': {
        'version': 'v9-TWO-PHASE',
        'latent_resolution': '8x8x8',
        'hidden_dim': HIDDEN_DIM,
        'phase1_epochs': PHASE1_EPOCHS,
        'phase2_epochs': PHASE2_EPOCHS,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'focal_gamma': FOCAL_GAMMA,
        'air_boost': AIR_BOOST,
        'volume_penalty': VOLUME_PENALTY,
        'frequency_cap': FREQUENCY_CAP,
        'seed': SEED,
    },
    'results': {
        'final_volume_ratio': float(history['val_volume_ratio'][-1]),
        'final_recall': float(history['val_recall'][-1]),
        'final_precision': float(history['val_precision'][-1]),
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_air_acc': float(history['val_air_acc'][-1]),
        'final_false_air_rate': float(history['val_false_air_rate'][-1]),
        'phase1_final_volume_ratio': float(history['val_volume_ratio'][PHASE1_EPOCHS-1]),
        'best_epoch': best_epoch,
        'training_time_min': float(train_time / 60),
        'volume_target_met': bool(0.9 <= history['val_volume_ratio'][-1] <= 1.1),
        'recall_target_met': bool(history['val_recall'][-1] >= 0.85),
        'accuracy_target_met': bool(history['val_building_acc'][-1] >= 0.45),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v9_results.json", 'w') as f:
    json.dump(results, f, indent=2)

torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v9_final.pt")

print("Results saved:")
print(f"  {OUTPUT_DIR}/vqvae_v9_results.json")
print(f"  {OUTPUT_DIR}/vqvae_v9_best.pt")
print(f"  {OUTPUT_DIR}/vqvae_v9_phase1_complete.pt")
print(f"  {OUTPUT_DIR}/vqvae_v9_final.pt")
print(f"  {OUTPUT_DIR}/vqvae_v9_training.png")
print("\nDone!")