# VQ-VAE v8-B Training - TEMPERATURE-SCALED Volume Control\n\n**FIX FOR 2.0x VOLUME RATIO**\n\n## Why Direct Logit Failed\n\nThe direct logit approach optimized margin violations to ~0, but:\n- 10% of GT air locations still had small violations\n- This 10% error = 2.0x volume ratio\n- CE loss (0.45) dominated volume loss (0.002)\n\n## New Approach: Temperature-Scaled Softmax\n\n```python\nsoft_probs = F.softmax(logits / 0.1, dim=1)  # temp=0.1 = near-argmax\nsoft_volume_ratio = soft_structure_count / gt_structure_count\nvolume_loss = (soft_volume_ratio - 1.0) ** 2  # Direct ratio penalty!\n```\n\nThis loss does NOT diminish to zero when volume ratio is wrong!\n\n## Expected Results\n\n| Metric | Previous (stuck) | Expected Now |\n|--------|-----------------|--------------|\n| Volume Ratio | 2.0x (all epochs) | **~1.0-1.3x** |\n| Building Acc | 80% (inflated) | **55-65%** (real) |\n| Recall | 99.7% | **>90%** |\n

## 1. Setup - Mount Google Drive

In [None]:
# Mount Google Drive (uncomment for Colab)
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

# Set paths - CHANGE THIS to match your Google Drive structure
DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v8b'

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")
print(f"Data directory: {DRIVE_BASE}")

## 2. Imports

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set, Optional
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 3. Model Code - FSQ and RFSQ Quantization

All model code is included inline - no external imports!

In [None]:
# ============================================================================
# FSQ (Finite Scalar Quantization)
# ============================================================================

class FSQ(nn.Module):
    """Finite Scalar Quantization."""

    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))

        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))

        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))

        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))

    @property
    def num_codes(self) -> int:
        return self.codebook_size

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z_bounded = torch.tanh(z)
        z_q = self._quantize(z_bounded)
        z_q = z_bounded + (z_q - z_bounded).detach()  # Straight-through
        indices = self._to_indices(z_q)
        return z_q, indices

    def _quantize(self, z: torch.Tensor) -> torch.Tensor:
        z_q_list = []
        for i in range(self.dim):
            L = self._levels[i]
            half_L = self._half_levels[i]
            z_i = z[..., i]
            z_i = z_i * half_L
            z_i = torch.round(z_i)
            z_i = torch.clamp(z_i, -half_L, half_L)
            z_i = z_i / half_L
            z_q_list.append(z_i)
        return torch.stack(z_q_list, dim=-1)

    def _to_indices(self, z_q: torch.Tensor) -> torch.Tensor:
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            z_i = z_q[..., i]
            level_idx = ((z_i * half_L) + half_L).round().long()
            level_idx = torch.clamp(level_idx, 0, L - 1)
            indices = indices + level_idx * self._basis[i]
        return indices

    def get_codebook_usage(self, indices: torch.Tensor) -> Tuple[float, float]:
        flat_indices = indices.flatten()
        counts = torch.bincount(flat_indices, minlength=self.codebook_size).float()
        usage = (counts > 0).float().mean().item()
        probs = counts / counts.sum()
        probs = probs[probs > 0]
        entropy = -(probs * torch.log(probs)).sum()
        perplexity = torch.exp(entropy).item()
        return usage, perplexity


# ============================================================================
# RFSQ (Residual FSQ with LayerNorm)
# ============================================================================

