# VQ-VAE v4 Training - Shape Preservation Focus

## Philosophy: SHAPE FIRST, DETAILS SECOND

The key insight from Phase 2.5 validation: **buildings were disappearing** in reconstructions.
The model was predicting AIR where there should be structure blocks.

This version addresses that with:
1. **Asymmetric Cross-Entropy**: structure→air errors penalized 10x more than air→structure
2. **False Air Penalty**: explicit loss for predicting air where structure exists
3. **Volume Preservation**: penalize if reconstruction has fewer blocks than original
4. **Structure Recall**: new key metric tracking shape preservation

## Key Improvements over v3

| Change | v3 | v4 |
|--------|-----|-----|
| Latent grid | 4x4x4 (512:1) | **8x8x8 (64:1)** |
| Embeddings | Frozen | **Trainable** |
| Loss | CE only | **CE + embedding similarity + shape preservation** |
| Metric | Exact-match | **Structure Recall (shape preservation)** |
| Key focus | Accuracy | **Don't erase buildings!** |

## New Metrics

- **Structure Recall**: Of blocks that SHOULD be structure, how many are NOT erased? (Target: >90%)
- **False Air Rate**: What % of structure was wrongly predicted as air? (Target: <10%)
- **Volume Ratio**: predicted_volume / original_volume (Target: ~1.0)

## Configuration

- **Epochs**: 15 total (5 warmup + 10 full)
- **Structure weight**: 50x (increased from 10x)
- **False air weight**: 5x
- **Structure→air weight**: 10x (asymmetric CE)

In [None]:
# ============================================================
# CELL 1: Imports
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VAL_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/val"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"
V3_EMBEDDINGS_PATH = "/kaggle/input/minecraft-embeddings-v3/block_embeddings_v3.npy"

OUTPUT_DIR = "/kaggle/working"

# === V4 Model Architecture ===
HIDDEN_DIMS = [96, 192]  # 2 stages for 32->8 (not 3 stages for 32->4)
LATENT_DIM = 256
NUM_CODEBOOK_ENTRIES = 512
DROPOUT = 0.1

# === VQ-VAE Settings ===
COMMITMENT_COST = 0.5
EMA_DECAY = 0.99
STRUCTURE_WEIGHT = 50.0  # INCREASED from 10.0 for shape preservation

# === V4 Specific Settings ===
EMBEDDING_LOSS_ALPHA = 0.5  # Weight for embedding similarity loss
STABILITY_WEIGHT = 0.01    # Embedding stability regularization
DIVERSITY_WEIGHT = 0.001   # Embedding diversity regularization

# === SHAPE PRESERVATION SETTINGS (NEW) ===
# These prevent the "buildings disappearing" problem
FALSE_AIR_WEIGHT = 5.0     # Heavily penalize predicting air where structure exists
VOLUME_WEIGHT = 2.0        # Penalize losing structure volume
STRUCTURE_TO_AIR_WEIGHT = 10.0  # Asymmetric CE: structure->air errors 10x worse
USE_SHAPE_LOSS = True      # Enable shape preservation loss
USE_ASYMMETRIC_LOSS = True # Enable asymmetric cross-entropy

# === Training ===
TOTAL_EPOCHS = 15
WARMUP_EPOCHS = 5  # Freeze embeddings for first N epochs
BATCH_SIZE = 4     # Reduced due to larger latent grid
BASE_LR = 3e-4
EMBEDDING_LR_SCALE = 0.1  # Embeddings train 10x slower
USE_AMP = True
GRAD_ACCUM_STEPS = 4

SEED = 42
NUM_WORKERS = 2

print("VQ-VAE v4 Configuration:")
print(f"  Latent grid: 8x8x8 (64:1 compression)")
print(f"  Hidden dims: {HIDDEN_DIMS}")
print(f"  Epochs: {TOTAL_EPOCHS} ({WARMUP_EPOCHS} warmup + {TOTAL_EPOCHS - WARMUP_EPOCHS} full)")
print(f"  Embedding loss alpha: {EMBEDDING_LOSS_ALPHA}")
print(f"  Structure weight: {STRUCTURE_WEIGHT}x")
print(f"\nShape Preservation (NEW):")
print(f"  False air weight: {FALSE_AIR_WEIGHT}")
print(f"  Volume weight: {VOLUME_WEIGHT}")
print(f"  Structure->air weight: {STRUCTURE_TO_AIR_WEIGHT}")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Embeddings
# ============================================================

with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Load V3 embeddings (compositional)
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

In [None]:
# ============================================================
# CELL 4: Dataset
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
# ============================================================
# CELL 5: VQ-VAE v4 Architecture with Shape Preservation
# ============================================================

