# VQ-VAE v4 Training - Shape Preservation Focus (Google Colab)

## Philosophy: SHAPE FIRST, DETAILS SECOND

The key insight from Phase 2.5 validation: **buildings were disappearing** in reconstructions.
The model was predicting AIR where there should be structure blocks.

This version addresses that with:
1. **Asymmetric Cross-Entropy**: structure->air errors penalized 10x more than air->structure
2. **False Air Penalty**: explicit loss for predicting air where structure exists
3. **Volume Preservation**: penalize if reconstruction has fewer blocks than original
4. **Structure Recall**: new key metric tracking shape preservation

## Key Improvements over v3

| Change | v3 | v4 |
|--------|-----|-----|
| Latent grid | 4x4x4 (512:1) | **8x8x8 (64:1)** |
| Embeddings | Frozen | **Trainable** |
| Loss | CE only | **CE + embedding similarity + shape preservation** |
| Metric | Exact-match | **Structure Recall (shape preservation)** |
| Key focus | Accuracy | **Don't erase buildings!** |

## New Metrics

- **Structure Recall**: Of blocks that SHOULD be structure, how many are NOT erased? (Target: >90%)
- **False Air Rate**: What % of structure was wrongly predicted as air? (Target: <10%)
- **Volume Ratio**: predicted_volume / original_volume (Target: ~1.0)

## Setup

Upload the following to your Google Drive under `minecraft_ai/`:
- `splits/train/` - Training H5 files
- `splits/val/` - Validation H5 files  
- `tok2block.json` - Vocabulary file
- `block_embeddings_v3.npy` - V3 compositional embeddings

## Configuration

- **Epochs**: 15 total (5 warmup + 10 full)
- **Structure weight**: 50x (increased from 10x)
- **False air weight**: 5x
- **Structure->air weight**: 10x (asymmetric CE)

In [None]:
# ============================================================
# CELL 0: Mount Google Drive & Prevent Idle Disconnect
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

# Verify mount
import os
print("Drive mounted. Checking for minecraft_ai folder...")
drive_path = "/content/drive/MyDrive/minecraft_ai"
if os.path.exists(drive_path):
    print(f"Found: {drive_path}")
    print(f"Contents: {os.listdir(drive_path)}")
else:
    print(f"WARNING: {drive_path} not found!")
    print("Please create this folder and upload your data.")

# --- Prevent Colab from disconnecting due to idle timeout ---
from IPython.display import display, Javascript

keep_alive_js = Javascript('''
function KeepAlive() {
    console.log("Keep-alive ping at " + new Date().toLocaleTimeString());
    var buttons = document.querySelectorAll("colab-connect-button, colab-toolbar-button#connect");
    buttons.forEach(function(btn) {
        if (btn) btn.click();
    });
}
setInterval(KeepAlive, 60000);
console.log("Keep-alive script activated - will ping every 60 seconds");
''')

display(keep_alive_js)
print("\nKeep-alive script activated to prevent idle disconnect.")

In [None]:
# ============================================================
# CELL 1: Imports
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================

# === Google Drive Paths ===
DRIVE_BASE = "/content/drive/MyDrive/minecraft_ai"

DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/tok2block.json"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/block_embeddings_v3.npy"

# Output to Drive so results persist after runtime ends
OUTPUT_DIR = f"{DRIVE_BASE}/output/vqvae_v4"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# === V4 Model Architecture ===
HIDDEN_DIMS = [96, 192]  # 2 stages for 32->8 (not 3 stages for 32->4)
LATENT_DIM = 256
NUM_CODEBOOK_ENTRIES = 512
DROPOUT = 0.1

# === VQ-VAE Settings ===
COMMITMENT_COST = 0.5
EMA_DECAY = 0.99
STRUCTURE_WEIGHT = 50.0  # INCREASED from 10.0 for shape preservation

# === V4 Specific Settings ===
EMBEDDING_LOSS_ALPHA = 0.5  # Weight for embedding similarity loss
STABILITY_WEIGHT = 0.01    # Embedding stability regularization
DIVERSITY_WEIGHT = 0.001   # Embedding diversity regularization

# === SHAPE PRESERVATION SETTINGS (NEW) ===
# These prevent the "buildings disappearing" problem
FALSE_AIR_WEIGHT = 5.0     # Heavily penalize predicting air where structure exists
VOLUME_WEIGHT = 2.0        # Penalize losing structure volume
STRUCTURE_TO_AIR_WEIGHT = 10.0  # Asymmetric CE: structure->air errors 10x worse
USE_SHAPE_LOSS = True      # Enable shape preservation loss
USE_ASYMMETRIC_LOSS = True # Enable asymmetric cross-entropy

# === Training ===
TOTAL_EPOCHS = 15
WARMUP_EPOCHS = 5  # Freeze embeddings for first N epochs
BATCH_SIZE = 4     # Reduced due to larger latent grid
BASE_LR = 3e-4
EMBEDDING_LR_SCALE = 0.1  # Embeddings train 10x slower
USE_AMP = True
GRAD_ACCUM_STEPS = 4

SEED = 42
NUM_WORKERS = 2

print("VQ-VAE v4 Configuration (Colab):")
print(f"  Drive base: {DRIVE_BASE}")
print(f"  Output: {OUTPUT_DIR}")
print(f"  Latent grid: 8x8x8 (64:1 compression)")
print(f"  Hidden dims: {HIDDEN_DIMS}")
print(f"  Epochs: {TOTAL_EPOCHS} ({WARMUP_EPOCHS} warmup + {TOTAL_EPOCHS - WARMUP_EPOCHS} full)")
print(f"  Embedding loss alpha: {EMBEDDING_LOSS_ALPHA}")
print(f"  Structure weight: {STRUCTURE_WEIGHT}x")
print(f"\nShape Preservation (NEW):")
print(f"  False air weight: {FALSE_AIR_WEIGHT}")
print(f"  Volume weight: {VOLUME_WEIGHT}")
print(f"  Structure->air weight: {STRUCTURE_TO_AIR_WEIGHT}")

# Verify paths exist
print("\nPath verification:")
for path, name in [(DATA_DIR, "train"), (VAL_DIR, "val"), (VOCAB_PATH, "vocab"), 
                   (V3_EMBEDDINGS_PATH, "V3 emb")]:
    exists = os.path.exists(path)
    status = "OK" if exists else "MISSING"
    print(f"  {name}: {status}")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Embeddings
# ============================================================

with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Load V3 embeddings (compositional)
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

In [None]:
# ============================================================
# CELL 4: Dataset
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
# ============================================================
# CELL 5: VQ-VAE v4 Architecture with Shape Preservation
# ============================================================

class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class VectorQuantizerEMA(nn.Module):
    """EMA-based vector quantizer with dead code reset."""
    
    def __init__(self, num_codes, latent_dim, commitment_cost=0.5, ema_decay=0.99):
        super().__init__()
        self.num_codes = num_codes
        self.latent_dim = latent_dim
        self.commitment_cost = commitment_cost
        self.ema_decay = ema_decay
        
        self.register_buffer('codebook', torch.randn(num_codes, latent_dim))
        self.codebook.data.uniform_(-1/num_codes, 1/num_codes)
        self.register_buffer('ema_cluster_size', torch.zeros(num_codes))
        self.register_buffer('ema_embed_sum', torch.zeros(num_codes, latent_dim))
        self.register_buffer('code_usage', torch.zeros(num_codes))
    
    def reset_epoch_stats(self):
        self.code_usage.zero_()
    
    def get_usage_fraction(self):
        return (self.code_usage > 0).float().mean().item()
    
    def get_perplexity(self):
        if self.code_usage.sum() == 0:
            return 0.0
        probs = self.code_usage / self.code_usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        return entropy.exp().item()
    
    def forward(self, z_e):
        z_e_perm = z_e.permute(0, 2, 3, 4, 1).contiguous()
        flat = z_e_perm.view(-1, self.latent_dim)
        flat_f32 = flat.float()
        codebook_f32 = self.codebook.float()
        
        d = (flat_f32.pow(2).sum(1, keepdim=True) 
             + codebook_f32.pow(2).sum(1) 
             - 2 * flat_f32 @ codebook_f32.t())
        
        indices = d.argmin(dim=1)
        
        with torch.no_grad():
            for idx in indices.unique():
                self.code_usage[idx] += (indices == idx).sum()
        
        z_q_flat = self.codebook[indices]
        z_q_perm = z_q_flat.view(z_e_perm.shape)
        
        if self.training:
            with torch.no_grad():
                encodings = F.one_hot(indices, self.num_codes).float()
                batch_size = encodings.sum(0)
                
                self.ema_cluster_size = self.ema_decay * self.ema_cluster_size + (1 - self.ema_decay) * batch_size
                batch_sum = encodings.t() @ flat_f32
                self.ema_embed_sum = self.ema_decay * self.ema_embed_sum + (1 - self.ema_decay) * batch_sum
                
                n = self.ema_cluster_size.sum()
                smoothed = (self.ema_cluster_size + 1e-5) / (n + self.num_codes * 1e-5) * n
                self.codebook.data = self.ema_embed_sum / smoothed.unsqueeze(1)
                
                # Dead code reset
                dead = batch_size < 2
                if dead.any() and flat_f32.size(0) > 0:
                    n_dead = dead.sum().item()
                    rand_idx = torch.randint(0, flat_f32.size(0), (n_dead,), device=flat_f32.device)
                    self.codebook.data[dead] = flat_f32[rand_idx]
                    self.ema_cluster_size[dead] = 1
                    self.ema_embed_sum[dead] = flat_f32[rand_idx]
        
        commitment_loss = F.mse_loss(z_e_perm, z_q_perm.detach())
        vq_loss = self.commitment_cost * commitment_loss
        
        z_q_st = z_e_perm + (z_q_perm - z_e_perm).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        return z_q, vq_loss, indices.view(z_e_perm.shape[:-1])


class EncoderV4(nn.Module):
    """32x32x32 -> 8x8x8 encoder (2 stages instead of 3)."""
    
    def __init__(self, in_channels, hidden_dims, latent_dim, dropout=0.1):
        super().__init__()
        layers = []
        current = in_channels
        
        for h in hidden_dims:
            layers.extend([
                nn.Conv3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
                ResidualBlock3D(h),
            ])
            current = h
        
        # Extra capacity at 8x8x8
        layers.extend([
            ResidualBlock3D(current),
            ResidualBlock3D(current),
            nn.Conv3d(current, latent_dim, 3, padding=1),
        ])
        
        self.encoder = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.encoder(x)