class InvertibleLayerNorm(nn.Module):
    """LayerNorm that stores statistics for exact inverse transformation."""

    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('stored_mean', None, persistent=False)
        self.register_buffer('stored_std', None, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        channels_last = x.shape[-1] == self.num_features
        if channels_last:
            self.stored_mean = x.mean(dim=(1, 2, 3), keepdim=True)
            self.stored_std = x.std(dim=(1, 2, 3), keepdim=True) + self.eps
            x_norm = (x - self.stored_mean) / self.stored_std
            return x_norm * self.weight + self.bias
        else:
            self.stored_mean = x.mean(dim=(2, 3, 4), keepdim=True)
            self.stored_std = x.std(dim=(2, 3, 4), keepdim=True) + self.eps
            x_norm = (x - self.stored_mean) / self.stored_std
            return x_norm * self.weight.view(1, -1, 1, 1, 1) + self.bias.view(1, -1, 1, 1, 1)

    def inverse(self, x_norm: torch.Tensor) -> torch.Tensor:
        if self.stored_mean is None or self.stored_std is None:
            raise RuntimeError("Must call forward() before inverse()")
        channels_last = x_norm.shape[-1] == self.num_features
        if channels_last:
            x = (x_norm - self.bias) / self.weight
            return x * self.stored_std + self.stored_mean
        else:
            x = (x_norm - self.bias.view(1, -1, 1, 1, 1)) / self.weight.view(1, -1, 1, 1, 1)
            return x * self.stored_std + self.stored_mean


class RFSQStage(nn.Module):
    """Single stage of Residual FSQ with LayerNorm conditioning."""

    def __init__(self, levels: List[int]):
        super().__init__()
        self.levels = levels
        self.fsq = FSQ(levels)
        self.layernorm = InvertibleLayerNorm(len(levels))

    @property
    def codebook_size(self) -> int:
        return self.fsq.codebook_size

    def forward(self, residual: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z_norm = self.layernorm(residual)
        z_q_norm, indices = self.fsq(z_norm)
        z_q = self.layernorm.inverse(z_q_norm)
        new_residual = residual - z_q
        return z_q, new_residual, indices


class RFSQ(nn.Module):
    """Robust Residual FSQ with multiple stages."""

    def __init__(self, levels_per_stage: List[int], num_stages: int = 2):
        super().__init__()
        self.levels_per_stage = levels_per_stage
        self.num_stages = num_stages
        self.dim = len(levels_per_stage)

        self.stages = nn.ModuleList([
            RFSQStage(levels_per_stage) for _ in range(num_stages)
        ])

        codes_per_stage = int(np.prod(levels_per_stage))
        self.codebook_size = codes_per_stage ** num_stages
        self.codes_per_stage = codes_per_stage
        self._usage_indices = []

    @property
    def num_codes(self) -> int:
        return self.codebook_size

    def reset_usage(self):
        self._usage_indices = []

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []

        for stage in self.stages:
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)

        self._usage_indices.append(all_indices)
        return z_q_sum, all_indices

    def get_usage_stats(self) -> Dict[str, Tuple[float, float]]:
        if not self._usage_indices:
            return {}
        stats = {}
        for stage_idx in range(self.num_stages):
            all_stage_indices = torch.cat([
                batch[stage_idx].flatten()
                for batch in self._usage_indices
            ])
            usage, perplexity = self.stages[stage_idx].fsq.get_codebook_usage(all_stage_indices)
            stats[f'stage{stage_idx}'] = (usage, perplexity)
        return stats

print("âœ“ FSQ and RFSQ modules defined")

## 4. Model Code - VQ-VAE v8-B Architecture

In [None]:
class ResidualBlock3D(nn.Module):
    """3D residual block with BatchNorm."""

    def __init__(self, in_channels: int, out_channels: int = None):
        super().__init__()
        if out_channels is None:
            out_channels = in_channels

        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv3d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + residual
        return F.relu(out)


class EncoderV8B(nn.Module):
    """Encoder: 32x32x32 â†’ 16x16x16 latent."""

    def __init__(
        self,
        in_channels: int = 40,
        hidden_dim: int = 192,
        rfsq_dim: int = 4,
        num_resblocks_per_stage: int = 2,
        num_resblocks_latent: int = 6,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dim, 3, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.downsample = nn.Sequential(
            nn.Conv3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout)
        )

        self.stage_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_per_stage)
        ])

        self.latent_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_latent)
        ])

        self.latent_proj = nn.Conv3d(hidden_dim, rfsq_dim, 3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.initial(x)
        x = self.downsample(x)
        x = self.stage_blocks(x)
        x = self.latent_blocks(x)
        z_e = self.latent_proj(x)
        return z_e


class DecoderV8B(nn.Module):
    """Decoder: 16x16x16 latent â†’ 32x32x32 output."""

    def __init__(
        self,
        rfsq_dim: int = 4,
        hidden_dim: int = 192,
        num_blocks: int = 3717,
        num_resblocks_per_stage: int = 2,
        num_resblocks_latent: int = 6,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv3d(rfsq_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.latent_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_latent)
        ])

        self.stage_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_per_stage)
        ])

        self.upsample = nn.Sequential(
            nn.ConvTranspose3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout)
        )

        self.final = nn.Conv3d(hidden_dim, num_blocks, 3, padding=1)

    def forward(self, z_q: torch.Tensor) -> torch.Tensor:
        x = self.initial(z_q)
        x = self.latent_blocks(x)
        x = self.stage_blocks(x)
        x = self.upsample(x)
        logits = self.final(x)
        return logits