class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class VectorQuantizerEMA(nn.Module):
    """EMA-based vector quantizer with dead code reset."""
    
    def __init__(self, num_codes, latent_dim, commitment_cost=0.5, ema_decay=0.99):
        super().__init__()
        self.num_codes = num_codes
        self.latent_dim = latent_dim
        self.commitment_cost = commitment_cost
        self.ema_decay = ema_decay
        
        self.register_buffer('codebook', torch.randn(num_codes, latent_dim))
        self.codebook.data.uniform_(-1/num_codes, 1/num_codes)
        self.register_buffer('ema_cluster_size', torch.zeros(num_codes))
        self.register_buffer('ema_embed_sum', torch.zeros(num_codes, latent_dim))
        self.register_buffer('code_usage', torch.zeros(num_codes))
    
    def reset_epoch_stats(self):
        self.code_usage.zero_()
    
    def get_usage_fraction(self):
        return (self.code_usage > 0).float().mean().item()
    
    def get_perplexity(self):
        if self.code_usage.sum() == 0:
            return 0.0
        probs = self.code_usage / self.code_usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        return entropy.exp().item()
    
    def forward(self, z_e):
        z_e_perm = z_e.permute(0, 2, 3, 4, 1).contiguous()
        flat = z_e_perm.view(-1, self.latent_dim)
        flat_f32 = flat.float()
        codebook_f32 = self.codebook.float()
        
        d = (flat_f32.pow(2).sum(1, keepdim=True) 
             + codebook_f32.pow(2).sum(1) 
             - 2 * flat_f32 @ codebook_f32.t())
        
        indices = d.argmin(dim=1)
        
        with torch.no_grad():
            for idx in indices.unique():
                self.code_usage[idx] += (indices == idx).sum()
        
        z_q_flat = self.codebook[indices]
        z_q_perm = z_q_flat.view(z_e_perm.shape)
        
        if self.training:
            with torch.no_grad():
                encodings = F.one_hot(indices, self.num_codes).float()
                batch_size = encodings.sum(0)
                
                self.ema_cluster_size = self.ema_decay * self.ema_cluster_size + (1 - self.ema_decay) * batch_size
                batch_sum = encodings.t() @ flat_f32
                self.ema_embed_sum = self.ema_decay * self.ema_embed_sum + (1 - self.ema_decay) * batch_sum
                
                n = self.ema_cluster_size.sum()
                smoothed = (self.ema_cluster_size + 1e-5) / (n + self.num_codes * 1e-5) * n
                self.codebook.data = self.ema_embed_sum / smoothed.unsqueeze(1)
                
                # Dead code reset
                dead = batch_size < 2
                if dead.any() and flat_f32.size(0) > 0:
                    n_dead = dead.sum().item()
                    rand_idx = torch.randint(0, flat_f32.size(0), (n_dead,), device=flat_f32.device)
                    self.codebook.data[dead] = flat_f32[rand_idx]
                    self.ema_cluster_size[dead] = 1
                    self.ema_embed_sum[dead] = flat_f32[rand_idx]
        
        commitment_loss = F.mse_loss(z_e_perm, z_q_perm.detach())
        vq_loss = self.commitment_cost * commitment_loss
        
        z_q_st = z_e_perm + (z_q_perm - z_e_perm).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        return z_q, vq_loss, indices.view(z_e_perm.shape[:-1])


class EncoderV4(nn.Module):
    """32x32x32 -> 8x8x8 encoder (2 stages instead of 3)."""
    
    def __init__(self, in_channels, hidden_dims, latent_dim, dropout=0.1):
        super().__init__()
        layers = []
        current = in_channels
        
        for h in hidden_dims:
            layers.extend([
                nn.Conv3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
                ResidualBlock3D(h),
            ])
            current = h
        
        # Extra capacity at 8x8x8
        layers.extend([
            ResidualBlock3D(current),
            ResidualBlock3D(current),
            nn.Conv3d(current, latent_dim, 3, padding=1),
        ])
        
        self.encoder = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.encoder(x)


class DecoderV4(nn.Module):
    """8x8x8 -> 32x32x32 decoder."""
    
    def __init__(self, latent_dim, hidden_dims, num_blocks, dropout=0.1):
        super().__init__()
        layers = [
            ResidualBlock3D(latent_dim),
            ResidualBlock3D(latent_dim),
        ]
        
        current = latent_dim
        for h in hidden_dims:
            layers.extend([
                ResidualBlock3D(current),
                nn.ConvTranspose3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
            ])
            current = h
        
        layers.append(nn.Conv3d(current, num_blocks, 3, padding=1))
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, z_q):
        return self.decoder(z_q)


def compute_similarity_matrix(embeddings):
    """Compute cosine similarity matrix scaled to [0,1]."""
    with torch.no_grad():
        normed = F.normalize(embeddings, dim=1)
        sim = normed @ normed.t()
        return (sim + 1) / 2


