# üèóÔ∏è Sparse Structure Transformer for Minecraft Build Reconstruction

## What This Notebook Does

This notebook trains a **Sparse Structure Transformer** - a neural network that learns to reconstruct Minecraft builds. The model takes a 3D structure (like a house or castle) and learns to compress it into a compact representation, then decompress it back to the original structure.

**The Problem We're Solving:**
Minecraft builds are stored as 32√ó32√ó32 grids of blocks (32,768 total positions). However, ~80% of these positions are just air! Traditional approaches (like VQ-VAE) waste most of their computation on empty space and end up predicting "air" everywhere, failing completely on important blocks like stairs, doors, and slabs.

**Our Solution:**
Instead of processing the entire dense grid, we treat each build as a **sparse set of (position, block) pairs**. A house with 500 blocks becomes a set of 500 elements, not 32,768. This is like describing a room by listing the furniture, not by describing every cubic inch of air.

---

## What is a Sparse Transformer?

### The Core Idea

A **Sparse Transformer** processes variable-length sets of elements using self-attention, rather than fixed-size grids using convolutions.

```
Traditional VQ-VAE (Dense):
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  32√ó32√ó32 Grid  ‚Üí  3D CNN Encoder  ‚Üí  Latent  ‚Üí  3D CNN Decoder  ‚Üí  Grid  ‚îÇ
‚îÇ  (32,768 voxels)      (80% air!)      Codes         (still)       (32,768) ‚îÇ
‚îÇ                                                    predicts                 ‚îÇ
‚îÇ                                                    mostly air               ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

Sparse Transformer:
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Set of ~500 blocks  ‚Üí  Transformer  ‚Üí  Latent  ‚Üí  Transformer  ‚Üí  Set    ‚îÇ
‚îÇ  {(pos‚ÇÅ, oak_planks),    Encoder       Codes      Decoder       of ~500   ‚îÇ
‚îÇ   (pos‚ÇÇ, oak_stairs),  (attention)               (cross-attn)   predicted ‚îÇ
‚îÇ   (pos‚ÇÉ, glass_pane),                                           blocks    ‚îÇ
‚îÇ   ...}                                                                     ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

### Key Components

1. **Sparse Representation**: Extract only non-air blocks as (x, y, z, block_embedding) tuples
2. **Fourier Positional Encoding**: Encode 3D coordinates as high-frequency sinusoidal features (like NeRF)
3. **Transformer Encoder**: Self-attention lets each block "see" all other blocks in the structure
4. **Set Pooling**: Compress variable-length encoded blocks into fixed-size latent codes
5. **Vector Quantization (optional)**: Discretize latent space for generation
6. **Transformer Decoder**: Cross-attention from target positions to latent codes
7. **Embedding Prediction**: Output block embeddings, match to vocabulary via nearest neighbor

### Why This Works Better

| Aspect | Dense VQ-VAE | Sparse Transformer |
|--------|-------------|-------------------|
| **Representation** | 32¬≥ = 32,768 voxels | ~500 non-air blocks |
| **Air handling** | 80% of computation wasted | Air not represented |
| **Loss function** | 3717-way classification | Embedding regression (MSE) |
| **Class imbalance** | Air dominates training | Equal weight per block |
| **Embedding usage** | Frozen, mostly ignored | Direct input AND output target |
| **Stairs/doors accuracy** | 0% | Expected >50% |

---

## How to Test With Your Own Build

After training, you can test the model on any Minecraft structure:

```python
# 1. Load your build (32√ó32√ó32 grid of block IDs)
with h5py.File("my_build.h5", 'r') as f:
    structure = f['structure'][:]

# 2. Convert to sparse format
positions, block_ids, embeddings = extract_sparse(structure, all_embeddings)

# 3. Run through model
model.eval()
with torch.no_grad():
    outputs = model(positions, embeddings, attention_mask)
    pred_embeddings = outputs["pred_embeddings"]

# 4. Find nearest block for each prediction
pred_block_ids = torch.cdist(pred_embeddings, all_embeddings).argmin(dim=-1)

# 5. Reconstruct to dense grid
reconstructed = reconstruct_dense(positions, pred_block_ids)