class VQVAEv8B(nn.Module):
    """VQ-VAE v8-B with 16Ã—16Ã—16 latent resolution."""

    def __init__(
        self,
        vocab_size: int = 3717,
        emb_dim: int = 40,
        hidden_dim: int = 192,
        rfsq_levels: List[int] = None,
        num_stages: int = 2,
        dropout: float = 0.1,
        pretrained_embeddings: torch.Tensor = None,
    ):
        super().__init__()

        if rfsq_levels is None:
            rfsq_levels = [5, 5, 5, 5]

        self.rfsq_dim = len(rfsq_levels)

        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        if pretrained_embeddings is not None:
            self.block_emb.weight.data.copy_(pretrained_embeddings)
            self.block_emb.weight.requires_grad = False

        self.encoder = EncoderV8B(
            in_channels=emb_dim,
            hidden_dim=hidden_dim,
            rfsq_dim=self.rfsq_dim,
            dropout=dropout
        )
        self.quantizer = RFSQ(
            levels_per_stage=rfsq_levels,
            num_stages=num_stages
        )
        self.decoder = DecoderV8B(
            rfsq_dim=self.rfsq_dim,
            hidden_dim=hidden_dim,
            num_blocks=vocab_size,
            dropout=dropout
        )

    def forward(self, block_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        z_e = self.encode(block_ids)
        z_q, indices = self.quantize(z_e)
        logits = self.decode(z_q)
        return logits, z_q, indices

    def encode(self, block_ids: torch.Tensor) -> torch.Tensor:
        x = self.block_emb(block_ids)
        x = x.permute(0, 4, 1, 2, 3)
        z_e = self.encoder(x)
        z_e = z_e.permute(0, 2, 3, 4, 1)
        return z_e

    def quantize(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        z_q, indices = self.quantizer(z_e)
        return z_q, indices

    def decode(self, z_q: torch.Tensor) -> torch.Tensor:
        z_q = z_q.permute(0, 4, 1, 2, 3)
        logits = self.decoder(z_q)
        return logits

print("âœ“ VQ-VAE v8-B architecture defined")

## 5. Model Code - Loss Functions

In [None]:
class FrequencyWeightedLoss(nn.Module):
    """Frequency-weighted CE loss with TEMPERATURE-SCALED volume control.

    THE PROBLEM WITH DIRECT LOGIT APPROACH:
    - Mean margin violation goes to ~0, but volume ratio stays at 2.0x
    - 10% of GT air locations have small violations = 2.0x volume ratio
    - CE loss (0.45) dominates volume loss (0.002)

    THE FIX - TEMPERATURE-SCALED SOFTMAX:
    - Use very low temperature (0.1) to approximate argmax
    - Softmax(logits / 0.1) is nearly one-hot but DIFFERENTIABLE
    - Volume loss = (soft_volume_ratio - 1.0)^2
    - This loss does NOT go to zero when volume ratio is wrong!
    """

    def __init__(
        self,
        frequency_weights: torch.Tensor,
        frequency_cap: float = 5.0,
        volume_penalty_weight: float = 50.0,   # Increased from 10!
        false_air_penalty_weight: float = 25.0, # Increased from 5!
        perceptual_weight: float = 0.1,
        temperature: float = 0.1,  # NEW: Low temp = near-argmax
        air_tokens: Optional[Set[int]] = None,
    ):
        super().__init__()
        clamped_weights = frequency_weights.clamp(max=frequency_cap)
        self.register_buffer('freq_weights', clamped_weights)
        self.volume_penalty_weight = volume_penalty_weight
        self.false_air_penalty_weight = false_air_penalty_weight
        self.perceptual_weight = perceptual_weight
        self.temperature = temperature
        if air_tokens is None:
            air_tokens = {102, 576, 3352}
        self.air_tokens = list(air_tokens)

    def forward(
        self,
        logits: torch.Tensor,
        target: torch.Tensor,
        z_q: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        vocab_size = logits.shape[1]

        # 1. Frequency-weighted CE loss
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, vocab_size)
        target_flat = target.reshape(-1)
        ce_loss = F.cross_entropy(logits_flat, target_flat, weight=self.freq_weights, reduction='mean')

        # 2. TEMPERATURE-SCALED VOLUME CONTROL
        # Low temperature softmax approximates argmax but keeps gradients!
        # temp=0.1 means softmax is nearly one-hot
        scaled_logits = logits / self.temperature
        soft_probs = F.softmax(scaled_logits, dim=1)  # [B, vocab_size, H, W, D]

        # Sum probability of predicting air tokens
        air_prob = torch.zeros_like(soft_probs[:, 0, :, :, :])  # [B, H, W, D]
        for air_tok in self.air_tokens:
            if air_tok < soft_probs.shape[1]:
                air_prob = air_prob + soft_probs[:, air_tok, :, :, :]

        # Soft structure probability (1 - air_prob)
        structure_prob = 1.0 - air_prob

        # Ground truth masks
        air_tokens_tensor = torch.tensor(self.air_tokens, device=target.device, dtype=target.dtype)
        gt_is_air = torch.isin(target, air_tokens_tensor)
        gt_is_structure = ~gt_is_air

        # Soft volumes
        soft_pred_volume = structure_prob.sum()
        gt_volume = gt_is_structure.float().sum()

        # Volume ratio loss - DIRECTLY penalizes wrong ratio!
        soft_volume_ratio = soft_pred_volume / (gt_volume + 1e-6)
        volume_loss = (soft_volume_ratio - 1.0) ** 2

        # 3. FALSE AIR PENALTY - penalize structure_prob at GT air locations
        # If structure_prob is high at GT air locations, that's bad
        if gt_is_air.any():
            # Mean structure probability at GT air locations
            # Should be LOW (close to 0)
            false_structure_at_air = structure_prob[gt_is_air].mean()
            false_air_loss = false_structure_at_air
        else:
            false_air_loss = torch.tensor(0.0, device=logits.device)

        # 4. STRUCTURE PRESERVATION - penalize air_prob at GT structure locations
        # If air_prob is high at GT structure locations, that's bad
        if gt_is_structure.any():
            # Mean air probability at GT structure locations
            # Should be LOW (close to 0)
            false_air_at_structure = air_prob[gt_is_structure].mean()
            structure_loss = false_air_at_structure
        else:
            structure_loss = torch.tensor(0.0, device=logits.device)

        # 5. Perceptual loss (spatial smoothness)
        diff_h = (z_q[:, :, 1:, :, :] - z_q[:, :, :-1, :, :]).abs().mean()
        diff_w = (z_q[:, :, :, 1:, :] - z_q[:, :, :, :-1, :]).abs().mean()
        diff_d = (z_q[:, :, :, :, 1:] - z_q[:, :, :, :, :-1]).abs().mean()
        perceptual_loss = (diff_h + diff_w + diff_d) / 3.0

        # Combined loss
        # false_air_loss: don't predict structure at GT air
        # structure_loss: don't predict air at GT structure
        total_loss = (
            ce_loss +
            self.volume_penalty_weight * volume_loss +
            self.false_air_penalty_weight * false_air_loss +
            self.false_air_penalty_weight * structure_loss +  # Same weight
            self.perceptual_weight * perceptual_loss
        )

        # Compute hard metrics for logging (non-differentiable)
        with torch.no_grad():
            pred_hard = torch.argmax(logits, dim=1)
            pred_is_air_hard = torch.isin(pred_hard, air_tokens_tensor)
            pred_volume_hard = (~pred_is_air_hard).float().sum()
            volume_ratio_hard = pred_volume_hard / (gt_volume + 1e-6)

            # Recall = structure voxels preserved / total structure voxels
            structure_preserved = gt_is_structure & (~pred_is_air_hard)
            recall = structure_preserved.float().sum() / (gt_is_structure.float().sum() + 1e-6)

        return {
            'loss': total_loss,
            'ce_loss': ce_loss.detach(),
            'volume_loss': volume_loss.detach(),
            'false_air_loss': false_air_loss.detach() if torch.is_tensor(false_air_loss) else false_air_loss,
            'structure_loss': structure_loss.detach() if torch.is_tensor(structure_loss) else structure_loss,
            'perceptual_loss': perceptual_loss.detach(),
            'volume_ratio': volume_ratio_hard,
            'soft_volume_ratio': soft_volume_ratio.detach(),
            'recall': recall,
        }


def compute_frequency_weights(
    block_ids: torch.Tensor,
    vocab_size: int,
    smoothing: float = 0.5,
) -> torch.Tensor:
    """Compute frequency-based weights."""
    counts = torch.bincount(block_ids.flatten(), minlength=vocab_size).clamp(min=1)
    total = counts.sum()
    weights = (total.float() / counts.float()) ** smoothing
    return weights

print("="*70)
print("TEMPERATURE-SCALED VOLUME CONTROL")
print("="*70)
print("WHY THIS WORKS:")
print("  - Low temperature (0.1) makes softmax nearly one-hot")
print("  - Soft volume ratio approximates hard volume ratio")
print("  - Loss = (soft_ratio - 1.0)^2 does NOT go to zero!")
print("  - Strong direct signal to fix volume ratio")
print()
print("KEY CHANGES:")
print("  - temperature = 0.1 (near-argmax)")
print("  - volume_penalty_weight = 50 (5x stronger)")
print("  - false_air_penalty_weight = 25 (5x stronger)")
print("  - Added structure_loss (protect recall)")
print("="*70)\n

## 6. Configuration

In [None]:
# === Loss Weights (TEMPERATURE-SCALED) ===
FREQUENCY_WEIGHT_CAP = 5.0
VOLUME_PENALTY_WEIGHT = 50.0        # 5x stronger than before!
FALSE_AIR_PENALTY_WEIGHT = 25.0     # 5x stronger than before!
PERCEPTUAL_WEIGHT = 0.1
TEMPERATURE = 0.1                   # Low temp = near-argmax softmax\nConfiguration:")
print(f"  Latent: 16x16x16 (4,096 positions)")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM_STEPS})")
print(f"  Epochs: {TOTAL_EPOCHS}")
print(f"  Learning rate: {BASE_LR}")
print(f"\nLoss weights:")
print(f"  Volume penalty: {VOLUME_PENALTY_WEIGHT} (prevents over-prediction)")
print(f"  False air penalty: {FALSE_AIR_PENALTY_WEIGHT} (protects recall)")\n