class ShapePreservationLoss(nn.Module):
    """Loss functions to prevent buildings from disappearing.
    
    Philosophy: SHAPE FIRST, DETAILS SECOND.
    It's better to predict the wrong block type than to predict air.
    """
    
    def __init__(self, false_air_weight=5.0, volume_weight=2.0):
        super().__init__()
        self.false_air_weight = false_air_weight
        self.volume_weight = volume_weight
    
    def forward(self, logits, targets, air_tokens):
        predictions = logits.argmax(dim=1)
        
        is_struct_orig = ~torch.isin(targets, air_tokens)
        is_air_pred = torch.isin(predictions, air_tokens)
        is_struct_pred = ~is_air_pred
        
        # False air: predicted air where structure existed
        false_air_mask = is_struct_orig & is_air_pred
        if is_struct_orig.sum() > 0:
            false_air_rate = false_air_mask.float().sum() / is_struct_orig.float().sum()
        else:
            false_air_rate = torch.tensor(0.0, device=predictions.device)
        
        # Volume preservation: penalize losing structure volume
        orig_vol = is_struct_orig.float().sum()
        pred_vol = is_struct_pred.float().sum()
        if orig_vol > 0:
            volume_loss = F.relu(orig_vol - pred_vol) / orig_vol
        else:
            volume_loss = torch.tensor(0.0, device=predictions.device)
        
        # Structure recall: fraction of original structure preserved
        true_struct = is_struct_orig & is_struct_pred
        if is_struct_orig.sum() > 0:
            struct_recall = true_struct.float().sum() / is_struct_orig.float().sum()
        else:
            struct_recall = torch.tensor(1.0, device=predictions.device)
        
        total = self.false_air_weight * false_air_rate + self.volume_weight * volume_loss
        
        return {
            'false_air_rate': false_air_rate,
            'volume_loss': volume_loss,
            'structure_recall': struct_recall,
            'total': total,
        }


class AsymmetricStructureLoss(nn.Module):
    """Asymmetric CE that penalizes structure->air more than air->structure."""
    
    def __init__(self, structure_to_air_weight=10.0, air_to_structure_weight=1.0):
        super().__init__()
        self.structure_to_air_weight = structure_to_air_weight
        self.air_to_structure_weight = air_to_structure_weight
    
    def forward(self, logits, targets, air_tokens):
        predictions = logits.argmax(dim=1)
        
        is_air_tgt = torch.isin(targets, air_tokens)
        is_air_pred = torch.isin(predictions, air_tokens)
        is_struct_tgt = ~is_air_tgt
        is_struct_pred = ~is_air_pred
        
        weights = torch.ones_like(targets, dtype=torch.float)
        
        # Structure->air: HEAVY penalty (erases building)
        struct_to_air = is_struct_tgt & is_air_pred
        weights[struct_to_air] = self.structure_to_air_weight
        
        # Air->structure: light penalty (adds extra blocks)
        air_to_struct = is_air_tgt & is_struct_pred
        weights[air_to_struct] = self.air_to_structure_weight
        
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        return (ce_loss * weights).sum() / weights.sum()