class DecoderV4(nn.Module):
    """8x8x8 -> 32x32x32 decoder."""
    
    def __init__(self, latent_dim, hidden_dims, num_blocks, dropout=0.1):
        super().__init__()
        layers = [
            ResidualBlock3D(latent_dim),
            ResidualBlock3D(latent_dim),
        ]
        
        current = latent_dim
        for h in hidden_dims:
            layers.extend([
                ResidualBlock3D(current),
                nn.ConvTranspose3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout),
            ])
            current = h
        
        layers.append(nn.Conv3d(current, num_blocks, 3, padding=1))
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, z_q):
        return self.decoder(z_q)


def compute_similarity_matrix(embeddings):
    """Compute cosine similarity matrix scaled to [0,1]."""
    with torch.no_grad():
        normed = F.normalize(embeddings, dim=1)
        sim = normed @ normed.t()
        return (sim + 1) / 2


class ShapePreservationLoss(nn.Module):
    """Loss functions to prevent buildings from disappearing.
    
    Philosophy: SHAPE FIRST, DETAILS SECOND.
    It's better to predict the wrong block type than to predict air.
    """
    
    def __init__(self, false_air_weight=5.0, volume_weight=2.0):
        super().__init__()
        self.false_air_weight = false_air_weight
        self.volume_weight = volume_weight
    
    def forward(self, logits, targets, air_tokens):
        predictions = logits.argmax(dim=1)
        
        is_struct_orig = ~torch.isin(targets, air_tokens)
        is_air_pred = torch.isin(predictions, air_tokens)
        is_struct_pred = ~is_air_pred
        
        # False air: predicted air where structure existed
        false_air_mask = is_struct_orig & is_air_pred
        if is_struct_orig.sum() > 0:
            false_air_rate = false_air_mask.float().sum() / is_struct_orig.float().sum()
        else:
            false_air_rate = torch.tensor(0.0, device=predictions.device)
        
        # Volume preservation: penalize losing structure volume
        orig_vol = is_struct_orig.float().sum()
        pred_vol = is_struct_pred.float().sum()
        if orig_vol > 0:
            volume_loss = F.relu(orig_vol - pred_vol) / orig_vol
        else:
            volume_loss = torch.tensor(0.0, device=predictions.device)
        
        # Structure recall: fraction of original structure preserved
        true_struct = is_struct_orig & is_struct_pred
        if is_struct_orig.sum() > 0:
            struct_recall = true_struct.float().sum() / is_struct_orig.float().sum()
        else:
            struct_recall = torch.tensor(1.0, device=predictions.device)
        
        total = self.false_air_weight * false_air_rate + self.volume_weight * volume_loss
        
        return {
            'false_air_rate': false_air_rate,
            'volume_loss': volume_loss,
            'structure_recall': struct_recall,
            'total': total,
        }


class AsymmetricStructureLoss(nn.Module):
    """Asymmetric CE that penalizes structure->air more than air->structure."""
    
    def __init__(self, structure_to_air_weight=10.0, air_to_structure_weight=1.0):
        super().__init__()
        self.structure_to_air_weight = structure_to_air_weight
        self.air_to_structure_weight = air_to_structure_weight
    
    def forward(self, logits, targets, air_tokens):
        predictions = logits.argmax(dim=1)
        
        is_air_tgt = torch.isin(targets, air_tokens)
        is_air_pred = torch.isin(predictions, air_tokens)
        is_struct_tgt = ~is_air_tgt
        is_struct_pred = ~is_air_pred
        
        weights = torch.ones_like(targets, dtype=torch.float)
        
        # Structure->air: HEAVY penalty (erases building)
        struct_to_air = is_struct_tgt & is_air_pred
        weights[struct_to_air] = self.structure_to_air_weight
        
        # Air->structure: light penalty (adds extra blocks)
        air_to_struct = is_air_tgt & is_struct_pred
        weights[air_to_struct] = self.air_to_structure_weight
        
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        return (ce_loss * weights).sum() / weights.sum()


class VQVAEv4(nn.Module):
    """VQ-VAE v4 with trainable embeddings, embedding-aware loss, and SHAPE PRESERVATION."""
    
    def __init__(self, vocab_size, emb_dim, hidden_dims, latent_dim, num_codes,
                 pretrained_emb, embedding_loss_alpha=0.5, stability_weight=0.01,
                 diversity_weight=0.001, false_air_weight=5.0, volume_weight=2.0,
                 structure_to_air_weight=10.0, dropout=0.1, commitment_cost=0.5, 
                 ema_decay=0.99):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.embedding_loss_alpha = embedding_loss_alpha
        self.stability_weight = stability_weight
        self.diversity_weight = diversity_weight
        self.false_air_weight = false_air_weight
        self.volume_weight = volume_weight
        self.train_embeddings = False  # Start frozen
        
        # Embeddings
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False  # Start frozen
        self.register_buffer('original_embeddings', torch.from_numpy(pretrained_emb.copy()))
        
        # Encoder 32->8
        self.encoder = EncoderV4(emb_dim, hidden_dims, latent_dim, dropout)
        
        # Quantizer
        self.quantizer = VectorQuantizerEMA(num_codes, latent_dim, commitment_cost, ema_decay)
        
        # Decoder 8->32
        self.decoder = DecoderV4(latent_dim, list(reversed(hidden_dims)), vocab_size, dropout)
        
        # Shape preservation loss (NEW)
        self.shape_loss = ShapePreservationLoss(false_air_weight, volume_weight)
        
        # Asymmetric loss (NEW)
        self.asymmetric_loss = AsymmetricStructureLoss(structure_to_air_weight)
        
        # Similarity cache
        self._sim_matrix = None
        self._sim_valid = False
    
    def set_train_embeddings(self, train: bool):
        """Enable/disable embedding training for phased training."""
        self.train_embeddings = train
        self.block_emb.weight.requires_grad = train
        if train:
            self._sim_valid = False
    
    def get_similarity_matrix(self):
        if not self._sim_valid or self._sim_matrix is None:
            self._sim_matrix = compute_similarity_matrix(self.block_emb.weight.detach())
            self._sim_valid = True
        return self._sim_matrix
    
    def forward(self, block_ids):
        x = self.block_emb(block_ids).permute(0, 4, 1, 2, 3).contiguous()
        z_e = self.encoder(x)
        z_q, vq_loss, indices = self.quantizer(z_e)
        logits = self.decoder(z_q)
        return {'logits': logits, 'vq_loss': vq_loss, 'indices': indices}
    
    def compute_embedding_regularization(self):
        current = self.block_emb.weight
        original = self.original_embeddings
        
        stability = F.mse_loss(current, original)
        
        normed = F.normalize(current, dim=1)
        sim = normed @ normed.t()
        off_diag = 1 - torch.eye(current.size(0), device=current.device)
        avg_sim = (sim * off_diag).sum() / off_diag.sum()
        diversity = F.relu(avg_sim - 0.3)
        
        return {'stability': stability, 'diversity': diversity}
    
    def compute_loss(self, block_ids, air_tokens, structure_weight, 
                     use_emb_loss=True, use_shape_loss=True, use_asymmetric_loss=True):
        out = self(block_ids)
        logits = out['logits']
        
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, self.vocab_size)
        targets_flat = block_ids.view(-1)
        
        # Air tokens for structure masking
        air_dev = air_tokens.to(targets_flat.device)
        is_air = torch.isin(targets_flat, air_dev)
        is_struct = ~is_air
        
        # === PRIMARY LOSS: Asymmetric Cross-Entropy ===
        if use_asymmetric_loss:
            ce_loss = self.asymmetric_loss(logits_flat, targets_flat, air_dev)
        else:
            weights = torch.ones_like(targets_flat, dtype=torch.float)
            weights[is_struct] = structure_weight
            ce = F.cross_entropy(logits_flat, targets_flat, reduction='none')
            ce_loss = (weights * ce).sum() / weights.sum()
        
        # === SHAPE PRESERVATION LOSS ===
        if use_shape_loss:
            shape = self.shape_loss(logits_flat, targets_flat, air_dev)
            shape_loss = shape['total']
            false_air_rate = shape['false_air_rate']
            struct_recall = shape['structure_recall']
        else:
            shape_loss = torch.tensor(0.0, device=block_ids.device)
            false_air_rate = torch.tensor(0.0, device=block_ids.device)
            struct_recall = torch.tensor(0.0, device=block_ids.device)
        
        # === EMBEDDING SIMILARITY LOSS ===
        if use_emb_loss and self.embedding_loss_alpha > 0:
            probs = F.softmax(logits_flat / 0.1, dim=1)
            emb_normed = F.normalize(self.block_emb.weight, dim=1)
            pred_emb = probs @ emb_normed
            target_emb = emb_normed[targets_flat]
            similarity = (pred_emb * target_emb).sum(dim=1)
            emb_loss_raw = 1 - similarity
            
            weights = torch.ones_like(targets_flat, dtype=torch.float)
            weights[is_struct] = structure_weight
            emb_loss = (weights * emb_loss_raw).sum() / weights.sum()
        else:
            emb_loss = torch.tensor(0.0, device=block_ids.device)
        
        # === EMBEDDING REGULARIZATION ===
        if self.train_embeddings:
            reg = self.compute_embedding_regularization()
            stab_loss = self.stability_weight * reg['stability']
            div_loss = self.diversity_weight * reg['diversity']
        else:
            stab_loss = torch.tensor(0.0, device=block_ids.device)
            div_loss = torch.tensor(0.0, device=block_ids.device)
        
        # === TOTAL LOSS ===
        total = (ce_loss + shape_loss + 
                 self.embedding_loss_alpha * emb_loss + 
                 out['vq_loss'] + stab_loss + div_loss)
        
        # === COMPUTE METRICS ===
        with torch.no_grad():
            preds = logits_flat.argmax(dim=1)
            correct = (preds == targets_flat).float()
            
            exact_acc = correct.mean()
            struct_exact = correct[is_struct].mean() if is_struct.any() else torch.tensor(0.0)
            air_exact = correct[is_air].mean() if is_air.any() else torch.tensor(0.0)
            
            # Similarity-weighted accuracy
            sim_matrix = self.get_similarity_matrix().to(preds.device)
            sim_scores = sim_matrix[preds, targets_flat]
            sim_acc = sim_scores.mean()
            struct_sim = sim_scores[is_struct].mean() if is_struct.any() else torch.tensor(0.0)
            
            # Volume ratio
            is_air_pred = torch.isin(preds, air_dev)
            orig_vol = is_struct.float().sum()
            pred_vol = (~is_air_pred).float().sum()
            vol_ratio = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0)
        
        return {
            'loss': total,
            'ce_loss': ce_loss,
            'shape_loss': shape_loss,
            'emb_loss': emb_loss,
            'vq_loss': out['vq_loss'],
            'stab_loss': stab_loss,
            'div_loss': div_loss,
            'exact_acc': exact_acc,
            'struct_exact': struct_exact,
            'air_exact': air_exact,
            'sim_acc': sim_acc,
            'struct_sim': struct_sim,
            # Shape preservation metrics (NEW - KEY METRICS)
            'false_air_rate': false_air_rate,  # Want LOW (<10%)
            'struct_recall': struct_recall,     # Want HIGH (>90%)
            'vol_ratio': vol_ratio,             # Want CLOSE TO 1.0
        }