## 7. Load Vocabulary and Embeddings

In [None]:
# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_TENSOR = torch.tensor(sorted(AIR_TOKENS), dtype=torch.long)

# Load embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"\nV3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## 8. Compute Frequency Weights

In [None]:
print("Computing block frequencies from training data...")
print("(This may take a few minutes)")

all_block_ids = []
train_files = sorted(Path(DATA_DIR).glob("*.h5"))

for h5_file in tqdm(train_files, desc="Scanning"):
    with h5py.File(h5_file, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].flatten()
        all_block_ids.append(torch.from_numpy(structure).long())

all_block_ids = torch.cat(all_block_ids)
print(f"\nTotal blocks scanned: {len(all_block_ids):,}")

FREQUENCY_WEIGHT_TENSOR = compute_frequency_weights(
    all_block_ids,
    vocab_size=VOCAB_SIZE,
    smoothing=0.5
).clamp(max=FREQUENCY_WEIGHT_CAP)

print(f"Frequency weights computed (cap={FREQUENCY_WEIGHT_CAP}x)")

## 9. Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

## 10. Create Model

In [None]:
criterion = FrequencyWeightedLoss(
    frequency_weights=FREQUENCY_WEIGHT_TENSOR,
    frequency_cap=FREQUENCY_WEIGHT_CAP,
    volume_penalty_weight=VOLUME_PENALTY_WEIGHT,
    false_air_penalty_weight=FALSE_AIR_PENALTY_WEIGHT,
    perceptual_weight=PERCEPTUAL_WEIGHT,
    temperature=TEMPERATURE,
    air_tokens=AIR_TOKENS,
).to(device)\n