# 6. Compare!
accuracy = (reconstructed == structure).mean()
```

The notebook includes a full example of this at the end.

---

## Expected Results

Based on our VQ-VAE baseline (~49% non-air accuracy, 0% on stairs/doors):

| Metric | VQ-VAE Baseline | Expected Sparse Transformer |
|--------|-----------------|----------------------------|
| Overall accuracy | ~49% | 70-80% |
| Stairs | 0% | >50% |
| Slabs | 0% | >50% |
| Doors | 0% | >50% |
| Fences | 0% | >50% |

---

## Table of Contents

1. **Setup** - Install dependencies and mount Google Drive
2. **Configuration** - Set paths and hyperparameters
3. **Data Loading** - Load embeddings and vocabulary
4. **Dataset** - Define sparse structure dataset
5. **Model Components** - Positional encoding, pooling, VQ
6. **Sparse Transformer** - The main model architecture
7. **Training Functions** - Loss computation and training loop
8. **Training** - Run the training
9. **Visualizations** - Training curves, category accuracy, confusion matrix
10. **Testing** - Test on a real structure
11. **Save Results** - Export model and metrics


---
# 1. Setup - Install Dependencies and Mount Google Drive


In [None]:
# ============================================================
# 1.1 Mount Google Drive and Install Dependencies
# ============================================================
# Mount Google Drive to access data and save results

from google.colab import drive
drive.mount('/content/drive')

# Install any missing dependencies (h5py should be pre-installed)
# Using %pip for better Colab compatibility
%pip install -q h5py tqdm

print("‚úì Google Drive mounted")
print("‚úì Dependencies installed")


In [None]:
# ============================================================
# 1.2 Import Libraries
# ============================================================

import json
import time
import random
import math
import os
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Set
from collections import defaultdict

import h5py
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm.notebook import tqdm

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è WARNING: No GPU detected! Training will be very slow.")
    print("  Go to Runtime ‚Üí Change runtime type ‚Üí GPU")


---
# 2. Configuration - Set Paths and Hyperparameters

**‚ö†Ô∏è IMPORTANT**: Update the paths below to match your Google Drive structure.

The data files needed are:
- `block_embeddings_v3.npy` - Block2Vec embeddings (40-dimensional vectors for each block type)
- `tok2block.json` - Vocabulary mapping token IDs to block names
- `train/` folder with `.h5` structure files
- `val/` folder with `.h5` structure files


In [None]:
# ============================================================
# 2.1 Path Configuration (Google Colab / Google Drive)
# ============================================================
# ‚ö†Ô∏è UPDATE THESE PATHS to match your Google Drive structure!

# Base path in Google Drive
DRIVE_BASE = "/content/drive/MyDrive/minecraft_ai"

# Data paths
DATA_DIR = f"{DRIVE_BASE}/data/splits/train"
VAL_DIR = f"{DRIVE_BASE}/data/splits/val"
EMBEDDINGS_PATH = f"{DRIVE_BASE}/embeddings/block_embeddings_v3.npy"
VOCAB_PATH = f"{DRIVE_BASE}/vocabulary/tok2block.json"
OUTPUT_DIR = "/content/output"  # Local output, copy to Drive later

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Verify paths exist
paths_ok = True
for name, path in [("Embeddings", EMBEDDINGS_PATH), ("Vocabulary", VOCAB_PATH), 
                    ("Train data", DATA_DIR), ("Val data", VAL_DIR)]:
    if os.path.exists(path):
        print(f"‚úì {name}: {path}")
    else:
        print(f"‚úó {name} NOT FOUND: {path}")
        paths_ok = False

if not paths_ok:
    print("\n‚ö†Ô∏è Some paths are missing! Please update the paths above.")


In [None]:
# ============================================================
# 2.2 Model and Training Hyperparameters
# ============================================================

# Air tokens will be detected dynamically from vocabulary (see Cell 9)
AIR_TOKENS: Set[int] = set()  # Placeholder - populated after loading vocab

# === Model Architecture ===
EMBED_DIM = 40           # Block2Vec embedding dimension
HIDDEN_DIM = 256         # Transformer hidden dimension
N_ENCODER_LAYERS = 6     # Number of transformer encoder layers
N_DECODER_LAYERS = 6     # Number of transformer decoder layers
N_HEADS = 8              # Number of attention heads
NUM_LATENT_CODES = 16    # Number of latent codes after pooling
VQ_NUM_EMBEDDINGS = 1024 # VQ codebook size (0 = disable VQ)
DROPOUT = 0.1            # Dropout rate

# === Training Settings ===
EPOCHS = 20              # Number of training epochs
BATCH_SIZE = 16          # Batch size (reduce if OOM)
LEARNING_RATE = 1e-4     # Initial learning rate
WEIGHT_DECAY = 0.01      # AdamW weight decay
AUX_WEIGHT = 0.1         # Weight for auxiliary classification loss
MAX_BLOCKS = 2048        # Maximum blocks per structure

# === Other ===
SEED = 42                # Random seed for reproducibility
NUM_WORKERS = 2          # DataLoader workers

# Print configuration summary
print("=" * 60)
print("CONFIGURATION SUMMARY")
print("=" * 60)
print(f"Model: {HIDDEN_DIM}d hidden, {N_ENCODER_LAYERS}+{N_DECODER_LAYERS} layers, {N_HEADS} heads")
print(f"Latent: {NUM_LATENT_CODES} codes" + (f", VQ with {VQ_NUM_EMBEDDINGS} entries" if VQ_NUM_EMBEDDINGS > 0 else ", no VQ"))
print(f"Training: {EPOCHS} epochs, batch={BATCH_SIZE}, lr={LEARNING_RATE}")
print(f"Max blocks per structure: {MAX_BLOCKS}")
print("=" * 60)


---
# 3. Data Loading - Load Embeddings and Vocabulary


In [None]:
# ============================================================
# 3.1 Set Random Seeds for Reproducibility
# ============================================================

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

print(f"‚úì Random seeds set to {SEED}")


In [None]:
# ============================================================
# 3.2 Load Block2Vec Embeddings and Vocabulary
# ============================================================

# Load Block2Vec embeddings
# These are 40-dimensional vectors that capture semantic similarity between blocks
# e.g., oak_stairs is close to spruce_stairs, oak_planks, etc.
all_embeddings = np.load(EMBEDDINGS_PATH).astype(np.float32)

# Check for NaN/Inf in embeddings
if np.isnan(all_embeddings).any():
    print("‚ö†Ô∏è  WARNING: NaN values found in embeddings! Replacing with zeros.")
    all_embeddings = np.nan_to_num(all_embeddings, nan=0.0)
if np.isinf(all_embeddings).any():
    print("‚ö†Ô∏è  WARNING: Inf values found in embeddings! Replacing with zeros.")
    all_embeddings = np.nan_to_num(all_embeddings, posinf=0.0, neginf=0.0)

# Normalize embeddings if they have extreme values
embed_max = np.abs(all_embeddings).max()
if embed_max > 100:
    print(f"‚ö†Ô∏è  WARNING: Embeddings have large values (max={embed_max:.2f}). Normalizing.")
    all_embeddings = all_embeddings / embed_max * 10

all_embeddings_tensor = torch.from_numpy(all_embeddings).to(device)
VOCAB_SIZE = all_embeddings.shape[0]
print(f"‚úì Loaded embeddings: {all_embeddings.shape}")
print(f"  - {VOCAB_SIZE} unique block types")
print(f"  - {EMBED_DIM}-dimensional embeddings")
print(f"  - Value range: [{all_embeddings.min():.3f}, {all_embeddings.max():.3f}]")

# Load vocabulary (token ID ‚Üí block name mapping)
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}
print(f"‚úì Loaded vocabulary: {len(tok2block)} blocks")

# Dynamically detect air tokens from vocabulary
# This ensures we use the correct tokens regardless of vocabulary version
global AIR_TOKENS
AIR_TOKENS = set()
for tok, block in tok2block.items():
    block_lower = block.lower()
    # Match "air", "cave_air", "void_air" but not "stairs"
    if 'air' in block_lower and 'stair' not in block_lower:
        AIR_TOKENS.add(tok)
        
print(f"‚úì Detected {len(AIR_TOKENS)} air tokens: {sorted(AIR_TOKENS)}")
for tok in sorted(AIR_TOKENS):
    print(f"    {tok}: {tok2block[tok]}")

# Show some example blocks
print("\nExample blocks in vocabulary:")
for tok in [0, 100, 500, 1000, 2000, 3000]:
    if tok in tok2block:
        print(f"  Token {tok}: {tok2block[tok]}")

# Verify air tokens
print("\nAir tokens:")
for tok in AIR_TOKENS:
    if tok in tok2block:
        print(f"  Token {tok}: {tok2block[tok]}")


---
# 4. Dataset - Sparse Structure Dataset

This is where the magic happens! Instead of loading 32√ó32√ó32 dense grids, we extract only the non-air blocks as a sparse set of (position, block_id, embedding) tuples.

**Key insight**: A typical Minecraft structure has ~500-2000 non-air blocks out of 32,768 total positions. By focusing only on these, we:
1. Reduce memory by ~16x
2. Give equal weight to every block (no air domination)
3. Process structures in proportion to their complexity


In [None]:
# ============================================================
# 4.1 Sparse Structure Dataset Class
# ============================================================

class SparseStructureDataset(Dataset):
    """
    Dataset that converts dense 32x32x32 structures to sparse (position, block) sets.
    
    Instead of representing a structure as 32,768 voxels (80% air), we extract only
    the non-air blocks as a set of (x, y, z, block_id) tuples.
    
    For example, a small house with 500 blocks becomes:
    {
        positions: [[0,0,0], [0,0,1], [0,1,0], ...],  # 500 x 3
        block_ids: [oak_planks, oak_planks, glass, ...],  # 500
        embeddings: [[0.1, 0.2, ...], ...]  # 500 x 40
    }
    """
    
    def __init__(
        self,
        data_dir: str,
        embeddings: np.ndarray,
        max_files: Optional[int] = None,
        max_blocks: int = 2048,
        augment: bool = False,
        seed: int = 42,
    ):
        self.data_dir = Path(data_dir)
        self.embeddings = embeddings
        self.max_blocks = max_blocks
        self.augment = augment
        self.rng = random.Random(seed)
        
        # Find all .h5 structure files
        # First try the directory directly
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        
        # If no files found, try searching recursively in subdirectories
        if len(self.h5_files) == 0:
            print(f"  No .h5 files in root directory, searching subdirectories...")
            self.h5_files = sorted(self.data_dir.glob("**/*.h5"))
            if len(self.h5_files) > 0:
                print(f"  Found {len(self.h5_files)} .h5 files in subdirectories")
        
        if max_files:
            self.h5_files = self.h5_files[:max_files]
        
        if len(self.h5_files) == 0:
            # Provide helpful error message
            error_msg = f"No .h5 files found in {data_dir}\n"
            error_msg += f"  Please check:\n"
            error_msg += f"  1. The path is correct: {data_dir}\n"
            error_msg += f"  2. Files are uploaded to Google Drive\n"
            error_msg += f"  3. Files have .h5 extension"
            raise ValueError(error_msg)
        
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        # Load the dense 32x32x32 structure
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        
        # Extract non-air blocks (the sparse representation!)
        non_air_mask = ~np.isin(structure, list(AIR_TOKENS))
        positions = np.argwhere(non_air_mask).astype(np.float32)  # [N, 3]
        block_ids = structure[non_air_mask]  # [N]
        
        # Handle empty structures: add a dummy block at origin
        if len(block_ids) == 0:
            positions = np.array([[0.0, 0.0, 0.0]], dtype=np.float32)
            block_ids = np.array([0], dtype=np.int64)  # Use first block in vocab
        
        # Clamp block IDs to valid embedding range
        max_block_id = len(self.embeddings) - 1
        block_ids = np.clip(block_ids, 0, max_block_id)
        
        # Randomly sample if too many blocks (for memory efficiency)
        n_blocks = len(block_ids)
        if n_blocks > self.max_blocks:
            indices = self.rng.sample(range(n_blocks), self.max_blocks)
            indices = sorted(indices)  # Keep spatial ordering
            positions = positions[indices]
            block_ids = block_ids[indices]
        
        # Data augmentation: random rotation and flips
        if self.augment:
            positions = self._augment(positions)
        
        # Look up embeddings for each block
        embeddings = self.embeddings[block_ids]
        
        # Check for NaN in embeddings (shouldn't happen but safety check)
        if np.isnan(embeddings).any():
            print(f"Warning: NaN in embeddings for file {self.h5_files[idx]}")
            embeddings = np.nan_to_num(embeddings, 0.0)
        
        return {
            "positions": torch.from_numpy(positions).float(),
            "block_ids": torch.from_numpy(block_ids).long(),
            "embeddings": torch.from_numpy(embeddings).float(),
            "num_blocks": torch.tensor(len(block_ids), dtype=torch.long),
        }
    
    def _augment(self, positions, grid_size=32):
        """Apply random 90¬∞ rotation around Y axis and horizontal flips."""
        positions = positions.copy()
        
        # Random rotation around Y axis (0, 90, 180, or 270 degrees)
        k = self.rng.randint(0, 3)
        if k > 0:
            x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
            for _ in range(k):
                new_x = grid_size - 1 - z
                new_z = x.copy()
                x, z = new_x, new_z
            positions[:, 0], positions[:, 2] = x, z
        
        # Random horizontal flips
        if self.rng.random() > 0.5:
            positions[:, 0] = grid_size - 1 - positions[:, 0]
        if self.rng.random() > 0.5:
            positions[:, 2] = grid_size - 1 - positions[:, 2]
        
        return positions


def collate_sparse(batch):
    """
    Custom collate function that pads variable-length sequences.
    
    Since structures have different numbers of blocks, we pad to the
    maximum length in the batch and create attention masks.
    """
    positions = pad_sequence([b["positions"] for b in batch], batch_first=True)
    block_ids = pad_sequence([b["block_ids"] for b in batch], batch_first=True)
    embeddings = pad_sequence([b["embeddings"] for b in batch], batch_first=True)
    num_blocks = torch.stack([b["num_blocks"] for b in batch])
    
    # Create attention mask: True for valid positions, False for padding
    max_len = positions.size(1)
    attention_mask = torch.arange(max_len).unsqueeze(0) < num_blocks.unsqueeze(1)
    
    return {
        "positions": positions,
        "block_ids": block_ids,
        "embeddings": embeddings,
        "num_blocks": num_blocks,
        "attention_mask": attention_mask,
    }

print("‚úì SparseStructureDataset defined")


---
# 5. Model Components - Positional Encoding, Pooling, VQ


In [None]:
# ============================================================
# 5.1 Fourier Positional Encoding
# ============================================================
# Transforms 3D coordinates into high-dimensional features using
# sin/cos functions at multiple frequencies. This allows the network
# to learn both low and high frequency spatial patterns.

class FourierPositionalEncoding(nn.Module):
    """
    Fourier feature encoding for 3D coordinates (like NeRF).
    
    Maps (x, y, z) ‚Üí [sin(2^0 * œÄ * x), cos(2^0 * œÄ * x), sin(2^1 * œÄ * x), ...]
    
    This creates 3 * num_frequencies * 2 features (sin + cos for each freq and coord).
    """
    
    def __init__(self, num_frequencies: int = 10, max_coord: int = 32, include_input: bool = True):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.max_coord = max_coord
        self.include_input = include_input
        
        # Frequency bands: 2^0, 2^1, 2^2, ..., 2^(L-1)
        freq_bands = 2.0 ** torch.linspace(0, num_frequencies - 1, num_frequencies)
        self.register_buffer("freq_bands", freq_bands)
        
        # Output dimension
        self.output_dim = 3 * num_frequencies * 2  # 3 coords * L freqs * 2 (sin+cos)
        if include_input:
            self.output_dim += 3  # Also include normalized input coords
    
    def forward(self, positions):
        # Normalize to [-1, 1]
        normalized = positions / self.max_coord * 2 - 1
        
        # Apply frequencies: [B, N, 3, num_freq]
        scaled = normalized.unsqueeze(-1) * self.freq_bands * math.pi
        
        # Sin and cos: [B, N, 3, num_freq * 2]
        encoded = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1)
        
        # Flatten: [B, N, 3 * num_freq * 2]
        encoded = encoded.view(*positions.shape[:-1], -1)
        
        if self.include_input:
            encoded = torch.cat([normalized, encoded], dim=-1)
        
        return encoded

print("‚úì FourierPositionalEncoding defined")


In [None]:
# ============================================================
# 5.2 Set Pooling and Vector Quantization
# ============================================================

class SetPooling(nn.Module):
    """
    Attention-based pooling to compress variable-length sets into fixed-size.
    
    Uses learnable "seed" vectors that attend to the input set, producing
    a fixed number of output vectors regardless of input length.
    """
    
    def __init__(self, input_dim, output_dim, num_outputs=16, num_heads=8):
        super().__init__()
        self.num_outputs = num_outputs
        self.seeds = nn.Parameter(torch.randn(num_outputs, input_dim))
        self.attention = nn.MultiheadAttention(input_dim, num_heads, batch_first=True)
        self.proj = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
    
    def forward(self, x, mask=None):
        batch_size = x.size(0)
        seeds = self.seeds.unsqueeze(0).expand(batch_size, -1, -1)
        key_padding_mask = ~mask if mask is not None else None
        
        pooled, _ = self.attention(seeds, x, x, key_padding_mask=key_padding_mask)
        return self.norm(self.proj(pooled))


class VectorQuantizerEMA(nn.Module):
    """
    Vector Quantization with Exponential Moving Average codebook updates.
    
    Discretizes continuous latent vectors by mapping each to its nearest
    codebook entry. Uses EMA (not gradients) to update the codebook.
    """
    
    def __init__(self, num_embeddings=1024, embedding_dim=256, commitment_cost=0.5, decay=0.99, epsilon=1e-5):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon
        
        # Initialize codebook with smaller variance (0.1 instead of 1.0)
        self.register_buffer("codebook", torch.randn(num_embeddings, embedding_dim) * 0.1)
        # Initialize cluster sizes to 1 (not 0) to prevent division issues
        self.register_buffer("ema_cluster_size", torch.ones(num_embeddings))
        # Initialize embed_sum to match codebook
        self.register_buffer("ema_embed_sum", self.codebook.clone())
        self.register_buffer("initialized", torch.tensor(False))
    
    def forward(self, z_e):
        B, K, D = z_e.shape
        flat = z_e.view(-1, D)
        
        # Initialize codebook from first batch (k-means++ style)
        if not self.initialized and self.training:
            n_samples = min(flat.size(0), self.num_embeddings)
            indices = torch.randperm(flat.size(0))[:n_samples]
            self.codebook.data[:n_samples] = flat[indices].detach()
            self.ema_embed_sum.data[:n_samples] = flat[indices].detach()
            self.initialized.fill_(True)
            print("  ‚Üí VQ codebook initialized from data")
        
        # Find nearest codebook entries
        distances = (flat ** 2).sum(1, keepdim=True) + (self.codebook ** 2).sum(1) - 2 * flat @ self.codebook.t()
        indices = distances.argmin(dim=1)
        z_q_flat = F.embedding(indices, self.codebook)
        
        # EMA codebook update (only for used entries)
        if self.training:
            encodings = F.one_hot(indices, self.num_embeddings).float()
            cluster_counts = encodings.sum(0)
            
            # Update EMA stats
            self.ema_cluster_size.mul_(self.decay).add_(cluster_counts, alpha=1-self.decay)
            embed_sum = encodings.t() @ flat
            self.ema_embed_sum.mul_(self.decay).add_(embed_sum, alpha=1-self.decay)
            
            # Update codebook with Laplace smoothing
            n = self.ema_cluster_size.sum()
            smoothed = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
            # Only update where we have sufficient counts
            new_codebook = self.ema_embed_sum / (smoothed.unsqueeze(1) + self.epsilon)
            # Clamp to prevent extreme values
            new_codebook = torch.clamp(new_codebook, -10, 10)
            self.codebook.data.copy_(new_codebook)
        
        z_q = z_q_flat.view(B, K, D)
        vq_loss = self.commitment_cost * F.mse_loss(z_e, z_q.detach())
        z_q = z_e + (z_q - z_e).detach()  # Straight-through estimator
        
        return z_q, vq_loss, indices.view(B, K)

print("‚úì SetPooling and VectorQuantizerEMA defined")


---
# 6. Sparse Structure Transformer - The Main Model


In [None]:
# ============================================================
# 6.1 Sparse Structure Transformer Architecture
# ============================================================

class SparseStructureTransformer(nn.Module):
    """
    Transformer-based autoencoder for sparse Minecraft structures.
    
    Architecture:
    1. Input: Set of (position, block_embedding) pairs
    2. Positional Encoding: Fourier features for 3D coordinates
    3. Transformer Encoder: Self-attention over all blocks
    4. Set Pooling: Compress to fixed-size latent codes
    5. Vector Quantization: Discretize latent space (optional)
    6. Transformer Decoder: Cross-attention from positions to latent
    7. Output: Predicted block embeddings ‚Üí nearest neighbor lookup
    """
    
    def __init__(
        self,
        embed_dim: int = 40,
        hidden_dim: int = 256,
        n_encoder_layers: int = 6,
        n_decoder_layers: int = 6,
        n_heads: int = 8,
        num_latent_codes: int = 16,
        vq_num_embeddings: int = 1024,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.use_vq = vq_num_embeddings > 0
        
        # Positional encoding for 3D coordinates
        self.pos_encoder = FourierPositionalEncoding(num_frequencies=10, max_coord=32)
        pos_dim = self.pos_encoder.output_dim
        
        # Input projection: position features + block embedding ‚Üí hidden_dim
        self.input_proj = nn.Linear(pos_dim + embed_dim, hidden_dim)
        
        # Transformer Encoder: self-attention over blocks
        encoder_layer = nn.TransformerEncoderLayer(
            hidden_dim, n_heads, hidden_dim * 4, dropout, batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, n_encoder_layers)
        
        # Set Pooling: variable-length ‚Üí fixed-size latent
        self.pool = SetPooling(hidden_dim, hidden_dim, num_latent_codes, n_heads)
        
        # Vector Quantization (optional)
        self.vq = None
        if self.use_vq:
            self.vq = VectorQuantizerEMA(vq_num_embeddings, hidden_dim)
        
        # Transformer Decoder: cross-attention from positions to latent
        decoder_layer = nn.TransformerDecoderLayer(
            hidden_dim, n_heads, hidden_dim * 4, dropout, batch_first=True, norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, n_decoder_layers)
        
        # Position query projection for decoder
        self.pos_query_proj = nn.Linear(pos_dim, hidden_dim)
        
        # Output projection: hidden_dim ‚Üí block embedding
        self.output_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, positions, embeddings, attention_mask=None):
        """Encode sparse structure to latent codes."""
        # Combine position features and block embeddings
        pos_features = self.pos_encoder(positions)
        x = self.input_proj(torch.cat([pos_features, embeddings], dim=-1))
        
        # Transformer encoder with masking
        src_key_padding_mask = ~attention_mask if attention_mask is not None else None
        encoded = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        
        # Pool to fixed-size latent
        latent = self.pool(encoded, attention_mask)
        
        # Optional VQ
        vq_loss = None
        if self.use_vq and self.vq is not None:
            latent, vq_loss, _ = self.vq(latent)
        
        return latent, vq_loss
    
    def decode(self, latent, positions, attention_mask=None):
        """Decode latent codes to block embeddings at given positions."""
        # Create position queries
        pos_features = self.pos_encoder(positions)
        queries = self.pos_query_proj(pos_features)
        
        # Cross-attention: queries attend to latent codes
        tgt_key_padding_mask = ~attention_mask if attention_mask is not None else None
        decoded = self.decoder(queries, latent, tgt_key_padding_mask=tgt_key_padding_mask)
        
        # Project to embedding space
        return self.output_proj(decoded)
    
    def forward(self, positions, embeddings, attention_mask=None):
        """Full forward pass: encode ‚Üí decode."""
        latent, vq_loss = self.encode(positions, embeddings, attention_mask)
        pred_embeddings = self.decode(latent, positions, attention_mask)
        
        return {
            "pred_embeddings": pred_embeddings,
            "vq_loss": vq_loss if vq_loss is not None else torch.tensor(0.0, device=positions.device),
        }

print("‚úì SparseStructureTransformer defined")


In [None]:
# ============================================================
# 6.2 Create Datasets and Model
# ============================================================

# Debug: Check what's in the directories before creating datasets
print("=" * 60)
print("DEBUGGING: Checking data directories...")
print("=" * 60)

for dir_name, dir_path in [("Train", DATA_DIR), ("Val", VAL_DIR)]:
    print(f"\n{dir_name} directory: {dir_path}")
    if os.path.exists(dir_path):
        print(f"  ‚úì Path exists")
        # List all items
        items = list(Path(dir_path).iterdir())
        print(f"  Items in directory: {len(items)}")
        
        # Count .h5 files
        h5_files = list(Path(dir_path).glob("*.h5"))
        print(f"  .h5 files found: {len(h5_files)}")
        
        if len(h5_files) == 0:
            print(f"  ‚ö†Ô∏è  No .h5 files found!")
            # Show first few items
            print(f"  First 10 items:")
            for item in items[:10]:
                item_type = "DIR" if item.is_dir() else "FILE"
                print(f"    [{item_type}] {item.name}")
            
            # Check subdirectories
            subdirs = [d for d in items if d.is_dir()]
            if subdirs:
                print(f"\n  Found {len(subdirs)} subdirectories:")
                for subdir in subdirs[:5]:
                    h5_in_sub = list(subdir.glob("*.h5"))
                    print(f"    {subdir.name}/: {len(h5_in_sub)} .h5 files")
        else:
            print(f"  ‚úì Found {len(h5_files)} .h5 files (showing first 3):")
            for f in h5_files[:3]:
                print(f"    {f.name}")
    else:
        print(f"  ‚úó Path does NOT exist!")

print("\n" + "=" * 60)
print("Creating datasets...")
print("=" * 60)

# Create datasets
train_dataset = SparseStructureDataset(
    DATA_DIR, all_embeddings, max_blocks=MAX_BLOCKS, augment=True, seed=SEED
)
val_dataset = SparseStructureDataset(
    VAL_DIR, all_embeddings, max_blocks=MAX_BLOCKS, augment=False, seed=SEED
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, collate_fn=collate_sparse, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, collate_fn=collate_sparse, pin_memory=True
)

print(f"‚úì Train: {len(train_dataset)} structures, {len(train_loader)} batches")
print(f"‚úì Val: {len(val_dataset)} structures, {len(val_loader)} batches")

# Create model
model = SparseStructureTransformer(
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    n_encoder_layers=N_ENCODER_LAYERS,
    n_decoder_layers=N_DECODER_LAYERS,
    n_heads=N_HEADS,
    num_latent_codes=NUM_LATENT_CODES,
    vq_num_embeddings=VQ_NUM_EMBEDDINGS,
    dropout=DROPOUT,
).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"‚úì Model created: {num_params:,} parameters")


---
# 7. Training Functions - Loss Computation and Training Loop


In [None]:
# ============================================================
# 7.1 Helper Functions and Category Classification
# ============================================================

def get_category(block_name):
    """Categorize a block by its shape/type for detailed metrics."""
    name = block_name.replace("minecraft:", "").split("[")[0].lower()
    if "stair" in name: return "stairs"
    if "slab" in name: return "slabs"
    if "door" in name: return "doors"
    if "fence" in name: return "fences"
    if "wall" in name and "_wall" in name: return "walls"
    if "planks" in name: return "planks"
    if "log" in name or "wood" in name: return "logs"
    if "glass" in name: return "glass"
    if "wool" in name: return "wool"
    if "concrete" in name: return "concrete"
    if "stone" in name or "cobble" in name: return "stone"
    if "brick" in name: return "bricks"
    if "button" in name: return "buttons"
    if "torch" in name or "lantern" in name: return "lighting"
    return "other"


def compute_loss(model, batch, all_embeddings, aux_weight=0.1):
    """Compute training loss with MSE for embeddings and optional classification."""
    positions = batch["positions"].to(device)
    embeddings = batch["embeddings"].to(device)
    block_ids = batch["block_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    
    outputs = model(positions, embeddings, attention_mask)
    pred_embeddings = outputs["pred_embeddings"]
    vq_loss = outputs["vq_loss"]
    
    # Primary loss: Embedding MSE (with epsilon to prevent div by zero)
    mask = attention_mask.unsqueeze(-1)
    embed_diff = (pred_embeddings - embeddings) ** 2
    mask_sum = mask.sum().clamp(min=1.0)  # Prevent division by zero
    embed_loss = (embed_diff * mask).sum() / mask_sum / pred_embeddings.size(-1)
    
    # Check for NaN and replace with zero
    if torch.isnan(embed_loss):
        print("Warning: NaN in embed_loss, replacing with 0")
        embed_loss = torch.tensor(0.0, device=device, requires_grad=True)
    if torch.isnan(vq_loss):
        print("Warning: NaN in vq_loss, replacing with 0")
        vq_loss = torch.tensor(0.0, device=device)
    
    total_loss = embed_loss + vq_loss
    
    # Auxiliary loss: Classification via nearest neighbor
    aux_loss = torch.tensor(0.0, device=device)
    accuracy = torch.tensor(0.0, device=device)
    
    if aux_weight > 0:
        pred_flat = pred_embeddings.view(-1, pred_embeddings.size(-1))
        distances = torch.cdist(pred_flat, all_embeddings)
        logits = -distances  # Negative distance as "logit"
        
        targets_flat = block_ids.view(-1)
        mask_flat = attention_mask.view(-1)
        
        ce_loss = F.cross_entropy(logits, targets_flat, reduction='none')
        mask_sum_flat = mask_flat.sum().clamp(min=1.0)
        aux_loss = (ce_loss * mask_flat.float()).sum() / mask_sum_flat
        
        # Check for NaN
        if torch.isnan(aux_loss):
            print("Warning: NaN in aux_loss, replacing with 0")
            aux_loss = torch.tensor(0.0, device=device)
        
        total_loss = total_loss + aux_weight * aux_loss
        
        # Compute accuracy
        with torch.no_grad():
            preds = distances.argmin(dim=1)
            correct = (preds == targets_flat).float()
            accuracy = (correct * mask_flat.float()).sum() / mask_sum_flat
    
    return {
        "loss": total_loss,
        "embed_loss": embed_loss,
        "vq_loss": vq_loss,
        "aux_loss": aux_loss,
        "accuracy": accuracy,
    }

print("‚úì Helper functions defined")


In [None]:
# ============================================================
# 7.2 Training and Validation Functions
# ============================================================

def train_epoch(model, loader, optimizer, all_embeddings):
    """Train for one epoch."""
    model.train()
    metrics = {"loss": 0, "embed_loss": 0, "vq_loss": 0, "accuracy": 0}
    n = 0
    nan_batches = 0
    
    for batch in tqdm(loader, desc="Train", leave=False):
        optimizer.zero_grad()
        losses = compute_loss(model, batch, all_embeddings, AUX_WEIGHT)
        
        # Skip batch if loss is NaN
        if torch.isnan(losses["loss"]) or torch.isinf(losses["loss"]):
            nan_batches += 1
            continue
        
        losses["loss"].backward()
        
        # Check for NaN gradients
        has_nan_grad = False
        for param in model.parameters():
            if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                has_nan_grad = True
                break
        
        if has_nan_grad:
            optimizer.zero_grad()  # Clear bad gradients
            nan_batches += 1
            continue
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        for k in metrics:
            metrics[k] += losses[k].item()
        n += 1
    
    if nan_batches > 0:
        print(f"  ‚ö†Ô∏è  Skipped {nan_batches} batches with NaN")
    
    if n == 0:
        print("  ‚ùå All batches had NaN! Check your data/model.")
        return {"loss": float('nan'), "embed_loss": float('nan'), "vq_loss": float('nan'), "accuracy": 0}
    
    return {k: v / n for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, all_embeddings, detailed=False):
    """Validate and compute detailed metrics."""
    model.eval()
    metrics = {"loss": 0, "embed_loss": 0, "accuracy": 0}
    total_blocks = 0
    total_correct = 0
    n = 0
    
    # Per-category tracking
    category_correct = defaultdict(int)
    category_total = defaultdict(int)
    confusion = defaultdict(lambda: defaultdict(int))  # true_cat -> pred_cat -> count
    
    for batch in tqdm(loader, desc="Val", leave=False):
        positions = batch["positions"].to(device)
        embeddings = batch["embeddings"].to(device)
        block_ids = batch["block_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        num_blocks = batch["num_blocks"]
        
        outputs = model(positions, embeddings, attention_mask)
        pred_embeddings = outputs["pred_embeddings"]
        
        # Embedding loss (with epsilon to prevent div by zero)
        mask = attention_mask.unsqueeze(-1)
        embed_diff = (pred_embeddings - embeddings) ** 2
        mask_sum = mask.sum().clamp(min=1.0)
        embed_loss = (embed_diff * mask).sum() / mask_sum / pred_embeddings.size(-1)
        
        # Skip if NaN
        if torch.isnan(embed_loss):
            continue
            
        metrics["embed_loss"] += embed_loss.item()
        metrics["loss"] += embed_loss.item()
        
        # Nearest neighbor prediction
        B, N, D = pred_embeddings.shape
        pred_flat = pred_embeddings.view(-1, D)
        distances = torch.cdist(pred_flat, all_embeddings)
        pred_ids = distances.argmin(dim=1).view(B, N)
        
        # Accuracy
        correct = (pred_ids == block_ids) & attention_mask
        total_correct += correct.sum().item()
        total_blocks += attention_mask.sum().item()
        
        # Per-category metrics (sample for speed)
        for b in range(B):
            for i in range(min(num_blocks[b].item(), 200)):
                true_id = block_ids[b, i].item()
                pred_id = pred_ids[b, i].item()
                if true_id in tok2block:
                    true_cat = get_category(tok2block[true_id])
                    category_total[true_cat] += 1
                    if pred_id == true_id:
                        category_correct[true_cat] += 1
                    if detailed and pred_id in tok2block:
                        pred_cat = get_category(tok2block[pred_id])
                        confusion[true_cat][pred_cat] += 1
        
        n += 1
    
    metrics["accuracy"] = total_correct / max(total_blocks, 1)
    metrics["loss"] /= max(n, 1)
    metrics["embed_loss"] /= max(n, 1)
    
    # Category accuracy
    cat_acc = {}
    for cat in category_total:
        if category_total[cat] > 0:
            cat_acc[cat] = category_correct[cat] / category_total[cat]
    metrics["category_accuracy"] = cat_acc
    metrics["category_total"] = dict(category_total)
    
    if detailed:
        metrics["confusion"] = {k: dict(v) for k, v in confusion.items()}
    
    return metrics

print("‚úì Training and validation functions defined")


---
# 8. Training - Run the Training Loop


In [None]:
# ============================================================
# 8.1 Training Loop
# ============================================================

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE/10)

history = {
    "train_loss": [], "train_embed_loss": [], "train_vq_loss": [], "train_accuracy": [],
    "val_loss": [], "val_accuracy": [], "learning_rate": [],
}
best_val_acc = 0
category_history = []  # Track category accuracy over epochs

print("=" * 60)
print("üöÄ SPARSE STRUCTURE TRANSFORMER TRAINING")
print("=" * 60)

start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # Train
    train_metrics = train_epoch(model, train_loader, optimizer, all_embeddings_tensor)
    
    # Validate
    val_metrics = validate(model, val_loader, all_embeddings_tensor, detailed=(epoch == EPOCHS - 1))
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Record history
    history["train_loss"].append(train_metrics["loss"])
    history["train_embed_loss"].append(train_metrics["embed_loss"])
    history["train_vq_loss"].append(train_metrics["vq_loss"])
    history["train_accuracy"].append(train_metrics["accuracy"])
    history["val_loss"].append(val_metrics["loss"])
    history["val_accuracy"].append(val_metrics["accuracy"])
    history["learning_rate"].append(current_lr)
    category_history.append(val_metrics["category_accuracy"])
    
    epoch_time = time.time() - epoch_start
    
    # Print progress
    print(f"\nüìä Epoch {epoch+1}/{EPOCHS} ({epoch_time:.1f}s)")
    print(f"   Train: loss={train_metrics['loss']:.4f}, acc={train_metrics['accuracy']:.2%}")
    print(f"   Val:   loss={val_metrics['loss']:.4f}, acc={val_metrics['accuracy']:.2%}")
    
    # Category accuracy preview
    if val_metrics["category_accuracy"]:
        cats = sorted(val_metrics["category_accuracy"].items(), key=lambda x: -x[1])[:5]
        cat_str = " | ".join([f"{c}: {a:.0%}" for c, a in cats])
        print(f"   Categories: {cat_str}")
    
    # Save best model
    if val_metrics["accuracy"] > best_val_acc:
        best_val_acc = val_metrics["accuracy"]
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/sparse_transformer_best.pt")
        print(f"   ‚úì New best model saved!")

total_time = time.time() - start_time
print(f"\n" + "=" * 60)
print(f"‚úÖ Training complete in {total_time/60:.1f} minutes")
print(f"üèÜ Best validation accuracy: {best_val_acc:.2%}")
print("=" * 60)


---
# 9. Visualizations - Training Curves, Category Accuracy, Analysis


In [None]:
# ============================================================
# 9.1 Training Curves - Loss and Accuracy
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
ax = axes[0, 0]
ax.plot(history["train_loss"], label="Train", linewidth=2, color='#2196F3')
ax.plot(history["val_loss"], label="Val", linewidth=2, color='#FF5722')
ax.set_title("üìâ Total Loss", fontsize=14, fontweight='bold')
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy curves
ax = axes[0, 1]
ax.plot([a * 100 for a in history["train_accuracy"]], label="Train", linewidth=2, color='#2196F3')
ax.plot([a * 100 for a in history["val_accuracy"]], label="Val", linewidth=2, color='#FF5722')
ax.axhline(y=49, color='gray', linestyle='--', alpha=0.5, label='VQ-VAE Baseline')
ax.set_title("üìà Block Accuracy (%)", fontsize=14, fontweight='bold')
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy (%)")
ax.legend()
ax.grid(True, alpha=0.3)

# VQ Loss (if using)
ax = axes[1, 0]
if VQ_NUM_EMBEDDINGS > 0:
    ax.plot(history["train_vq_loss"], linewidth=2, color='#9C27B0')
    ax.set_title("üîÆ VQ Commitment Loss", fontsize=14, fontweight='bold')
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.grid(True, alpha=0.3)
else:
    ax.text(0.5, 0.5, "VQ Disabled", ha='center', va='center', fontsize=14, color='gray')
    ax.set_title("üîÆ VQ Loss", fontsize=14, fontweight='bold')
    ax.axis('off')

# Learning rate
ax = axes[1, 1]
ax.plot(history["learning_rate"], linewidth=2, color='#4CAF50')
ax.set_title("üìö Learning Rate Schedule", fontsize=14, fontweight='bold')
ax.set_xlabel("Epoch")
ax.set_ylabel("Learning Rate")
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# ============================================================
# 9.2 Per-Category Accuracy Bar Chart
# ============================================================

# Load best model for final evaluation
model.load_state_dict(torch.load(f"{OUTPUT_DIR}/sparse_transformer_best.pt"))
final_metrics = validate(model, val_loader, all_embeddings_tensor, detailed=True)

# Prepare data for bar chart
categories = sorted(final_metrics["category_accuracy"].keys())
accuracies = [final_metrics["category_accuracy"][c] * 100 for c in categories]
totals = [final_metrics["category_total"].get(c, 0) for c in categories]

# Sort by accuracy
sorted_data = sorted(zip(categories, accuracies, totals), key=lambda x: -x[1])
categories, accuracies, totals = zip(*sorted_data)

# Color code by performance
colors = []
for acc in accuracies:
    if acc >= 70: colors.append('#4CAF50')  # Green - good
    elif acc >= 50: colors.append('#FF9800')  # Orange - decent
    elif acc >= 30: colors.append('#FF5722')  # Red-orange - poor
    else: colors.append('#F44336')  # Red - bad

fig, ax = plt.subplots(figsize=(14, 8))
bars = ax.barh(range(len(categories)), accuracies, color=colors, edgecolor='white', linewidth=0.5)

# Add labels
for i, (cat, acc, total) in enumerate(zip(categories, accuracies, totals)):
    ax.text(acc + 1, i, f"{acc:.1f}% ({total:,})", va='center', fontsize=10)

ax.set_yticks(range(len(categories)))
ax.set_yticklabels(categories)
ax.set_xlabel("Accuracy (%)", fontsize=12)
ax.set_title("üìä Per-Category Reconstruction Accuracy", fontsize=14, fontweight='bold')
ax.axvline(x=49, color='gray', linestyle='--', alpha=0.7, label='VQ-VAE Baseline (49%)')
ax.axvline(x=final_metrics["accuracy"]*100, color='blue', linestyle='-', alpha=0.7, label=f'Overall ({final_metrics["accuracy"]*100:.1f}%)')
ax.legend(loc='lower right')
ax.set_xlim(0, 105)
ax.grid(True, axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüèÜ Overall Accuracy: {final_metrics['accuracy']:.2%}")


In [None]:
# ============================================================
# 9.3 VQ-VAE vs Sparse Transformer Comparison
# ============================================================

# Critical categories that VQ-VAE failed on
critical_cats = ['stairs', 'slabs', 'doors', 'fences', 'walls', 'buttons', 'logs', 'planks']
vqvae_baseline = {'stairs': 0, 'slabs': 0, 'doors': 0, 'fences': 0, 'walls': 0, 
                  'buttons': 0, 'logs': 10, 'planks': 15}  # Approximate VQ-VAE performance

fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(critical_cats))
width = 0.35

# VQ-VAE baseline
vqvae_acc = [vqvae_baseline.get(c, 0) for c in critical_cats]
bars1 = ax.bar(x - width/2, vqvae_acc, width, label='VQ-VAE (Baseline)', color='#B0BEC5', edgecolor='white')

# Sparse Transformer
sparse_acc = [final_metrics["category_accuracy"].get(c, 0) * 100 for c in critical_cats]
bars2 = ax.bar(x + width/2, sparse_acc, width, label='Sparse Transformer', color='#4CAF50', edgecolor='white')

ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('üÜö Critical Category Comparison: VQ-VAE vs Sparse Transformer', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(critical_cats, rotation=45, ha='right')
ax.legend()
ax.set_ylim(0, 100)
ax.grid(True, axis='y', alpha=0.3)

# Add value labels
for bar, val in zip(bars1, vqvae_acc):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{val:.0f}%', 
            ha='center', va='bottom', fontsize=9, color='gray')
for bar, val in zip(bars2, sparse_acc):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{val:.0f}%', 
            ha='center', va='bottom', fontsize=9, color='#2E7D32', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

# Summary table
print("\n" + "=" * 60)
print("üìã CRITICAL CATEGORY COMPARISON")
print("=" * 60)
print(f"{'Category':<12} {'VQ-VAE':<12} {'Sparse Trans.':<12} {'Improvement':<12}")
print("-" * 60)
for cat in critical_cats:
    vq = vqvae_baseline.get(cat, 0)
    sp = final_metrics["category_accuracy"].get(cat, 0) * 100
    imp = sp - vq
    print(f"{cat:<12} {vq:>8.1f}%    {sp:>8.1f}%       +{imp:>6.1f}%")
print("=" * 60)


---
# 10. Testing - Test on a Real Structure

This section shows how to test the model on an individual Minecraft structure. The model:
1. Takes a sparse structure (positions + block embeddings)
2. Encodes it to a latent representation
3. Decodes back to predicted block embeddings
4. Matches each prediction to the nearest block in the vocabulary


In [None]:
# ============================================================
# 10.1 Test on a Single Structure
# ============================================================

def test_single_structure(model, structure_path, embeddings, all_embeddings_tensor, tok2block):
    """
    Test the model on a single structure and visualize the results.
    
    Args:
        model: Trained SparseStructureTransformer
        structure_path: Path to .h5 structure file
        embeddings: Block2Vec embeddings numpy array
        all_embeddings_tensor: Embeddings on GPU
        tok2block: Token to block name mapping
    
    Returns:
        Dictionary with reconstruction metrics
    """
    model.eval()
    
    # Load structure
    with h5py.File(structure_path, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].astype(np.int64)
    
    # Extract sparse representation
    non_air_mask = ~np.isin(structure, list(AIR_TOKENS))
    positions = np.argwhere(non_air_mask).astype(np.float32)
    block_ids = structure[non_air_mask]
    block_embeddings = embeddings[block_ids]
    
    print(f"üì¶ Structure: {structure_path}")
    print(f"   Shape: {structure.shape}")
    print(f"   Non-air blocks: {len(block_ids)}")
    print(f"   Air percentage: {(1 - len(block_ids) / structure.size) * 100:.1f}%")
    
    # Convert to tensors
    positions_t = torch.from_numpy(positions).float().unsqueeze(0).to(device)
    embeddings_t = torch.from_numpy(block_embeddings).float().unsqueeze(0).to(device)
    block_ids_t = torch.from_numpy(block_ids).long().unsqueeze(0).to(device)
    attention_mask = torch.ones(1, len(block_ids), dtype=torch.bool, device=device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(positions_t, embeddings_t, attention_mask)
        pred_embeddings = outputs["pred_embeddings"]
    
    # Predict blocks via nearest neighbor
    pred_flat = pred_embeddings.view(-1, pred_embeddings.size(-1))
    distances = torch.cdist(pred_flat, all_embeddings_tensor)
    pred_ids = distances.argmin(dim=1).cpu().numpy()
    
    # Compute metrics
    correct = (pred_ids == block_ids)
    accuracy = correct.mean()
    
    # Category breakdown
    category_correct = defaultdict(int)
    category_total = defaultdict(int)
    
    for i, (true_id, pred_id) in enumerate(zip(block_ids, pred_ids)):
        if true_id in tok2block:
            cat = get_category(tok2block[true_id])
            category_total[cat] += 1
            if true_id == pred_id:
                category_correct[cat] += 1
    
    print(f"\nüìä Results:")
    print(f"   Overall Accuracy: {accuracy:.2%}")
    print(f"   Correct blocks: {correct.sum()} / {len(block_ids)}")
    
    print(f"\n   Per-category:")
    for cat in sorted(category_total.keys()):
        if category_total[cat] > 0:
            acc = category_correct[cat] / category_total[cat]
            print(f"     {cat:<12}: {acc:.2%} ({category_correct[cat]}/{category_total[cat]})")
    
    # Show some examples of predictions
    print(f"\n   Sample predictions (first 10 blocks):")
    for i in range(min(10, len(block_ids))):
        true_name = tok2block.get(int(block_ids[i]), "unknown")
        pred_name = tok2block.get(int(pred_ids[i]), "unknown")
        status = "‚úì" if block_ids[i] == pred_ids[i] else "‚úó"
        true_short = true_name.replace("minecraft:", "").split("[")[0]
        pred_short = pred_name.replace("minecraft:", "").split("[")[0]
        print(f"     {status} True: {true_short:<20} ‚Üí Pred: {pred_short}")
    
    return {
        "accuracy": accuracy,
        "correct": correct.sum(),
        "total": len(block_ids),
        "category_accuracy": {c: category_correct[c]/category_total[c] for c in category_total},
    }


# Test on a random validation structure
val_files = list(Path(VAL_DIR).glob("*.h5"))
if val_files:
    test_file = random.choice(val_files)
    result = test_single_structure(model, test_file, all_embeddings, all_embeddings_tensor, tok2block)
else:
    print("No validation files found to test on.")


---
# 11. Save Results - Export Model and Metrics


In [None]:
# ============================================================
# 11.1 Save All Results
# ============================================================

# Prepare results dictionary
results = {
    "best_val_accuracy": best_val_acc,
    "final_val_accuracy": final_metrics["accuracy"],
    "category_accuracy": final_metrics["category_accuracy"],
    "category_total": final_metrics.get("category_total", {}),
    "training_time_minutes": total_time / 60,
    "num_epochs": EPOCHS,
    "history": history,
    "config": {
        "embed_dim": EMBED_DIM,
        "hidden_dim": HIDDEN_DIM,
        "n_encoder_layers": N_ENCODER_LAYERS,
        "n_decoder_layers": N_DECODER_LAYERS,
        "n_heads": N_HEADS,
        "num_latent_codes": NUM_LATENT_CODES,
        "vq_num_embeddings": VQ_NUM_EMBEDDINGS,
        "dropout": DROPOUT,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "weight_decay": WEIGHT_DECAY,
        "aux_weight": AUX_WEIGHT,
        "max_blocks": MAX_BLOCKS,
    },
}

# Save results JSON
with open(f"{OUTPUT_DIR}/sparse_transformer_results.json", "w") as f:
    json.dump(results, f, indent=2)

# Copy outputs to Google Drive
import shutil
drive_output = f"{DRIVE_BASE}/outputs/sparse_transformer_v1"
os.makedirs(drive_output, exist_ok=True)

for fname in ["sparse_transformer_best.pt", "sparse_transformer_results.json", 
              "training_curves.png", "category_accuracy.png", "vqvae_comparison.png"]:
    src = f"{OUTPUT_DIR}/{fname}"
    if os.path.exists(src):
        shutil.copy(src, drive_output)
        print(f"‚úì Copied {fname} to Drive")

print(f"\nüìÅ Results saved to: {drive_output}")

# Final summary
print("\n" + "=" * 60)
print("üéâ SPARSE STRUCTURE TRANSFORMER - FINAL SUMMARY")
print("=" * 60)
print(f"üìä Best Validation Accuracy: {best_val_acc:.2%}")
print(f"‚è±Ô∏è  Training Time: {total_time/60:.1f} minutes")
print(f"üî¢ Model Parameters: {num_params:,}")
print(f"\nüìà Improvement over VQ-VAE Baseline (49%):")
improvement = (best_val_acc - 0.49) / 0.49 * 100
print(f"   +{improvement:.1f}% relative improvement")
print("=" * 60)