class VQVAEv4(nn.Module):
    """VQ-VAE v4 with trainable embeddings, embedding-aware loss, and SHAPE PRESERVATION."""
    
    def __init__(self, vocab_size, emb_dim, hidden_dims, latent_dim, num_codes,
                 pretrained_emb, embedding_loss_alpha=0.5, stability_weight=0.01,
                 diversity_weight=0.001, false_air_weight=5.0, volume_weight=2.0,
                 structure_to_air_weight=10.0, dropout=0.1, commitment_cost=0.5, 
                 ema_decay=0.99):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.embedding_loss_alpha = embedding_loss_alpha
        self.stability_weight = stability_weight
        self.diversity_weight = diversity_weight
        self.false_air_weight = false_air_weight
        self.volume_weight = volume_weight
        self.train_embeddings = False  # Start frozen
        
        # Embeddings
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False  # Start frozen
        self.register_buffer('original_embeddings', torch.from_numpy(pretrained_emb.copy()))
        
        # Encoder 32->8
        self.encoder = EncoderV4(emb_dim, hidden_dims, latent_dim, dropout)
        
        # Quantizer
        self.quantizer = VectorQuantizerEMA(num_codes, latent_dim, commitment_cost, ema_decay)
        
        # Decoder 8->32
        self.decoder = DecoderV4(latent_dim, list(reversed(hidden_dims)), vocab_size, dropout)
        
        # Shape preservation loss (NEW)
        self.shape_loss = ShapePreservationLoss(false_air_weight, volume_weight)
        
        # Asymmetric loss (NEW)
        self.asymmetric_loss = AsymmetricStructureLoss(structure_to_air_weight)
        
        # Similarity cache
        self._sim_matrix = None
        self._sim_valid = False
    
    def set_train_embeddings(self, train: bool):
        """Enable/disable embedding training for phased training."""
        self.train_embeddings = train
        self.block_emb.weight.requires_grad = train
        if train:
            self._sim_valid = False
    
    def get_similarity_matrix(self):
        if not self._sim_valid or self._sim_matrix is None:
            self._sim_matrix = compute_similarity_matrix(self.block_emb.weight.detach())
            self._sim_valid = True
        return self._sim_matrix
    
    def forward(self, block_ids):
        x = self.block_emb(block_ids).permute(0, 4, 1, 2, 3).contiguous()
        z_e = self.encoder(x)
        z_q, vq_loss, indices = self.quantizer(z_e)
        logits = self.decoder(z_q)
        return {'logits': logits, 'vq_loss': vq_loss, 'indices': indices}
    
    def compute_embedding_regularization(self):
        current = self.block_emb.weight
        original = self.original_embeddings
        
        stability = F.mse_loss(current, original)
        
        normed = F.normalize(current, dim=1)
        sim = normed @ normed.t()
        off_diag = 1 - torch.eye(current.size(0), device=current.device)
        avg_sim = (sim * off_diag).sum() / off_diag.sum()
        diversity = F.relu(avg_sim - 0.3)
        
        return {'stability': stability, 'diversity': diversity}
    
    def compute_loss(self, block_ids, air_tokens, structure_weight, 
                     use_emb_loss=True, use_shape_loss=True, use_asymmetric_loss=True):
        out = self(block_ids)
        logits = out['logits']
        
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, self.vocab_size)
        targets_flat = block_ids.view(-1)
        
        # Air tokens for structure masking
        air_dev = air_tokens.to(targets_flat.device)
        is_air = torch.isin(targets_flat, air_dev)
        is_struct = ~is_air
        
        # === PRIMARY LOSS: Asymmetric Cross-Entropy ===
        if use_asymmetric_loss:
            ce_loss = self.asymmetric_loss(logits_flat, targets_flat, air_dev)
        else:
            weights = torch.ones_like(targets_flat, dtype=torch.float)
            weights[is_struct] = structure_weight
            ce = F.cross_entropy(logits_flat, targets_flat, reduction='none')
            ce_loss = (weights * ce).sum() / weights.sum()
        
        # === SHAPE PRESERVATION LOSS ===
        if use_shape_loss:
            shape = self.shape_loss(logits_flat, targets_flat, air_dev)
            shape_loss = shape['total']
            false_air_rate = shape['false_air_rate']
            struct_recall = shape['structure_recall']
        else:
            shape_loss = torch.tensor(0.0, device=block_ids.device)
            false_air_rate = torch.tensor(0.0, device=block_ids.device)
            struct_recall = torch.tensor(0.0, device=block_ids.device)
        
        # === EMBEDDING SIMILARITY LOSS ===
        if use_emb_loss and self.embedding_loss_alpha > 0:
            probs = F.softmax(logits_flat / 0.1, dim=1)
            emb_normed = F.normalize(self.block_emb.weight, dim=1)
            pred_emb = probs @ emb_normed
            target_emb = emb_normed[targets_flat]
            similarity = (pred_emb * target_emb).sum(dim=1)
            emb_loss_raw = 1 - similarity
            
            weights = torch.ones_like(targets_flat, dtype=torch.float)
            weights[is_struct] = structure_weight
            emb_loss = (weights * emb_loss_raw).sum() / weights.sum()
        else:
            emb_loss = torch.tensor(0.0, device=block_ids.device)
        
        # === EMBEDDING REGULARIZATION ===
        if self.train_embeddings:
            reg = self.compute_embedding_regularization()
            stab_loss = self.stability_weight * reg['stability']
            div_loss = self.diversity_weight * reg['diversity']
        else:
            stab_loss = torch.tensor(0.0, device=block_ids.device)
            div_loss = torch.tensor(0.0, device=block_ids.device)
        
        # === TOTAL LOSS ===
        total = (ce_loss + shape_loss + 
                 self.embedding_loss_alpha * emb_loss + 
                 out['vq_loss'] + stab_loss + div_loss)
        
        # === COMPUTE METRICS ===
        with torch.no_grad():
            preds = logits_flat.argmax(dim=1)
            correct = (preds == targets_flat).float()
            
            exact_acc = correct.mean()
            struct_exact = correct[is_struct].mean() if is_struct.any() else torch.tensor(0.0)
            air_exact = correct[is_air].mean() if is_air.any() else torch.tensor(0.0)
            
            # Similarity-weighted accuracy
            sim_matrix = self.get_similarity_matrix().to(preds.device)
            sim_scores = sim_matrix[preds, targets_flat]
            sim_acc = sim_scores.mean()
            struct_sim = sim_scores[is_struct].mean() if is_struct.any() else torch.tensor(0.0)
            
            # Volume ratio
            is_air_pred = torch.isin(preds, air_dev)
            orig_vol = is_struct.float().sum()
            pred_vol = (~is_air_pred).float().sum()
            vol_ratio = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0)
        
        return {
            'loss': total,
            'ce_loss': ce_loss,
            'shape_loss': shape_loss,
            'emb_loss': emb_loss,
            'vq_loss': out['vq_loss'],
            'stab_loss': stab_loss,
            'div_loss': div_loss,
            'exact_acc': exact_acc,
            'struct_exact': struct_exact,
            'air_exact': air_exact,
            'sim_acc': sim_acc,
            'struct_sim': struct_sim,
            # Shape preservation metrics (NEW - KEY METRICS)
            'false_air_rate': false_air_rate,  # Want LOW (<10%)
            'struct_recall': struct_recall,     # Want HIGH (>90%)
            'vol_ratio': vol_ratio,             # Want CLOSE TO 1.0
        }