## 11. Data Loaders

In [None]:
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 12. Training Functions

In [None]:
def compute_metrics(logits, targets, air_tokens):
    """Compute key metrics."""
    device = logits.device
    B, C, X, Y, Z = logits.shape
    
    logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
    targets_flat = targets.view(-1)
    
    air_dev = air_tokens.to(device)
    is_air = torch.isin(targets_flat, air_dev)
    is_building = ~is_air
    
    preds = logits_flat.argmax(dim=1)
    is_air_pred = torch.isin(preds, air_dev)
    correct = (preds == targets_flat).float()
    
    metrics = {}
    metrics['overall_acc'] = correct.mean()
    
    if is_building.any():
        metrics['building_acc'] = correct[is_building].mean()
        metrics['building_recall'] = (is_building & ~is_air_pred).sum().float() / is_building.sum()
    else:
        metrics['building_acc'] = torch.tensor(0.0, device=device)
        metrics['building_recall'] = torch.tensor(0.0, device=device)
    
    is_struct = ~is_air
    metrics['struct_recall'] = (is_struct & ~is_air_pred).sum().float() / is_struct.sum() if is_struct.any() else torch.tensor(0.0, device=device)
    
    pred_building = ~is_air_pred
    pred_vol = pred_building.sum().float()
    orig_vol = is_struct.sum().float()
    metrics['volume_ratio'] = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0, device=device)
    
    return metrics