print("VQ-VAE v4 architecture with shape preservation defined!")

In [None]:
# ============================================================
# CELL 6: Training Functions with Shape Preservation Metrics
# ============================================================

def train_epoch(model, loader, optimizer, scaler, device, air_tokens, structure_weight,
                use_emb_loss, use_shape_loss, use_asymmetric_loss):
    model.train()
    model.quantizer.reset_epoch_stats()
    
    metrics = {k: 0.0 for k in ['loss', 'ce', 'shape', 'emb', 'vq', 'stab', 'div',
                                 'exact', 'struct_exact', 'air_exact', 'sim', 'struct_sim',
                                 'false_air', 'struct_recall', 'vol_ratio']}
    n = 0
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, structure_weight, 
                                     use_emb_loss, use_shape_loss, use_asymmetric_loss)
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        metrics['loss'] += out['loss'].item()
        metrics['ce'] += out['ce_loss'].item()
        metrics['shape'] += out['shape_loss'].item()
        metrics['emb'] += out['emb_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['stab'] += out['stab_loss'].item()
        metrics['div'] += out['div_loss'].item()
        metrics['exact'] += out['exact_acc'].item()
        metrics['struct_exact'] += out['struct_exact'].item()
        metrics['air_exact'] += out['air_exact'].item()
        metrics['sim'] += out['sim_acc'].item()
        metrics['struct_sim'] += out['struct_sim'].item()
        # NEW shape preservation metrics
        metrics['false_air'] += out['false_air_rate'].item()
        metrics['struct_recall'] += out['struct_recall'].item()
        metrics['vol_ratio'] += out['vol_ratio'].item()
        n += 1
    
    metrics['cb_usage'] = model.quantizer.get_usage_fraction()
    metrics['perplexity'] = model.quantizer.get_perplexity()
    
    return {k: v/n if k not in ['cb_usage', 'perplexity'] else v for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens, structure_weight, 
             use_emb_loss, use_shape_loss, use_asymmetric_loss):
    model.eval()
    model.quantizer.reset_epoch_stats()
    
    metrics = {k: 0.0 for k in ['loss', 'ce', 'shape', 'emb', 'vq',
                                 'exact', 'struct_exact', 'air_exact', 'sim', 'struct_sim',
                                 'false_air', 'struct_recall', 'vol_ratio']}
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, structure_weight,
                                     use_emb_loss, use_shape_loss, use_asymmetric_loss)
        
        metrics['loss'] += out['loss'].item()
        metrics['ce'] += out['ce_loss'].item()
        metrics['shape'] += out['shape_loss'].item()
        metrics['emb'] += out['emb_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['exact'] += out['exact_acc'].item()
        metrics['struct_exact'] += out['struct_exact'].item()
        metrics['air_exact'] += out['air_exact'].item()
        metrics['sim'] += out['sim_acc'].item()
        metrics['struct_sim'] += out['struct_sim'].item()
        # NEW shape preservation metrics
        metrics['false_air'] += out['false_air_rate'].item()
        metrics['struct_recall'] += out['struct_recall'].item()
        metrics['vol_ratio'] += out['vol_ratio'].item()
        n += 1
    
    metrics['cb_usage'] = model.quantizer.get_usage_fraction()
    metrics['perplexity'] = model.quantizer.get_perplexity()
    
    return {k: v/n if k not in ['cb_usage', 'perplexity'] else v for k, v in metrics.items()}


print("Training functions with shape preservation metrics defined!")

In [None]:
# ============================================================
# CELL 7: Create Model and Optimizer
# ============================================================

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

# Data loaders
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

# Create model with shape preservation
model = VQVAEv4(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    latent_dim=LATENT_DIM,
    num_codes=NUM_CODEBOOK_ENTRIES,
    pretrained_emb=v3_embeddings,
    embedding_loss_alpha=EMBEDDING_LOSS_ALPHA,
    stability_weight=STABILITY_WEIGHT,
    diversity_weight=DIVERSITY_WEIGHT,
    false_air_weight=FALSE_AIR_WEIGHT,
    volume_weight=VOLUME_WEIGHT,
    structure_to_air_weight=STRUCTURE_TO_AIR_WEIGHT,
    dropout=DROPOUT,
    commitment_cost=COMMITMENT_COST,
    ema_decay=EMA_DECAY,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"Embeddings trainable: {model.train_embeddings}")

# Optimizer with separate LR for embeddings
emb_params = list(model.block_emb.parameters())
other_params = [p for n, p in model.named_parameters() if 'block_emb' not in n and p.requires_grad]

optimizer = optim.AdamW([
    {'params': other_params, 'lr': BASE_LR},
    {'params': emb_params, 'lr': BASE_LR * EMBEDDING_LR_SCALE},
], weight_decay=1e-5)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"\nOptimizer: AdamW")
print(f"  Base LR: {BASE_LR}")
print(f"  Embedding LR: {BASE_LR * EMBEDDING_LR_SCALE}")
print(f"\nShape Preservation:")
print(f"  False air weight: {FALSE_AIR_WEIGHT}")
print(f"  Volume weight: {VOLUME_WEIGHT}")
print(f"  Structure->air weight: {STRUCTURE_TO_AIR_WEIGHT}")

In [None]:
# ============================================================
# CELL 8: Training Loop with Shape Preservation
# ============================================================

print("="*70)
print("VQ-VAE V4 TRAINING - SHAPE FIRST, DETAILS SECOND")
print("="*70)
print(f"Phase 1: Warmup (epochs 1-{WARMUP_EPOCHS}) - Embeddings FROZEN")
print(f"Phase 2: Full (epochs {WARMUP_EPOCHS+1}-{TOTAL_EPOCHS}) - Embeddings TRAINABLE")
print(f"\nShape Preservation Enabled:")
print(f"  - Asymmetric CE: structure->air penalized {STRUCTURE_TO_AIR_WEIGHT}x")
print(f"  - False air penalty: {FALSE_AIR_WEIGHT}x")
print(f"  - Volume preservation: {VOLUME_WEIGHT}x")
print()

history = {
    'train_loss': [], 'train_struct_exact': [], 'train_struct_sim': [],
    'train_cb_usage': [], 'train_perplexity': [],
    'train_false_air': [], 'train_struct_recall': [], 'train_vol_ratio': [],
    'val_loss': [], 'val_struct_exact': [], 'val_struct_sim': [],
    'val_cb_usage': [], 'val_perplexity': [],
    'val_false_air': [], 'val_struct_recall': [], 'val_vol_ratio': [],
    'phase': [],
}

best_struct_recall = 0  # NEW: track best shape preservation, not similarity
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    # Phased training
    if epoch < WARMUP_EPOCHS:
        phase = "warmup"
        model.set_train_embeddings(False)
        use_emb_loss = False
    else:
        phase = "full"
        model.set_train_embeddings(True)
        use_emb_loss = True
    
    # Train
    train_m = train_epoch(model, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, STRUCTURE_WEIGHT, use_emb_loss,
                          USE_SHAPE_LOSS, USE_ASYMMETRIC_LOSS)
    
    # Validate
    val_m = validate(model, val_loader, device, AIR_TOKENS_TENSOR, STRUCTURE_WEIGHT,
                     use_emb_loss, USE_SHAPE_LOSS, USE_ASYMMETRIC_LOSS)
    
    # Record
    history['train_loss'].append(train_m['loss'])
    history['train_struct_exact'].append(train_m['struct_exact'])
    history['train_struct_sim'].append(train_m['struct_sim'])
    history['train_cb_usage'].append(train_m['cb_usage'])
    history['train_perplexity'].append(train_m['perplexity'])
    history['train_false_air'].append(train_m['false_air'])
    history['train_struct_recall'].append(train_m['struct_recall'])
    history['train_vol_ratio'].append(train_m['vol_ratio'])
    
    history['val_loss'].append(val_m['loss'])
    history['val_struct_exact'].append(val_m['struct_exact'])
    history['val_struct_sim'].append(val_m['struct_sim'])
    history['val_cb_usage'].append(val_m['cb_usage'])
    history['val_perplexity'].append(val_m['perplexity'])
    history['val_false_air'].append(val_m['false_air'])
    history['val_struct_recall'].append(val_m['struct_recall'])
    history['val_vol_ratio'].append(val_m['vol_ratio'])
    history['phase'].append(phase)
    
    # Best model - now track STRUCTURE RECALL (shape preservation) as the key metric
    if val_m['struct_recall'] > best_struct_recall:
        best_struct_recall = val_m['struct_recall']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v4_best.pt")
    
    # Log - now with shape preservation metrics
    emb_status = "FROZEN" if phase == "warmup" else "TRAIN"
    print(f"Epoch {epoch+1:2d} [{phase:6s}] | "
          f"Recall: {train_m['struct_recall']:.1%}/{val_m['struct_recall']:.1%} | "
          f"FalseAir: {train_m['false_air']:.1%}/{val_m['false_air']:.1%} | "
          f"Vol: {val_m['vol_ratio']:.2f} | "
          f"Exact: {val_m['struct_exact']:.1%} | "
          f"CB: {val_m['cb_usage']:.0%}")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val struct_recall: {best_struct_recall:.1%} at epoch {best_epoch}")
print(f"\n*** KEY METRIC: Structure Recall = {best_struct_recall:.1%} ***")
print("(This measures: of all blocks that SHOULD be structure, how many were NOT erased)")

In [None]:
# ============================================================
# CELL 9: Plot Training Curves - Shape Preservation Focus
# ============================================================

fig, axes = plt.subplots(2, 4, figsize=(20, 8))

epochs = range(1, TOTAL_EPOCHS + 1)