print("VQ-VAE v4 architecture with shape preservation defined!")

In [None]:
# ============================================================
# CELL 6: Training Functions with Shape Preservation Metrics
# ============================================================

def train_epoch(model, loader, optimizer, scaler, device, air_tokens, structure_weight,
                use_emb_loss, use_shape_loss, use_asymmetric_loss):
    model.train()
    model.quantizer.reset_epoch_stats()
    
    metrics = {k: 0.0 for k in ['loss', 'ce', 'shape', 'emb', 'vq', 'stab', 'div',
                                 'exact', 'struct_exact', 'air_exact', 'sim', 'struct_sim',
                                 'false_air', 'struct_recall', 'vol_ratio']}
    n = 0
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, structure_weight, 
                                     use_emb_loss, use_shape_loss, use_asymmetric_loss)
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        metrics['loss'] += out['loss'].item()
        metrics['ce'] += out['ce_loss'].item()
        metrics['shape'] += out['shape_loss'].item()
        metrics['emb'] += out['emb_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['stab'] += out['stab_loss'].item()
        metrics['div'] += out['div_loss'].item()
        metrics['exact'] += out['exact_acc'].item()
        metrics['struct_exact'] += out['struct_exact'].item()
        metrics['air_exact'] += out['air_exact'].item()
        metrics['sim'] += out['sim_acc'].item()
        metrics['struct_sim'] += out['struct_sim'].item()
        # NEW shape preservation metrics
        metrics['false_air'] += out['false_air_rate'].item()
        metrics['struct_recall'] += out['struct_recall'].item()
        metrics['vol_ratio'] += out['vol_ratio'].item()
        n += 1
    
    metrics['cb_usage'] = model.quantizer.get_usage_fraction()
    metrics['perplexity'] = model.quantizer.get_perplexity()
    
    return {k: v/n if k not in ['cb_usage', 'perplexity'] else v for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens, structure_weight, 
             use_emb_loss, use_shape_loss, use_asymmetric_loss):
    model.eval()
    model.quantizer.reset_epoch_stats()
    
    metrics = {k: 0.0 for k in ['loss', 'ce', 'shape', 'emb', 'vq',
                                 'exact', 'struct_exact', 'air_exact', 'sim', 'struct_sim',
                                 'false_air', 'struct_recall', 'vol_ratio']}
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, structure_weight,
                                     use_emb_loss, use_shape_loss, use_asymmetric_loss)
        
        metrics['loss'] += out['loss'].item()
        metrics['ce'] += out['ce_loss'].item()
        metrics['shape'] += out['shape_loss'].item()
        metrics['emb'] += out['emb_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['exact'] += out['exact_acc'].item()
        metrics['struct_exact'] += out['struct_exact'].item()
        metrics['air_exact'] += out['air_exact'].item()
        metrics['sim'] += out['sim_acc'].item()
        metrics['struct_sim'] += out['struct_sim'].item()
        # NEW shape preservation metrics
        metrics['false_air'] += out['false_air_rate'].item()
        metrics['struct_recall'] += out['struct_recall'].item()
        metrics['vol_ratio'] += out['vol_ratio'].item()
        n += 1
    
    metrics['cb_usage'] = model.quantizer.get_usage_fraction()
    metrics['perplexity'] = model.quantizer.get_perplexity()
    
    return {k: v/n if k not in ['cb_usage', 'perplexity'] else v for k, v in metrics.items()}


print("Training functions with shape preservation metrics defined!")

In [None]:
# ============================================================
# CELL 7: Create Model and Optimizer
# ============================================================

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

# Data loaders
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

# Create model with shape preservation
model = VQVAEv4(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    latent_dim=LATENT_DIM,
    num_codes=NUM_CODEBOOK_ENTRIES,
    pretrained_emb=v3_embeddings,
    embedding_loss_alpha=EMBEDDING_LOSS_ALPHA,
    stability_weight=STABILITY_WEIGHT,
    diversity_weight=DIVERSITY_WEIGHT,
    false_air_weight=FALSE_AIR_WEIGHT,
    volume_weight=VOLUME_WEIGHT,
    structure_to_air_weight=STRUCTURE_TO_AIR_WEIGHT,
    dropout=DROPOUT,
    commitment_cost=COMMITMENT_COST,
    ema_decay=EMA_DECAY,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"Embeddings trainable: {model.train_embeddings}")

# Optimizer with separate LR for embeddings
emb_params = list(model.block_emb.parameters())
other_params = [p for n, p in model.named_parameters() if 'block_emb' not in n and p.requires_grad]

optimizer = optim.AdamW([
    {'params': other_params, 'lr': BASE_LR},
    {'params': emb_params, 'lr': BASE_LR * EMBEDDING_LR_SCALE},
], weight_decay=1e-5)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"\nOptimizer: AdamW")
print(f"  Base LR: {BASE_LR}")
print(f"  Embedding LR: {BASE_LR * EMBEDDING_LR_SCALE}")
print(f"\nShape Preservation:")
print(f"  False air weight: {FALSE_AIR_WEIGHT}")
print(f"  Volume weight: {VOLUME_WEIGHT}")
print(f"  Structure->air weight: {STRUCTURE_TO_AIR_WEIGHT}")