def train_epoch(model, criterion, loader, optimizer, scaler, device, air_tokens):
    model.train()
    model.quantizer.reset_usage()
    
    # Include ALL metrics from loss function
    metrics_sum = {
        'loss': 0.0, 'ce_loss': 0.0, 'volume_loss': 0.0, 
        'false_air_loss': 0.0, 'perceptual_loss': 0.0,
        'overall_acc': 0.0, 'building_acc': 0.0, 'building_recall': 0.0,
        'struct_recall': 0.0, 'volume_ratio': 0.0, 'recall': 0.0
    }
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q)
            loss = loss_dict['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        with torch.no_grad():
            # Collect loss metrics
            for k in ['loss', 'ce_loss', 'volume_loss', 'false_air_loss', 'structure_loss', 'perceptual_loss', 'volume_ratio', 'soft_volume_ratio', 'recall']:
                if k in loss_dict:
                    val = loss_dict[k]
                    metrics_sum[k] += val.item() if torch.is_tensor(val) else val
            
            batch_metrics = compute_metrics(logits, batch, air_tokens)
            for k, v in batch_metrics.items():
                metrics_sum[k] += v.item()
        
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    stage_stats = model.quantizer.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    return metrics


@torch.no_grad()
def validate(model, criterion, loader, device, air_tokens):
    model.eval()
    model.quantizer.reset_usage()
    
    # Include ALL metrics from loss function
    metrics_sum = {
        'loss': 0.0, 'ce_loss': 0.0, 'volume_loss': 0.0,
        'false_air_loss': 0.0, 'perceptual_loss': 0.0,
        'overall_acc': 0.0, 'building_acc': 0.0, 'building_recall': 0.0,
        'struct_recall': 0.0, 'volume_ratio': 0.0, 'recall': 0.0
    }
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q)
        
        # Collect loss metrics
        for k in ['loss', 'ce_loss', 'volume_loss', 'false_air_loss', 'structure_loss', 'perceptual_loss', 'volume_ratio', 'soft_volume_ratio', 'recall']:
            if k in loss_dict:
                val = loss_dict[k]
                metrics_sum[k] += val.item() if torch.is_tensor(val) else val
        
        batch_metrics = compute_metrics(logits, batch, air_tokens)
        for k, v in batch_metrics.items():
            metrics_sum[k] += v.item()
        
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    stage_stats = model.quantizer.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    return metrics