# Plot 1: STRUCTURE RECALL (THE KEY METRIC)
ax = axes[0, 0]
ax.plot(epochs, history['train_struct_recall'], 'b-', label='Train', linewidth=2)
ax.plot(epochs, history['val_struct_recall'], 'r--', label='Val', linewidth=2)
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':', label='Warmup end')
ax.axhline(y=0.9, color='green', linestyle='--', alpha=0.5, label='Target (90%)')
ax.set_title('Structure Recall (KEY METRIC)', fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Recall')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Plot 2: FALSE AIR RATE (lower is better)
ax = axes[0, 1]
ax.plot(epochs, history['train_false_air'], 'b-', label='Train')
ax.plot(epochs, history['val_false_air'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.axhline(y=0.1, color='green', linestyle='--', alpha=0.5, label='Target (<10%)')
ax.set_title('False Air Rate (lower=better)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Rate')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Plot 3: Volume Ratio (want close to 1.0)
ax = axes[0, 2]
ax.plot(epochs, history['train_vol_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_vol_ratio'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.axhline(y=1.0, color='green', linestyle='--', alpha=0.5, label='Target (1.0)')
ax.set_title('Volume Ratio (1.0=perfect)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Ratio')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Structure accuracy (exact vs similarity)
ax = axes[0, 3]
ax.plot(epochs, history['train_struct_exact'], 'b-', label='Train Exact')
ax.plot(epochs, history['val_struct_exact'], 'b--', label='Val Exact')
ax.plot(epochs, history['train_struct_sim'], 'g-', label='Train Sim')
ax.plot(epochs, history['val_struct_sim'], 'g--', label='Val Sim')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Structure Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 5: Loss
ax = axes[1, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Total Loss')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 6: Codebook usage
ax = axes[1, 1]
ax.plot(epochs, history['train_cb_usage'], 'b-', label='Train')
ax.plot(epochs, history['val_cb_usage'], 'r--', label='Val')
ax.axhline(y=0.3, color='gray', linestyle='--', alpha=0.5)
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Codebook Usage')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 7: Perplexity
ax = axes[1, 2]
ax.plot(epochs, history['train_perplexity'], 'b-', label='Train')
ax.plot(epochs, history['val_perplexity'], 'r--', label='Val')
ax.axvline(x=WARMUP_EPOCHS, color='gray', linestyle=':')
ax.set_title('Codebook Perplexity')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 8: Final comparison bar - SHAPE PRESERVATION FOCUS
ax = axes[1, 3]
final_metrics = {
    'Struct\nRecall': history['val_struct_recall'][-1],
    '1-False\nAir': 1 - history['val_false_air'][-1],
    'Exact\nAcc': history['val_struct_exact'][-1],
    'Sim\nAcc': history['val_struct_sim'][-1],
}
colors = ['green', 'orange', 'blue', 'purple']
bars = ax.bar(final_metrics.keys(), final_metrics.values(), color=colors)
ax.set_title('Final Metrics (Shape Focus)')
ax.set_ylabel('Score')
ax.set_ylim(0, 1)
for bar, val in zip(bars, final_metrics.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{val:.1%}', ha='center', fontsize=9)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v4_training.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("\n" + "="*60)
print("SHAPE PRESERVATION SUMMARY")
print("="*60)
print(f"Structure Recall: {history['val_struct_recall'][-1]:.1%} (target: >90%)")
print(f"False Air Rate:   {history['val_false_air'][-1]:.1%} (target: <10%)")
print(f"Volume Ratio:     {history['val_vol_ratio'][-1]:.2f} (target: ~1.0)")
print(f"Exact Accuracy:   {history['val_struct_exact'][-1]:.1%}")
print(f"Sim Accuracy:     {history['val_struct_sim'][-1]:.1%}")

In [None]:
# ============================================================
# CELL 10: Save Results and Checkpoint
# ============================================================

results = {
    'config': {
        'hidden_dims': HIDDEN_DIMS,
        'latent_dim': LATENT_DIM,
        'num_codes': NUM_CODEBOOK_ENTRIES,
        'total_epochs': TOTAL_EPOCHS,
        'warmup_epochs': WARMUP_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'embedding_lr_scale': EMBEDDING_LR_SCALE,
        'embedding_loss_alpha': EMBEDDING_LOSS_ALPHA,
        'stability_weight': STABILITY_WEIGHT,
        'diversity_weight': DIVERSITY_WEIGHT,
        'structure_weight': STRUCTURE_WEIGHT,
        # NEW shape preservation config
        'false_air_weight': FALSE_AIR_WEIGHT,
        'volume_weight': VOLUME_WEIGHT,
        'structure_to_air_weight': STRUCTURE_TO_AIR_WEIGHT,
        'use_shape_loss': USE_SHAPE_LOSS,
        'use_asymmetric_loss': USE_ASYMMETRIC_LOSS,
        'seed': SEED,
    },
    'results': {
        # Shape preservation metrics (THE KEY METRICS)
        'best_struct_recall': float(best_struct_recall),
        'best_epoch': best_epoch,
        'final_struct_recall': float(history['val_struct_recall'][-1]),
        'final_false_air_rate': float(history['val_false_air'][-1]),
        'final_vol_ratio': float(history['val_vol_ratio'][-1]),
        # Original metrics
        'final_struct_exact': float(history['val_struct_exact'][-1]),
        'final_struct_sim': float(history['val_struct_sim'][-1]),
        'final_cb_usage': float(history['val_cb_usage'][-1]),
        'final_perplexity': float(history['val_perplexity'][-1]),
        'training_time_min': float(train_time / 60),
    },
    'history': {k: [float(x) if isinstance(x, (int, float)) else x for x in v] 
                for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v4_results.json", 'w') as f:
    json.dump(results, f, indent=2)

# Save complete checkpoint with all metadata for visualization script
checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'vocab_size': VOCAB_SIZE,
        'emb_dim': EMBEDDING_DIM,
        'hidden_dims': HIDDEN_DIMS,
        'latent_dim': LATENT_DIM,
        'num_codes': NUM_CODEBOOK_ENTRIES,
        'embedding_loss_alpha': EMBEDDING_LOSS_ALPHA,
        'stability_weight': STABILITY_WEIGHT,
        'diversity_weight': DIVERSITY_WEIGHT,
        'false_air_weight': FALSE_AIR_WEIGHT,
        'volume_weight': VOLUME_WEIGHT,
        'structure_to_air_weight': STRUCTURE_TO_AIR_WEIGHT,
        'dropout': DROPOUT,
        'commitment_cost': COMMITMENT_COST,
        'ema_decay': EMA_DECAY,
    },
    'air_tokens': AIR_TOKENS_LIST,
    'best_struct_recall': float(best_struct_recall),
    'best_epoch': best_epoch,
    'training_time_min': float(train_time / 60),
}

# Save best and final checkpoints
torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v4_best_checkpoint.pt")

checkpoint['model_state_dict'] = model.state_dict()  # Update to final state
torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v4_final_checkpoint.pt")

# Also save just the state dict for backwards compatibility
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v4_final.pt")

print("\nResults saved to Google Drive:")
print(f"  - {OUTPUT_DIR}/vqvae_v4_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v4_best_checkpoint.pt (full checkpoint)")
print(f"  - {OUTPUT_DIR}/vqvae_v4_final_checkpoint.pt (full checkpoint)")
print(f"  - {OUTPUT_DIR}/vqvae_v4_final.pt (state dict only)")
print(f"  - {OUTPUT_DIR}/vqvae_v4_training.png")

print("\n" + "="*70)
print("FINAL RESULTS - SHAPE PRESERVATION")
print("="*70)
print(f"Best structure recall:     {best_struct_recall:.1%} (epoch {best_epoch})")
print(f"Final structure recall:    {history['val_struct_recall'][-1]:.1%}")
print(f"Final false air rate:      {history['val_false_air'][-1]:.1%}")
print(f"Final volume ratio:        {history['val_vol_ratio'][-1]:.2f}")
print(f"Final structure exact acc: {history['val_struct_exact'][-1]:.1%}")
print(f"Final codebook usage:      {history['val_cb_usage'][-1]:.1%}")
print(f"Training time:             {train_time/60:.1f} minutes")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)
print("Structure Recall: % of original building blocks NOT erased")
print("  - Target: >90% (buildings should be preserved)")
print("  - If low: model is still erasing buildings -> increase shape loss weights")
print()
print("False Air Rate: % of structure blocks wrongly predicted as air")
print("  - Target: <10% (minimal building erasure)")
print("  - This is 1 - Structure Recall")
print()
print("Volume Ratio: predicted_volume / original_volume")
print("  - Target: ~1.0 (same amount of blocks)")
print("  - <1.0 means buildings shrunk, >1.0 means extra blocks added")

print("\n--- DOWNLOAD INSTRUCTIONS ---")
print("The checkpoint files are saved to your Google Drive at:")
print(f"  {OUTPUT_DIR}/")
print("\nYou can download vqvae_v4_best_checkpoint.pt and use with visualization script:")
print("  python scripts/visualize_reconstruction_mcp.py \\")
print("      --checkpoint vqvae_v4_best_checkpoint.pt \\")
print("      --h5-file path/to/build.h5 \\")
print("      --output commands.txt")

In [None]:
# ============================================================
# CELL 11: Reconstruction Visualizations - DIAGNOSTIC PLOTS
# ============================================================
# These visualizations help diagnose issues that metrics alone can miss,
# like the building deletion problem we discovered in Phase 2.5.

print("="*70)
print("DIAGNOSTIC VISUALIZATIONS")
print("="*70)
print("These plots help catch issues that metrics alone might miss:")
print("  1. Reconstruction slices - see if buildings are visually preserved")
print("  2. Block substitution matrix - what's replacing what")
print("  3. Error heatmaps - where are mistakes concentrated")
print("  4. Category accuracy - which block types fail completely")
print()

# Load best model for visualization
model.load_state_dict(torch.load(f"{OUTPUT_DIR}/vqvae_v4_best.pt"))
model.eval()

# Get a batch for visualization
vis_samples = []
for i, sample in enumerate(val_dataset):
    if i >= 8:  # Get 8 samples for visualization
        break
    vis_samples.append(sample)

vis_batch = torch.stack(vis_samples).to(device)

# Get reconstructions
with torch.no_grad():
    out = model(vis_batch)
    logits = out['logits']
    preds = logits.argmax(dim=1)

print(f"Loaded {len(vis_samples)} samples for visualization")
print(f"Original shape: {vis_batch.shape}")
print(f"Predictions shape: {preds.shape}")

In [None]:
# ============================================================
# CELL 12: Slice-by-Slice Reconstruction Comparison
# ============================================================
# This is the MOST IMPORTANT diagnostic - visually compare original vs reconstruction
# If buildings are being erased, you'll SEE it here even if metrics look okay

def plot_reconstruction_slices(original, reconstructed, air_tokens, sample_idx=0, 
                               slices=[8, 12, 16, 20, 24], save_path=None):
    """
    Plot horizontal slices showing:
    - Original structure (colored by block)
    - Reconstructed structure
    - Difference (red=erased, blue=added)
    """
    orig = original.cpu().numpy()
    recon = reconstructed.cpu().numpy()
    
    air_set = set(air_tokens.tolist()) if hasattr(air_tokens, 'tolist') else set(air_tokens)
    
    # Create structure masks
    orig_struct = ~np.isin(orig, list(air_set))
    recon_struct = ~np.isin(recon, list(air_set))
    
    fig, axes = plt.subplots(len(slices), 4, figsize=(16, 4*len(slices)))
    
    for i, y in enumerate(slices):
        if y >= orig.shape[0]:
            continue
            
        # Original structure slice
        ax = axes[i, 0]
        orig_slice = orig[y]
        ax.imshow(orig_slice, cmap='tab20', interpolation='nearest')
        ax.set_title(f'Original (y={y})')
        ax.set_ylabel(f'Slice {y}')
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Reconstructed slice
        ax = axes[i, 1]
        recon_slice = recon[y]
        ax.imshow(recon_slice, cmap='tab20', interpolation='nearest')
        ax.set_title(f'Reconstructed (y={y})')
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Structure mask comparison (binary)
        ax = axes[i, 2]
        orig_mask = orig_struct[y].astype(float)
        recon_mask = recon_struct[y].astype(float)
        # Stack: original in green, recon in blue
        comparison = np.stack([np.zeros_like(orig_mask), orig_mask, recon_mask], axis=-1)
        ax.imshow(comparison)
        ax.set_title('Structure: Green=Orig, Blue=Recon')
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Difference map
        ax = axes[i, 3]
        diff = np.zeros((*orig_mask.shape, 3))
        # Red: erased (was structure, now air) - THIS IS THE PROBLEM WE'RE LOOKING FOR
        erased = orig_struct[y] & ~recon_struct[y]
        diff[erased] = [1, 0, 0]  # Red
        # Blue: added (was air, now structure)
        added = ~orig_struct[y] & recon_struct[y]
        diff[added] = [0, 0, 1]  # Blue
        # Green: correct structure
        correct = orig_struct[y] & recon_struct[y]
        diff[correct] = [0, 0.5, 0]  # Dark green
        ax.imshow(diff)
        erased_count = erased.sum()
        added_count = added.sum()
        ax.set_title(f'Diff: RED=erased({erased_count}), BLUE=added({added_count})')
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.suptitle(f'Sample {sample_idx}: Reconstruction Analysis\n'
                 f'RED cells = BUILDING DELETED (structure->air) - This is what we want to minimize!', 
                 fontsize=12, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    # Summary stats
    total_erased = (orig_struct & ~recon_struct).sum()
    total_added = (~orig_struct & recon_struct).sum()
    total_orig = orig_struct.sum()
    total_recon = recon_struct.sum()
    
    print(f"\nSample {sample_idx} Structure Analysis:")
    print(f"  Original structure blocks: {total_orig}")
    print(f"  Reconstructed structure blocks: {total_recon}")
    print(f"  ERASED (structure->air): {total_erased} ({100*total_erased/max(total_orig,1):.1f}%)")
    print(f"  Added (air->structure): {total_added}")
    print(f"  Volume ratio: {total_recon/max(total_orig,1):.2f}")

# Plot first 4 samples
for idx in range(min(4, len(vis_samples))):
    plot_reconstruction_slices(
        vis_batch[idx], preds[idx], AIR_TOKENS_TENSOR,
        sample_idx=idx,
        slices=[4, 8, 12, 16, 20, 24, 28],
        save_path=f"{OUTPUT_DIR}/reconstruction_sample_{idx}.png"
    )

In [None]:
# ============================================================
# CELL 13: Block Substitution Analysis
# ============================================================
# This shows WHAT blocks are being replaced by WHAT
# Key insight: If everything is being replaced by air, that's the erasure problem
# If oak_planks -> spruce_planks, that's less concerning (similar blocks)

def analyze_block_substitutions(original, reconstructed, tok2block, air_tokens, top_k=20):
    """
    Analyze what blocks are being substituted for what.
    Returns the most common (original -> predicted) pairs.
    """
    orig = original.cpu().numpy().flatten()
    recon = reconstructed.cpu().numpy().flatten()
    
    air_set = set(air_tokens.tolist()) if hasattr(air_tokens, 'tolist') else set(air_tokens)
    
    # Count substitutions
    substitutions = Counter()
    struct_to_air = Counter()  # THE BAD ONES
    air_to_struct = Counter()
    struct_to_struct = Counter()
    
    for o, r in zip(orig, recon):
        if o != r:  # Only count errors
            orig_is_struct = o not in air_set
            recon_is_air = r in air_set
            recon_is_struct = not recon_is_air
            
            orig_name = tok2block.get(o, f"tok_{o}")[:25]
            recon_name = tok2block.get(r, f"tok_{r}")[:25]
            
            if orig_is_struct and recon_is_air:
                # STRUCTURE ERASED - this is the problem we're tracking
                struct_to_air[(orig_name, recon_name)] += 1
            elif not orig_is_struct and recon_is_struct:
                # Air -> structure (adding blocks)
                air_to_struct[(orig_name, recon_name)] += 1
            elif orig_is_struct and recon_is_struct:
                # Structure -> different structure (material confusion)
                struct_to_struct[(orig_name, recon_name)] += 1
                
            substitutions[(orig_name, recon_name)] += 1
    
    return {
        'all': substitutions.most_common(top_k),
        'struct_to_air': struct_to_air.most_common(top_k),
        'air_to_struct': air_to_struct.most_common(top_k),
        'struct_to_struct': struct_to_struct.most_common(top_k),
    }

# Analyze all validation samples
all_subs = {'all': Counter(), 'struct_to_air': Counter(), 
            'air_to_struct': Counter(), 'struct_to_struct': Counter()}

print("Analyzing block substitutions across validation set...")
for sample in tqdm(val_dataset, desc="Analyzing"):
    sample = sample.unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(sample)
        pred = out['logits'].argmax(dim=1)
    
    subs = analyze_block_substitutions(sample[0], pred[0], tok2block, AIR_TOKENS_TENSOR, top_k=100)
    for key in all_subs:
        all_subs[key].update(dict(subs[key]))

# Plot top substitutions
fig, axes = plt.subplots(2, 2, figsize=(18, 14))

# 1. STRUCTURE -> AIR (THE PROBLEM WE'RE WATCHING FOR)
ax = axes[0, 0]
s2a = all_subs['struct_to_air'].most_common(15)
if s2a:
    labels = [f"{o}->{r}" for (o,r), _ in s2a]
    counts = [c for _, c in s2a]
    bars = ax.barh(range(len(labels)), counts, color='red', alpha=0.7)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontsize=8)
    ax.invert_yaxis()
    ax.set_xlabel('Count')
    ax.set_title('STRUCTURE -> AIR (Buildings Erased)\nTHIS IS THE PROBLEM - should be minimal!', 
                 fontweight='bold', color='red')
else:
    ax.text(0.5, 0.5, 'No structure->air errors!\nGreat!', ha='center', va='center', fontsize=14)
    ax.set_title('STRUCTURE -> AIR', fontweight='bold', color='green')

# 2. AIR -> STRUCTURE
ax = axes[0, 1]
a2s = all_subs['air_to_struct'].most_common(15)
if a2s:
    labels = [f"{o}->{r}" for (o,r), _ in a2s]
    counts = [c for _, c in a2s]
    ax.barh(range(len(labels)), counts, color='blue', alpha=0.7)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontsize=8)
    ax.invert_yaxis()
    ax.set_xlabel('Count')
    ax.set_title('AIR -> STRUCTURE (Extra Blocks Added)\nLess concerning - adds blocks instead of removing', 
                 color='blue')
else:
    ax.text(0.5, 0.5, 'No air->structure errors!', ha='center', va='center', fontsize=14)
    ax.set_title('AIR -> STRUCTURE', color='blue')

# 3. STRUCTURE -> STRUCTURE (material confusion)
ax = axes[1, 0]
s2s = all_subs['struct_to_struct'].most_common(15)
if s2s:
    labels = [f"{o}->{r}" for (o,r), _ in s2s]
    counts = [c for _, c in s2s]
    ax.barh(range(len(labels)), counts, color='orange', alpha=0.7)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontsize=8)
    ax.invert_yaxis()
    ax.set_xlabel('Count')
    ax.set_title('STRUCTURE -> STRUCTURE (Material Confusion)\nAcceptable if similar blocks (oak->spruce)', 
                 color='orange')
else:
    ax.text(0.5, 0.5, 'No struct->struct errors!', ha='center', va='center', fontsize=14)
    ax.set_title('STRUCTURE -> STRUCTURE', color='orange')

# 4. Summary pie chart
ax = axes[1, 1]
total_s2a = sum(all_subs['struct_to_air'].values())
total_a2s = sum(all_subs['air_to_struct'].values())
total_s2s = sum(all_subs['struct_to_struct'].values())
total_correct = sum((vis_batch == preds).float().sum().item() for _ in range(1))  # Placeholder

sizes = [total_s2a, total_a2s, total_s2s]
labels = [f'Struct->Air\n(ERASURE)\n{total_s2a:,}', 
          f'Air->Struct\n{total_a2s:,}', 
          f'Struct->Struct\n{total_s2s:,}']
colors = ['red', 'blue', 'orange']
explode = (0.1, 0, 0)  # Emphasize the bad one

if sum(sizes) > 0:
    ax.pie(sizes, labels=labels, colors=colors, explode=explode, 
           autopct='%1.1f%%', startangle=90)
    ax.set_title('Error Type Distribution')
else:
    ax.text(0.5, 0.5, 'No errors!', ha='center', va='center', fontsize=14)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/block_substitutions.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("BLOCK SUBSTITUTION SUMMARY")
print("="*60)
print(f"Structure -> Air (ERASURE): {total_s2a:,} errors")
print(f"Air -> Structure (adding):  {total_a2s:,} errors")
print(f"Struct -> Struct (confusion): {total_s2s:,} errors")
print()
if total_s2a > total_s2s:
    print("WARNING: More erasure than material confusion!")
    print("The model is DELETING buildings more than getting materials wrong.")
else:
    print("Good: Material confusion > erasure. Shape is being preserved.")

In [None]:
# ============================================================
# CELL 14: Per-Category Accuracy Breakdown
# ============================================================
# In Phase 2.5, we saw 9/16 categories at 0% accuracy!
# This helps identify which block types the model can't predict at all

# Define block categories based on naming patterns
BLOCK_CATEGORIES = {
    'wood_planks': lambda b: 'planks' in b and 'slab' not in b and 'stair' not in b,
    'logs': lambda b: 'log' in b or 'wood' in b and 'planks' not in b,
    'stone': lambda b: any(s in b for s in ['stone', 'cobblestone', 'granite', 'diorite', 'andesite']) 
                       and 'slab' not in b and 'stair' not in b and 'wall' not in b,
    'stairs': lambda b: 'stair' in b,
    'slabs': lambda b: 'slab' in b,
    'walls': lambda b: 'wall' in b and 'banner' not in b,
    'fences': lambda b: 'fence' in b,
    'doors': lambda b: 'door' in b,
    'glass': lambda b: 'glass' in b,
    'wool': lambda b: 'wool' in b,
    'concrete': lambda b: 'concrete' in b,
    'terracotta': lambda b: 'terracotta' in b,
    'bricks': lambda b: 'brick' in b and 'slab' not in b and 'stair' not in b and 'wall' not in b,
    'ores': lambda b: 'ore' in b,
    'leaves': lambda b: 'leaves' in b or 'leaf' in b,
    'flowers': lambda b: any(f in b for f in ['dandelion', 'poppy', 'tulip', 'daisy', 'orchid', 'allium', 'rose']),
    'crops': lambda b: any(c in b for c in ['wheat', 'carrot', 'potato', 'beetroot', 'melon', 'pumpkin']),
    'redstone': lambda b: any(r in b for r in ['redstone', 'repeater', 'comparator', 'piston', 'observer', 'dispenser']),
    'lighting': lambda b: any(l in b for l in ['torch', 'lantern', 'lamp', 'glowstone', 'sea_lantern', 'shroomlight']),
    'water': lambda b: 'water' in b,
    'lava': lambda b: 'lava' in b,
}

def categorize_block(block_name):
    """Assign a block to a category."""
    block_lower = block_name.lower()
    for cat, check_fn in BLOCK_CATEGORIES.items():
        if check_fn(block_lower):
            return cat
    return 'other'

# Build token -> category mapping
tok2cat = {}
for tok, block in tok2block.items():
    tok2cat[tok] = categorize_block(block)

# Compute per-category accuracy
cat_correct = Counter()
cat_total = Counter()
cat_recall = Counter()  # Structure recall per category
cat_struct_total = Counter()

print("Computing per-category accuracy...")
for sample in tqdm(val_dataset, desc="Categories"):
    sample_t = sample.unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(sample_t)
        pred = out['logits'].argmax(dim=1)[0]
    
    orig = sample.cpu().numpy().flatten()
    recon = pred.cpu().numpy().flatten()
    air_set = set(AIR_TOKENS_TENSOR.tolist())
    
    for o, r in zip(orig, recon):
        cat = tok2cat.get(o, 'other')
        is_struct = o not in air_set
        
        if is_struct:
            cat_struct_total[cat] += 1
            if r not in air_set:  # Predicted as structure (not erased)
                cat_recall[cat] += 1
        
        cat_total[cat] += 1
        if o == r:
            cat_correct[cat] += 1

# Compute accuracies
categories = sorted(set(cat_total.keys()) - {'other', 'air'})
exact_accs = []
recalls = []

for cat in categories:
    total = cat_total.get(cat, 0)
    correct = cat_correct.get(cat, 0)
    struct_total = cat_struct_total.get(cat, 0)
    recall = cat_recall.get(cat, 0)
    
    exact_acc = correct / total if total > 0 else 0
    struct_recall = recall / struct_total if struct_total > 0 else 0
    
    exact_accs.append(exact_acc)
    recalls.append(struct_recall)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Exact accuracy per category
ax = axes[0]
colors = ['red' if acc == 0 else 'orange' if acc < 0.3 else 'green' for acc in exact_accs]
bars = ax.barh(categories, exact_accs, color=colors, alpha=0.7)
ax.axvline(x=0.5, color='gray', linestyle='--', alpha=0.5, label='50%')
ax.set_xlabel('Exact Accuracy')
ax.set_title('Per-Category Exact Accuracy\nRED = 0% (never predicted correctly)')
ax.set_xlim(0, 1)
for bar, acc in zip(bars, exact_accs):
    ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2, 
            f'{acc:.1%}', va='center', fontsize=8)

# Structure recall per category (shape preservation)
ax = axes[1]
colors = ['red' if r < 0.5 else 'orange' if r < 0.8 else 'green' for r in recalls]
bars = ax.barh(categories, recalls, color=colors, alpha=0.7)
ax.axvline(x=0.9, color='gray', linestyle='--', alpha=0.5, label='90% target')
ax.set_xlabel('Structure Recall (not erased)')
ax.set_title('Per-Category Structure Recall\nRED = blocks being erased')
ax.set_xlim(0, 1)
for bar, r in zip(bars, recalls):
    ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2, 
            f'{r:.1%}', va='center', fontsize=8)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=150, bbox_inches='tight')
plt.show()

# Print summary
print("\n" + "="*60)
print("PER-CATEGORY BREAKDOWN")
print("="*60)
print(f"{'Category':<15} {'Count':>8} {'Exact':>8} {'Recall':>8}")
print("-"*45)

zero_acc_cats = []
low_recall_cats = []

for cat, acc, rec in zip(categories, exact_accs, recalls):
    count = cat_total.get(cat, 0)
    print(f"{cat:<15} {count:>8,} {acc:>7.1%} {rec:>7.1%}")
    if acc == 0:
        zero_acc_cats.append(cat)
    if rec < 0.5:
        low_recall_cats.append(cat)

print("-"*45)
print(f"\nCategories with 0% exact accuracy: {len(zero_acc_cats)}")
if zero_acc_cats:
    print(f"  {', '.join(zero_acc_cats)}")
print(f"\nCategories with <50% recall (being erased): {len(low_recall_cats)}")
if low_recall_cats:
    print(f"  {', '.join(low_recall_cats)}")

In [None]:
# ============================================================
# CELL 15: Error Location Heatmaps
# ============================================================
# Where in the structure are errors concentrated?
# - Center vs edges?
# - Top vs bottom?
# - Are corners being erased?

def compute_error_heatmaps(original, reconstructed, air_tokens):
    """Compute spatial heatmaps of different error types."""
    orig = original.cpu().numpy()
    recon = reconstructed.cpu().numpy()
    air_set = set(air_tokens.tolist()) if hasattr(air_tokens, 'tolist') else set(air_tokens)
    
    orig_struct = ~np.isin(orig, list(air_set))
    recon_struct = ~np.isin(recon, list(air_set))
    
    # Error masks
    erasure = orig_struct & ~recon_struct  # Structure deleted
    addition = ~orig_struct & recon_struct  # Extra blocks
    confusion = orig_struct & recon_struct & (orig != recon)  # Wrong material
    correct = orig == recon
    
    return {
        'erasure': erasure.astype(float),
        'addition': addition.astype(float),
        'confusion': confusion.astype(float),
        'correct': correct.astype(float),
        'orig_struct': orig_struct.astype(float),
    }

# Aggregate error heatmaps across all validation samples
print("Computing spatial error distribution...")
agg_erasure = np.zeros((32, 32, 32))
agg_addition = np.zeros((32, 32, 32))
agg_confusion = np.zeros((32, 32, 32))
agg_struct = np.zeros((32, 32, 32))
n_samples = 0

for sample in tqdm(val_dataset, desc="Error locations"):
    sample_t = sample.unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(sample_t)
        pred = out['logits'].argmax(dim=1)[0]
    
    heatmaps = compute_error_heatmaps(sample, pred, AIR_TOKENS_TENSOR)
    agg_erasure += heatmaps['erasure']
    agg_addition += heatmaps['addition']
    agg_confusion += heatmaps['confusion']
    agg_struct += heatmaps['orig_struct']
    n_samples += 1

# Normalize by number of samples
agg_erasure /= n_samples
agg_addition /= n_samples
agg_confusion /= n_samples
agg_struct /= n_samples

# Plot XZ projection (top-down view) and XY projection (side view)
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Top-down views (sum over Y axis)
for idx, (name, data, cmap, title) in enumerate([
    ('erasure', agg_erasure, 'Reds', 'ERASURE (structure->air)'),
    ('addition', agg_addition, 'Blues', 'Addition (air->structure)'),
    ('confusion', agg_confusion, 'Oranges', 'Material Confusion'),
    ('struct', agg_struct, 'Greens', 'Original Structure Density'),
]):
    ax = axes[0, idx]
    proj = data.sum(axis=0)  # Sum over Y (height)
    im = ax.imshow(proj, cmap=cmap, interpolation='nearest')
    ax.set_title(f'{title}\n(Top-down projection)')
    ax.set_xlabel('X')
    ax.set_ylabel('Z')
    plt.colorbar(im, ax=ax, fraction=0.046)

# Side views (sum over Z axis)
for idx, (name, data, cmap, title) in enumerate([
    ('erasure', agg_erasure, 'Reds', 'ERASURE'),
    ('addition', agg_addition, 'Blues', 'Addition'),
    ('confusion', agg_confusion, 'Oranges', 'Confusion'),
    ('struct', agg_struct, 'Greens', 'Structure'),
]):
    ax = axes[1, idx]
    proj = data.sum(axis=2)  # Sum over Z
    im = ax.imshow(proj.T, cmap=cmap, origin='lower', interpolation='nearest')
    ax.set_title(f'{title}\n(Side projection)')
    ax.set_xlabel('X')
    ax.set_ylabel('Y (height)')
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle('Spatial Error Distribution\nRed areas = where buildings are being erased', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/error_heatmaps.png", dpi=150, bbox_inches='tight')
plt.show()

# Analyze error distribution by region
print("\n" + "="*60)
print("SPATIAL ERROR ANALYSIS")
print("="*60)

# Compute errors in different regions
center = slice(10, 22)
edge_x = [slice(0, 8), slice(24, 32)]
edge_z = [slice(0, 8), slice(24, 32)]

center_erasure = agg_erasure[:, center, center].mean()
edge_erasure = np.mean([agg_erasure[:, e, :].mean() for e in edge_x] + 
                        [agg_erasure[:, :, e].mean() for e in edge_z])

center_struct = agg_struct[:, center, center].mean()
edge_struct = np.mean([agg_struct[:, e, :].mean() for e in edge_x] + 
                       [agg_struct[:, :, e].mean() for e in edge_z])

print(f"Erasure rate in CENTER: {center_erasure:.4f}")
print(f"Erasure rate at EDGES:  {edge_erasure:.4f}")
print(f"Structure density in CENTER: {center_struct:.4f}")
print(f"Structure density at EDGES:  {edge_struct:.4f}")

if center_erasure > edge_erasure * 1.5:
    print("\nWARNING: Center is being erased more than edges!")
    print("This might indicate the latent grid is too coarse for interior details.")
elif edge_erasure > center_erasure * 1.5:
    print("\nWARNING: Edges are being erased more than center!")
    print("The model might be having trouble with boundary regions.")

In [None]:
# ============================================================
# CELL 16: Worst Reconstructions Analysis
# ============================================================
# Examine the failure cases to understand what the model struggles with
# Sorted by: lowest structure recall (most building deleted)

print("Finding worst reconstructions (most building erased)...")

sample_stats = []
for idx, sample in enumerate(tqdm(val_dataset, desc="Evaluating")):
    sample_t = sample.unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(sample_t)
        pred = out['logits'].argmax(dim=1)[0]
    
    orig = sample.cpu().numpy()
    recon = pred.cpu().numpy()
    air_set = set(AIR_TOKENS_TENSOR.tolist())
    
    orig_struct = ~np.isin(orig, list(air_set))
    recon_struct = ~np.isin(recon, list(air_set))
    
    orig_vol = orig_struct.sum()
    recon_vol = recon_struct.sum()
    erased = (orig_struct & ~recon_struct).sum()
    
    recall = 1 - (erased / orig_vol) if orig_vol > 0 else 1.0
    vol_ratio = recon_vol / orig_vol if orig_vol > 0 else 1.0
    
    sample_stats.append({
        'idx': idx,
        'orig_vol': orig_vol,
        'recon_vol': recon_vol,
        'erased': erased,
        'recall': recall,
        'vol_ratio': vol_ratio,
        'sample': sample,
        'pred': pred.cpu(),
    })

# Sort by worst recall (most erased)
sample_stats.sort(key=lambda x: x['recall'])

# Show worst 5
print("\n" + "="*60)
print("5 WORST RECONSTRUCTIONS (most building erased)")
print("="*60)

fig, axes = plt.subplots(5, 5, figsize=(20, 20))

for row, stats in enumerate(sample_stats[:5]):
    idx = stats['idx']
    orig = stats['sample'].numpy()
    recon = stats['pred'].numpy()
    
    air_set = set(AIR_TOKENS_TENSOR.tolist())
    orig_struct = ~np.isin(orig, list(air_set))
    recon_struct = ~np.isin(recon, list(air_set))
    
    # Find best slice to show (most structure)
    struct_per_slice = orig_struct.sum(axis=(1, 2))
    best_y = struct_per_slice.argmax()
    
    # Original blocks
    axes[row, 0].imshow(orig[best_y], cmap='tab20', interpolation='nearest')
    axes[row, 0].set_title(f'Original (y={best_y})')
    axes[row, 0].set_ylabel(f'Sample {idx}\nRecall: {stats["recall"]:.1%}')
    
    # Reconstructed blocks
    axes[row, 1].imshow(recon[best_y], cmap='tab20', interpolation='nearest')
    axes[row, 1].set_title(f'Reconstructed')
    
    # Original structure mask
    axes[row, 2].imshow(orig_struct[best_y], cmap='Greens', interpolation='nearest')
    axes[row, 2].set_title(f'Orig Structure')
    
    # Reconstructed structure mask
    axes[row, 3].imshow(recon_struct[best_y], cmap='Blues', interpolation='nearest')
    axes[row, 3].set_title(f'Recon Structure')
    
    # Difference (red = erased)
    diff = np.zeros((*orig_struct[best_y].shape, 3))
    erased = orig_struct[best_y] & ~recon_struct[best_y]
    added = ~orig_struct[best_y] & recon_struct[best_y]
    correct = orig_struct[best_y] & recon_struct[best_y]
    diff[erased] = [1, 0, 0]
    diff[added] = [0, 0, 1]
    diff[correct] = [0, 0.5, 0]
    axes[row, 4].imshow(diff)
    axes[row, 4].set_title(f'Diff: {stats["erased"]} erased')

for ax in axes.flat:
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('WORST 5 RECONSTRUCTIONS\n'
             'Red = erased blocks, Blue = added blocks, Green = preserved\n'
             'These show what the model struggles with most', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/worst_reconstructions.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary stats
print("\nWorst 5 samples:")
print(f"{'Idx':>5} {'Orig Vol':>10} {'Recon Vol':>10} {'Erased':>8} {'Recall':>8} {'Vol Ratio':>10}")
print("-"*55)
for stats in sample_stats[:5]:
    print(f"{stats['idx']:>5} {stats['orig_vol']:>10,} {stats['recon_vol']:>10,} "
          f"{stats['erased']:>8,} {stats['recall']:>7.1%} {stats['vol_ratio']:>9.2f}")

# Show best 5 for comparison
print("\n" + "="*60)
print("5 BEST RECONSTRUCTIONS (for comparison)")
print("="*60)
print(f"{'Idx':>5} {'Orig Vol':>10} {'Recon Vol':>10} {'Erased':>8} {'Recall':>8} {'Vol Ratio':>10}")
print("-"*55)
for stats in sample_stats[-5:]:
    print(f"{stats['idx']:>5} {stats['orig_vol']:>10,} {stats['recon_vol']:>10,} "
          f"{stats['erased']:>8,} {stats['recall']:>7.1%} {stats['vol_ratio']:>9.2f}")

In [None]:
# ============================================================
# CELL 17: Volume Distribution Analysis
# ============================================================
# Are buildings systematically shrinking or growing?
# This shows the distribution of volume ratios across all samples

# Extract volume stats (already computed in previous cell)
orig_volumes = [s['orig_vol'] for s in sample_stats]
recon_volumes = [s['recon_vol'] for s in sample_stats]
vol_ratios = [s['vol_ratio'] for s in sample_stats]
recalls = [s['recall'] for s in sample_stats]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Histogram of volume ratios
ax = axes[0, 0]
ax.hist(vol_ratios, bins=50, color='blue', alpha=0.7, edgecolor='black')
ax.axvline(x=1.0, color='red', linestyle='--', linewidth=2, label='Perfect (1.0)')
ax.axvline(x=np.mean(vol_ratios), color='green', linestyle='-', linewidth=2, 
           label=f'Mean ({np.mean(vol_ratios):.2f})')
ax.set_xlabel('Volume Ratio (recon/orig)')
ax.set_ylabel('Count')
ax.set_title('Volume Ratio Distribution\n<1.0 = shrinking, >1.0 = growing')
ax.legend()

# 2. Scatter: Original vs Reconstructed volume
ax = axes[0, 1]
ax.scatter(orig_volumes, recon_volumes, alpha=0.5, s=20)
max_vol = max(max(orig_volumes), max(recon_volumes))
ax.plot([0, max_vol], [0, max_vol], 'r--', label='Perfect reconstruction')
ax.set_xlabel('Original Volume (blocks)')
ax.set_ylabel('Reconstructed Volume (blocks)')
ax.set_title('Original vs Reconstructed Volume\nPoints below line = shrinking')
ax.legend()

# 3. Histogram of structure recall
ax = axes[1, 0]
ax.hist(recalls, bins=50, color='green', alpha=0.7, edgecolor='black')
ax.axvline(x=0.9, color='red', linestyle='--', linewidth=2, label='Target (90%)')
ax.axvline(x=np.mean(recalls), color='blue', linestyle='-', linewidth=2,
           label=f'Mean ({np.mean(recalls):.1%})')
ax.set_xlabel('Structure Recall')
ax.set_ylabel('Count')
ax.set_title('Structure Recall Distribution\n% of original structure preserved')
ax.legend()

# 4. Recall vs Volume ratio scatter
ax = axes[1, 1]
ax.scatter(vol_ratios, recalls, alpha=0.5, s=20, c=orig_volumes, cmap='viridis')
ax.axhline(y=0.9, color='red', linestyle='--', alpha=0.5, label='Target recall')
ax.axvline(x=1.0, color='red', linestyle='--', alpha=0.5, label='Perfect volume')
ax.set_xlabel('Volume Ratio')
ax.set_ylabel('Structure Recall')
ax.set_title('Recall vs Volume Ratio\nColor = original volume')
cbar = plt.colorbar(ax.collections[0], ax=ax)
cbar.set_label('Original Volume')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/volume_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary statistics
print("\n" + "="*60)
print("VOLUME ANALYSIS SUMMARY")
print("="*60)
print(f"Volume Ratio:")
print(f"  Mean:   {np.mean(vol_ratios):.3f}")
print(f"  Median: {np.median(vol_ratios):.3f}")
print(f"  Std:    {np.std(vol_ratios):.3f}")
print(f"  Min:    {np.min(vol_ratios):.3f}")
print(f"  Max:    {np.max(vol_ratios):.3f}")
print()
print(f"Structure Recall:")
print(f"  Mean:   {np.mean(recalls):.1%}")
print(f"  Median: {np.median(recalls):.1%}")
print(f"  Std:    {np.std(recalls):.1%}")
print(f"  Min:    {np.min(recalls):.1%}")
print(f"  Max:    {np.max(recalls):.1%}")

# Interpretation
print("\n" + "="*60)
print("INTERPRETATION")
print("="*60)
if np.mean(vol_ratios) < 0.9:
    print("WARNING: Buildings are systematically SHRINKING!")
    print(f"  Average volume loss: {(1 - np.mean(vol_ratios))*100:.1f}%")
    print("  The model is erasing more structure than it should.")
elif np.mean(vol_ratios) > 1.1:
    print("NOTE: Buildings are systematically GROWING!")
    print(f"  Average volume gain: {(np.mean(vol_ratios) - 1)*100:.1f}%")
    print("  The model is adding extra blocks.")
else:
    print("Good: Volume is well-preserved on average.")

below_90 = sum(1 for r in recalls if r < 0.9)
below_50 = sum(1 for r in recalls if r < 0.5)
print(f"\nSamples below 90% recall: {below_90} ({100*below_90/len(recalls):.1f}%)")
print(f"Samples below 50% recall: {below_50} ({100*below_50/len(recalls):.1f}%)")

In [None]:
# ============================================================
# CELL 18: Embedding Drift Visualization
# ============================================================
# Since embeddings are trainable in v4, we need to track how much they changed
# Too much drift = lost semantic meaning
# Too little drift = didn't adapt to task

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Get original and current embeddings
original_emb = model.original_embeddings.cpu().numpy()
current_emb = model.block_emb.weight.detach().cpu().numpy()

# Compute per-embedding drift
drifts = []
for i in range(len(original_emb)):
    orig = original_emb[i]
    curr = current_emb[i]
    # Cosine similarity
    cos_sim = np.dot(orig, curr) / (np.linalg.norm(orig) * np.linalg.norm(curr) + 1e-8)
    # Euclidean distance
    euc_dist = np.linalg.norm(curr - orig)
    drifts.append({
        'token': i,
        'block': tok2block.get(i, f'tok_{i}'),
        'cos_sim': cos_sim,
        'euc_dist': euc_dist,
    })

# Sort by most drifted
drifts.sort(key=lambda x: x['cos_sim'])

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Histogram of cosine similarities
ax = axes[0, 0]
cos_sims = [d['cos_sim'] for d in drifts]
ax.hist(cos_sims, bins=50, color='blue', alpha=0.7, edgecolor='black')
ax.axvline(x=np.mean(cos_sims), color='red', linestyle='--', 
           label=f'Mean: {np.mean(cos_sims):.3f}')
ax.axvline(x=0.9, color='green', linestyle=':', label='Stable threshold (0.9)')
ax.set_xlabel('Cosine Similarity (original vs trained)')
ax.set_ylabel('Count')
ax.set_title('Embedding Stability\n1.0 = unchanged, lower = more drift')
ax.legend()

# 2. Top 20 most drifted embeddings
ax = axes[0, 1]
most_drifted = drifts[:20]
blocks = [d['block'][:20] for d in most_drifted]
sims = [d['cos_sim'] for d in most_drifted]
colors = ['red' if s < 0.5 else 'orange' if s < 0.8 else 'yellow' for s in sims]
ax.barh(range(len(blocks)), sims, color=colors)
ax.set_yticks(range(len(blocks)))
ax.set_yticklabels(blocks, fontsize=8)
ax.set_xlabel('Cosine Similarity')
ax.set_title('20 Most Drifted Embeddings\n(these changed the most during training)')
ax.set_xlim(0, 1)
ax.invert_yaxis()

# 3. PCA visualization: original vs trained
ax = axes[1, 0]

# Sample 500 embeddings for visualization
sample_idx = np.random.choice(len(original_emb), min(500, len(original_emb)), replace=False)

# Combine and do PCA
combined = np.vstack([original_emb[sample_idx], current_emb[sample_idx]])
pca = PCA(n_components=2)
combined_2d = pca.fit_transform(combined)

orig_2d = combined_2d[:len(sample_idx)]
curr_2d = combined_2d[len(sample_idx):]

# Plot with arrows showing drift
ax.scatter(orig_2d[:, 0], orig_2d[:, 1], c='blue', alpha=0.3, s=20, label='Original')
ax.scatter(curr_2d[:, 0], curr_2d[:, 1], c='red', alpha=0.3, s=20, label='Trained')

# Draw arrows for top 20 most drifted in sample
sample_drifts = [(i, drifts[sample_idx[i]]['cos_sim']) for i in range(len(sample_idx))]
sample_drifts.sort(key=lambda x: x[1])
for i, _ in sample_drifts[:20]:
    ax.annotate('', xy=curr_2d[i], xytext=orig_2d[i],
                arrowprops=dict(arrowstyle='->', color='gray', alpha=0.5))

ax.set_xlabel('PCA 1')
ax.set_ylabel('PCA 2')
ax.set_title('Embedding Space: Original (blue) vs Trained (red)\nArrows show drift direction')
ax.legend()

# 4. Drift by category
ax = axes[1, 1]
cat_drifts = {}
for d in drifts:
    cat = categorize_block(d['block'])
    if cat not in cat_drifts:
        cat_drifts[cat] = []
    cat_drifts[cat].append(d['cos_sim'])

cat_names = []
cat_means = []
cat_stds = []
for cat, sims in sorted(cat_drifts.items(), key=lambda x: np.mean(x[1])):
    if len(sims) >= 5:  # Only categories with enough samples
        cat_names.append(cat)
        cat_means.append(np.mean(sims))
        cat_stds.append(np.std(sims))

colors = ['red' if m < 0.7 else 'orange' if m < 0.85 else 'green' for m in cat_means]
ax.barh(cat_names, cat_means, xerr=cat_stds, color=colors, alpha=0.7, capsize=3)
ax.axvline(x=0.9, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Mean Cosine Similarity')
ax.set_title('Embedding Drift by Category\nLower = category changed more')
ax.set_xlim(0, 1)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/embedding_drift.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("\n" + "="*60)
print("EMBEDDING DRIFT SUMMARY")
print("="*60)
print(f"Mean cosine similarity: {np.mean(cos_sims):.3f}")
print(f"Median: {np.median(cos_sims):.3f}")
print(f"Min: {np.min(cos_sims):.3f}")
print(f"Max: {np.max(cos_sims):.3f}")
print()

# Interpretation
if np.mean(cos_sims) > 0.95:
    print("Embeddings barely changed. The regularization may be too strong,")
    print("or the embedding learning rate too low.")
elif np.mean(cos_sims) < 0.7:
    print("WARNING: Embeddings changed significantly!")
    print("This might mean semantic meaning was lost.")
    print("Consider increasing stability_weight or decreasing embedding LR.")
else:
    print("Good: Moderate embedding drift. Embeddings adapted while preserving structure.")

print("\n5 Most Drifted Embeddings:")
for d in drifts[:5]:
    print(f"  {d['block'][:30]:<30} cos_sim={d['cos_sim']:.3f}")

print("\n5 Most Stable Embeddings:")
for d in drifts[-5:]:
    print(f"  {d['block'][:30]:<30} cos_sim={d['cos_sim']:.3f}")

In [None]:
# ============================================================
# CELL 19: Final Diagnostic Summary
# ============================================================
# This cell summarizes all diagnostics and provides actionable recommendations

print("="*70)
print("COMPREHENSIVE DIAGNOSTIC SUMMARY")
print("="*70)

print("\n### 1. SHAPE PRESERVATION (Primary Goal) ###")
print(f"Structure Recall:    {np.mean(recalls):.1%} (target: >90%)")
print(f"False Air Rate:      {1 - np.mean(recalls):.1%} (target: <10%)")
print(f"Volume Ratio:        {np.mean(vol_ratios):.2f} (target: ~1.0)")

shape_ok = np.mean(recalls) > 0.85
if shape_ok:
    print("✓ Shape preservation looks acceptable")
else:
    print("✗ ISSUE: Shape is not being preserved well")

print("\n### 2. BLOCK SUBSTITUTION ###")
print(f"Structure->Air (erasure):   {total_s2a:,} errors")
print(f"Struct->Struct (confusion): {total_s2s:,} errors")

if total_s2a < total_s2s:
    print("✓ Material confusion > erasure (good - shape preserved)")
else:
    print("✗ ISSUE: More erasure than confusion (buildings being deleted)")

print("\n### 3. PER-CATEGORY ###")
print(f"Categories with 0% accuracy: {len(zero_acc_cats)}")
print(f"Categories with <50% recall: {len(low_recall_cats)}")

if len(zero_acc_cats) < 5:
    print("✓ Most categories are being predicted")
else:
    print("✗ ISSUE: Many categories never predicted correctly")

print("\n### 4. EMBEDDING DRIFT ###")
print(f"Mean embedding similarity: {np.mean(cos_sims):.3f}")

if 0.7 < np.mean(cos_sims) < 0.95:
    print("✓ Embeddings adapted moderately")
elif np.mean(cos_sims) > 0.95:
    print("⚠ Embeddings barely changed (may not have adapted)")
else:
    print("✗ ISSUE: Embeddings drifted too much (semantic meaning may be lost)")

print("\n### 5. WORST CASES ###")
print(f"Samples with <50% recall: {below_50}")
print(f"Samples with <90% recall: {below_90}")

print("\n" + "="*70)
print("RECOMMENDATIONS")
print("="*70)

recommendations = []

if np.mean(recalls) < 0.85:
    recommendations.append("- Increase FALSE_AIR_WEIGHT or STRUCTURE_TO_AIR_WEIGHT")
    recommendations.append("- Try longer training or lower learning rate")

if total_s2a > total_s2s:
    recommendations.append("- Structure is being erased too much")
    recommendations.append("- Increase VOLUME_WEIGHT to penalize shrinking")

if len(zero_acc_cats) > 5:
    recommendations.append("- Many block categories never predicted")
    recommendations.append("- Check if these blocks exist in training data")
    recommendations.append("- Consider class weighting or focal loss")

if np.mean(cos_sims) < 0.7:
    recommendations.append("- Embeddings drifted too much")
    recommendations.append("- Increase STABILITY_WEIGHT")

if np.mean(cos_sims) > 0.95:
    recommendations.append("- Embeddings may not be learning")
    recommendations.append("- Decrease STABILITY_WEIGHT or increase embedding LR")

if np.mean(vol_ratios) < 0.8:
    recommendations.append("- Buildings are shrinking significantly")
    recommendations.append("- Volume preservation loss needs to be stronger")

if not recommendations:
    print("No major issues detected! The model appears to be training well.")
else:
    for rec in recommendations:
        print(rec)

print("\n" + "="*70)
print("OUTPUT FILES SAVED TO GOOGLE DRIVE")
print("="*70)
print(f"Location: {OUTPUT_DIR}/")
print()
print("Checkpoints:")
print("  - vqvae_v4_best_checkpoint.pt    (best model by struct_recall)")
print("  - vqvae_v4_final_checkpoint.pt   (final model)")
print("  - vqvae_v4_results.json          (training history)")
print()
print("Diagnostic Plots:")
print("  - vqvae_v4_training.png          (training curves)")
print("  - reconstruction_sample_*.png    (slice-by-slice reconstructions)")
print("  - block_substitutions.png        (what's replacing what)")
print("  - category_accuracy.png          (per-category breakdown)")
print("  - error_heatmaps.png             (spatial error distribution)")
print("  - worst_reconstructions.png      (failure cases)")
print("  - volume_analysis.png            (volume distribution)")
print("  - embedding_drift.png            (embedding changes)")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)