In [None]:
# ============================================================
# CELL 8: Training Loop with Shape Preservation
# ============================================================

print("="*70)
print("VQ-VAE V4 TRAINING - SHAPE FIRST, DETAILS SECOND")
print("="*70)
print(f"Phase 1: Warmup (epochs 1-{WARMUP_EPOCHS}) - Embeddings FROZEN")
print(f"Phase 2: Full (epochs {WARMUP_EPOCHS+1}-{TOTAL_EPOCHS}) - Embeddings TRAINABLE")
print(f"\nShape Preservation Enabled:")
print(f"  - Asymmetric CE: structure->air penalized {STRUCTURE_TO_AIR_WEIGHT}x")
print(f"  - False air penalty: {FALSE_AIR_WEIGHT}x")
print(f"  - Volume preservation: {VOLUME_WEIGHT}x")
print()

history = {
    'train_loss': [], 'train_struct_exact': [], 'train_struct_sim': [],
    'train_cb_usage': [], 'train_perplexity': [],
    'train_false_air': [], 'train_struct_recall': [], 'train_vol_ratio': [],
    'val_loss': [], 'val_struct_exact': [], 'val_struct_sim': [],
    'val_cb_usage': [], 'val_perplexity': [],
    'val_false_air': [], 'val_struct_recall': [], 'val_vol_ratio': [],
    'phase': [],
}

best_struct_recall = 0  # NEW: track best shape preservation, not similarity
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    # Phased training
    if epoch < WARMUP_EPOCHS:
        phase = "warmup"
        model.set_train_embeddings(False)
        use_emb_loss = False
    else:
        phase = "full"
        model.set_train_embeddings(True)
        use_emb_loss = True
    
    # Train
    train_m = train_epoch(model, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, STRUCTURE_WEIGHT, use_emb_loss,
                          USE_SHAPE_LOSS, USE_ASYMMETRIC_LOSS)
    
    # Validate
    val_m = validate(model, val_loader, device, AIR_TOKENS_TENSOR, STRUCTURE_WEIGHT,
                     use_emb_loss, USE_SHAPE_LOSS, USE_ASYMMETRIC_LOSS)
    
    # Record
    history['train_loss'].append(train_m['loss'])
    history['train_struct_exact'].append(train_m['struct_exact'])
    history['train_struct_sim'].append(train_m['struct_sim'])
    history['train_cb_usage'].append(train_m['cb_usage'])
    history['train_perplexity'].append(train_m['perplexity'])
    history['train_false_air'].append(train_m['false_air'])
    history['train_struct_recall'].append(train_m['struct_recall'])
    history['train_vol_ratio'].append(train_m['vol_ratio'])
    
    history['val_loss'].append(val_m['loss'])
    history['val_struct_exact'].append(val_m['struct_exact'])
    history['val_struct_sim'].append(val_m['struct_sim'])
    history['val_cb_usage'].append(val_m['cb_usage'])
    history['val_perplexity'].append(val_m['perplexity'])
    history['val_false_air'].append(val_m['false_air'])
    history['val_struct_recall'].append(val_m['struct_recall'])
    history['val_vol_ratio'].append(val_m['vol_ratio'])
    history['phase'].append(phase)
    
    # Best model - now track STRUCTURE RECALL (shape preservation) as the key metric
    if val_m['struct_recall'] > best_struct_recall:
        best_struct_recall = val_m['struct_recall']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v4_best.pt")
    
    # Log - now with shape preservation metrics
    emb_status = "FROZEN" if phase == "warmup" else "TRAIN"
    print(f"Epoch {epoch+1:2d} [{phase:6s}] | "
          f"Recall: {train_m['struct_recall']:.1%}/{val_m['struct_recall']:.1%} | "
          f"FalseAir: {train_m['false_air']:.1%}/{val_m['false_air']:.1%} | "
          f"Vol: {val_m['vol_ratio']:.2f} | "
          f"Exact: {val_m['struct_exact']:.1%} | "
          f"CB: {val_m['cb_usage']:.0%}")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val struct_recall: {best_struct_recall:.1%} at epoch {best_epoch}")