print("Training functions defined")
print("  - Tracking: volume_ratio, recall, false_air_loss")\n

## 13. Training Loop

In [None]:
history = {
    # Core metrics
    'train_loss': [], 'val_loss': [],
    'train_building_acc': [], 'val_building_acc': [],
    'train_vol_ratio': [], 'val_vol_ratio': [],
    'train_recall': [], 'val_recall': [],

    # Loss components
    'train_ce_loss': [], 'val_ce_loss': [],
    'train_volume_loss': [], 'val_volume_loss': [],
    'train_false_air_loss': [], 'val_false_air_loss': [],
    'train_structure_loss': [], 'val_structure_loss': [],
    'train_perceptual_loss': [], 'val_perceptual_loss': [],

    # Soft volume ratio (should track hard ratio)
    'train_soft_vol_ratio': [], 'val_soft_vol_ratio': [],

    # RFSQ metrics
    'train_stage0_perplexity': [], 'val_stage0_perplexity': [],

    # Training info
    'learning_rate': [],
}\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")
print(f"Final volume ratio: {history['val_vol_ratio'][-1]:.2f}x")
print(f"Final recall: {history['val_recall'][-1]:.1%}")

\n

## 14. Plot Results

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(14, 15))
epochs = range(1, len(history['train_loss']) + 1)

# 1. Building Accuracy (TOP LEFT)
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val', linewidth=2)
ax.axhline(y=0.492, color='g', linestyle=':', alpha=0.5, label='v6-freq (49.2%)')
ax.axhline(y=0.60, color='orange', linestyle='--', alpha=0.5, label='Target (60%)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Building Accuracy', fontweight='bold', fontsize=12)
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)

# 2. Volume Ratio (TOP RIGHT) - THE KEY METRIC TO WATCH!
ax = axes[0, 1]
ax.plot(epochs, history['train_vol_ratio'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_vol_ratio'], 'r--', label='Val', linewidth=2)
ax.axhline(y=1.0, color='g', linestyle='--', linewidth=2, label='Target (1.0x)')
ax.axhline(y=1.3, color='orange', linestyle=':', alpha=0.7, label='Max acceptable (1.3x)')
ax.axhline(y=2.0, color='red', linestyle=':', alpha=0.5, label='Previous stuck (2.0x)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Volume Ratio')
ax.set_title('Volume Ratio (MUST decrease from 2.0!)', fontweight='bold', fontsize=12, color='darkred')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_ylim(0, max(2.5, max(history['val_vol_ratio']) * 1.1))

# 3. Recall (MIDDLE LEFT) - Structure preservation
ax = axes[1, 0]
ax.plot(epochs, history['train_recall'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_recall'], 'r--', label='Val', linewidth=2)
ax.axhline(y=0.90, color='g', linestyle='--', alpha=0.5, label='Target (90%)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Recall')
ax.set_title('Recall (Structure Preservation)', fontweight='bold', fontsize=12)
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.05)

# 4. Loss Components (MIDDLE RIGHT) - Debugging
ax = axes[1, 1]
ax.plot(epochs, history['val_ce_loss'], 'b-', label='CE Loss', linewidth=2)
ax.plot(epochs, history['val_volume_loss'], 'r-', label='Volume Loss', linewidth=2)
ax.plot(epochs, history['val_false_air_loss'], 'g-', label='False Air Loss', linewidth=2)
ax.plot(epochs, history['val_perceptual_loss'], 'm-', label='Perceptual Loss', linewidth=1, alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss Value')
ax.set_title('Loss Components (Val)', fontweight='bold', fontsize=12)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

# 5. Total Loss (BOTTOM LEFT)
ax = axes[2, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_loss'], 'r--', label='Val', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Total Loss')
ax.set_title('Total Loss', fontweight='bold', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)

# 6. RFSQ Perplexity (BOTTOM RIGHT)
ax = axes[2, 1]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Perplexity')
ax.set_title('RFSQ Stage 0 Perplexity', fontweight='bold', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v8b_training.png", dpi=150, bbox_inches='tight')
plt.show()

# Print summary with ALL key metrics
print("\n" + "="*70)
print("TRAINING RESULTS SUMMARY")
print("="*70)
print(f"Best building accuracy: {best_building_acc:.1%} (epoch {best_epoch})")
print(f"Final building accuracy: {history['val_building_acc'][-1]:.1%}")
print(f"Final volume ratio: {history['val_vol_ratio'][-1]:.2f}x")
print(f"Final recall: {history['val_recall'][-1]:.1%}")
print(f"Training time: {train_time/60:.1f} minutes")
print()
print("Loss components (final epoch):")
print(f"  CE Loss: {history['val_ce_loss'][-1]:.4f}")
print(f"  Volume Loss: {history['val_volume_loss'][-1]:.4f}")
print(f"  False Air Loss: {history['val_false_air_loss'][-1]:.4f}")
print(f"  Perceptual Loss: {history['val_perceptual_loss'][-1]:.4f}")
print()

# Success criteria check
vol_ok = history['val_vol_ratio'][-1] <= 1.3
recall_ok = history['val_recall'][-1] >= 0.90
acc_ok = best_building_acc >= 0.60

print("Success Criteria:")
print(f"  Volume ratio <= 1.3x: {'[PASS]' if vol_ok else '[FAIL]'} ({history['val_vol_ratio'][-1]:.2f}x)")
print(f"  Recall >= 90%: {'[PASS]' if recall_ok else '[FAIL]'} ({history['val_recall'][-1]:.1%})")
print(f"  Building acc >= 60%: {'[PASS]' if acc_ok else '[FAIL]'} ({best_building_acc:.1%})")
print()

if vol_ok and recall_ok and acc_ok:
    print("[SUCCESS] All targets met! Ready for Stage 2 (v8-C)")
elif vol_ok and recall_ok:
    print("[PARTIAL] Volume and recall OK, but accuracy below target")
elif not vol_ok:
    print("[CRITICAL] Volume ratio still too high - direct logit fix may not be working")
else:
    print("[NEEDS ANALYSIS] Check the loss component plots above")


## 15. Save Results

In [None]:
results = {
    'config': {
        'version': 'v8-B-DIRECT-LOGIT',
        'latent_resolution': '16x16x16',
        'hidden_dim': HIDDEN_DIM,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'seed': SEED,
        'volume_penalty_weight': VOLUME_PENALTY_WEIGHT,
        'false_air_penalty_weight': FALSE_AIR_PENALTY_WEIGHT,
    },
    'results': {
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_volume_ratio': float(history['val_vol_ratio'][-1]),
        'final_recall': float(history['val_recall'][-1]),
        'final_ce_loss': float(history['val_ce_loss'][-1]),
        'final_volume_loss': float(history['val_volume_loss'][-1]),
        'final_false_air_loss': float(history['val_false_air_loss'][-1]),
        'training_time_min': float(train_time / 60),
        'target_60pct_met': bool(best_building_acc >= 0.60),
        'volume_target_met': bool(history['val_vol_ratio'][-1] <= 1.3),
        'recall_target_met': bool(history['val_recall'][-1] >= 0.90),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v8b_results.json", 'w') as f:
    json.dump(results, f, indent=2)

torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v8b_final.pt")

print("\nResults saved to:")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_best.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_training.png")
print("\nDone!")