print(f"\n*** KEY METRIC: Structure Recall = {best_struct_recall:.1%} ***")
print("(This measures: of all blocks that SHOULD be structure, how many were NOT erased)")

In [None]:
# ============================================================
# CELL 9: Plot Training Curves - Shape Preservation Focus
# ============================================================

fig, axes = plt.subplots(2, 4, figsize=(20, 8))

epochs = range(1, TOTAL_EPOCHS + 1)

# Plot 1: STRUCTURE RECALL (THE KEY METRIC)
ax = axes[0, 0]
ax.plot(epochs, history['train_struct_recall'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_struct_recall'], 'r--', label='Val', linewidth=2)
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':', label='Warmup end')
ax.axhline(y=0.9, color='green', linestyle='--', alpha=0.5, label='Target (90%)')
ax.set_title('Structure Recall (KEY METRIC)', fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Recall')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Plot 2: FALSE AIR RATE (lower is better)
ax = axes[0, 1]
ax.plot(epochs, history['train_false_air'], 'b-', label='Train')
ax.plot(epochs, history['val_false_air'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.axhline(y=0.1, color='green', linestyle='--', alpha=0.5, label='Target (<10%)')
ax.set_title('False Air Rate (lower=better)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Rate')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Plot 3: Volume Ratio (want close to 1.0)
ax = axes[0, 2]
ax.plot(epochs, history['train_vol_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_vol_ratio'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.axhline(y=1.0, color='green', linestyle='--', alpha=0.5, label='Target (1.0)')
ax.set_title('Volume Ratio (1.0=perfect)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Ratio')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Structure accuracy (exact vs similarity)
ax = axes[0, 3]
ax.plot(epochs, history['train_struct_exact'], 'b-', label='Train Exact')
ax.plot(epochs, history['val_struct_exact'], 'b--', label='Val Exact')
ax.plot(epochs, history['train_struct_sim'], 'g-', label='Train Sim')
ax.plot(epochs, history['val_struct_sim'], 'g--', label='Val Sim')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Structure Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 5: Loss
ax = axes[1, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Total Loss')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 6: Codebook usage
ax = axes[1, 1]
ax.plot(epochs, history['train_cb_usage'], 'b-', label='Train')
ax.plot(epochs, history['val_cb_usage'], 'r--', label='Val')
ax.axhline(y=0.3, color='gray', linestyle='--', alpha=0.5)
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Codebook Usage')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 7: Perplexity
ax = axes[1, 2]
ax.plot(epochs, history['train_perplexity'], 'b-', label='Train')
ax.plot(epochs, history['val_perplexity'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Codebook Perplexity')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 8: Final comparison bar - SHAPE PRESERVATION FOCUS
ax = axes[1, 3]
final_metrics = {
    'Struct\nRecall': history['val_struct_recall'][-1],
    '1-False\nAir': 1 - history['val_false_air'][-1],
    'Exact\nAcc': history['val_struct_exact'][-1],
    'Sim\nAcc': history['val_struct_sim'][-1],
}
colors = ['green', 'orange', 'blue', 'purple']
bars = ax.bar(final_metrics.keys(), final_metrics.values(), color=colors)
ax.set_title('Final Metrics (Shape Focus)')
ax.set_ylabel('Score')
ax.set_ylim(0, 1)
for bar, val in zip(bars, final_metrics.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{val:.1%}', ha='center', fontsize=9)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v4_training.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("\n" + "="*60)
print("SHAPE PRESERVATION SUMMARY")
print("="*60)
print(f"Structure Recall: {history['val_struct_recall'][-1]:.1%} (target: >90%)")
print(f"False Air Rate:   {history['val_false_air'][-1]:.1%} (target: <10%)")
print(f"Volume Ratio:     {history['val_vol_ratio'][-1]:.2f} (target: ~1.0)")
print(f"Exact Accuracy:   {history['val_struct_exact'][-1]:.1%}")
print(f"Sim Accuracy:     {history['val_struct_sim'][-1]:.1%}")

In [None]:
# ============================================================
# CELL 10: Save Results and Checkpoint
# ============================================================

results = {
    'config': {
        'hidden_dims': HIDDEN_DIMS,
        'latent_dim': LATENT_DIM,
        'num_codes': NUM_CODEBOOK_ENTRIES,
        'total_epochs': TOTAL_EPOCHS,
        'warmup_epochs': WARMUP_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'embedding_lr_scale': EMBEDDING_LR_SCALE,
        'embedding_loss_alpha': EMBEDDING_LOSS_ALPHA,
        'stability_weight': STABILITY_WEIGHT,
        'diversity_weight': DIVERSITY_WEIGHT,
        'structure_weight': STRUCTURE_WEIGHT,
        # NEW shape preservation config
        'false_air_weight': FALSE_AIR_WEIGHT,
        'volume_weight': VOLUME_WEIGHT,
        'structure_to_air_weight': STRUCTURE_TO_AIR_WEIGHT,
        'use_shape_loss': USE_SHAPE_LOSS,
        'use_asymmetric_loss': USE_ASYMMETRIC_LOSS,
        'seed': SEED,
    },
    'results': {
        # Shape preservation metrics (THE KEY METRICS)
        'best_struct_recall': float(best_struct_recall),
        'best_epoch': best_epoch,
        'final_struct_recall': float(history['val_struct_recall'][-1]),
        'final_false_air_rate': float(history['val_false_air'][-1]),
        'final_vol_ratio': float(history['val_vol_ratio'][-1]),
        # Original metrics
        'final_struct_exact': float(history['val_struct_exact'][-1]),
        'final_struct_sim': float(history['val_struct_sim'][-1]),
        'final_cb_usage': float(history['val_cb_usage'][-1]),
        'final_perplexity': float(history['val_perplexity'][-1]),
        'training_time_min': float(train_time / 60),
    },
    'history': {k: [float(x) if isinstance(x, (int, float)) else x for x in v] 
                for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v4_results.json", 'w') as f:
    json.dump(results, f, indent=2)

# Save complete checkpoint with all metadata for visualization script
checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'vocab_size': VOCAB_SIZE,
        'emb_dim': EMBEDDING_DIM,
        'hidden_dims': HIDDEN_DIMS,
        'latent_dim': LATENT_DIM,
        'num_codes': NUM_CODEBOOK_ENTRIES,
        'embedding_loss_alpha': EMBEDDING_LOSS_ALPHA,
        'stability_weight': STABILITY_WEIGHT,
        'diversity_weight': DIVERSITY_WEIGHT,
        'false_air_weight': FALSE_AIR_WEIGHT,
        'volume_weight': VOLUME_WEIGHT,
        'structure_to_air_weight': STRUCTURE_TO_AIR_WEIGHT,
        'dropout': DROPOUT,
        'commitment_cost': COMMITMENT_COST,
        'ema_decay': EMA_DECAY,
    },
    'air_tokens': AIR_TOKENS_LIST,
    'best_struct_recall': float(best_struct_recall),
    'best_epoch': best_epoch,
    'training_time_min': float(train_time / 60),
}

# Save best and final checkpoints
torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v4_best_checkpoint.pt")

checkpoint['model_state_dict'] = model.state_dict()  # Update to final state
torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v4_final_checkpoint.pt")

# Also save just the state dict for backwards compatibility
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v4_final.pt")

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/vqvae_v4_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v4_best_checkpoint.pt (full checkpoint)")
print(f"  - {OUTPUT_DIR}/vqvae_v4_final_checkpoint.pt (full checkpoint)")
print(f"  - {OUTPUT_DIR}/vqvae_v4_final.pt (state dict only)")
print(f"  - {OUTPUT_DIR}/vqvae_v4_training.png")

print("\n" + "="*70)
print("FINAL RESULTS - SHAPE PRESERVATION")
print("="*70)
print(f"Best structure recall:     {best_struct_recall:.1%} (epoch {best_epoch})")
print(f"Final structure recall:    {history['val_struct_recall'][-1]:.1%}")
print(f"Final false air rate:      {history['val_false_air'][-1]:.1%}")
print(f"Final volume ratio:        {history['val_vol_ratio'][-1]:.2f}")
print(f"Final structure exact acc: {history['val_struct_exact'][-1]:.1%}")
print(f"Final codebook usage:      {history['val_cb_usage'][-1]:.1%}")
print(f"Training time:             {train_time/60:.1f} minutes")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)
print("Structure Recall: % of original building blocks NOT erased")
print("  - Target: >90% (buildings should be preserved)")
print("  - If low: model is still erasing buildings -> increase shape loss weights")
print()
print("False Air Rate: % of structure blocks wrongly predicted as air")
print("  - Target: <10% (minimal building erasure)")
print("  - This is 1 - Structure Recall")
print()
print("Volume Ratio: predicted_volume / original_volume")
print("  - Target: ~1.0 (same amount of blocks)")
print("  - <1.0 means buildings shrunk, >1.0 means extra blocks added")

print("\n--- NEXT STEPS ---")
print("Download vqvae_v4_best_checkpoint.pt and use with visualization script:")
print("  python scripts/visualize_reconstruction_mcp.py \\")
print("      --checkpoint vqvae_v4_best_checkpoint.pt \\")
print("      --h5-file path/to/build.h5 \\")
print("      --output commands.txt")