# Block2Vec V2: Hybrid Skip-gram + CBOW Training

## What This Notebook Does

This notebook trains an **improved version** of Block2Vec that fixes the problems we discovered in V1. After training, blocks that are **semantically similar** (like all types of planks) will have similar embeddings, not just blocks that happen to appear next to each other.

## Why V2?

V1 had serious problems that we only discovered after evaluating the trained embeddings:

| Problem | V1 Result | What It Means |
|---------|-----------|---------------|
| Category coherence | 20.4% | oak_planks neighbors weren't other planks |
| Block state consistency | 0.486 | oak_stairs[facing=north] vs [facing=south] had LOW similarity |
| Analogy accuracy | 11.1% | "oak is to oak_log as spruce is to ?" failed |

The core issue: **Skip-gram only learns co-occurrence, not semantic similarity.**

---

# Part 1: Understanding V1's Failure

## What Skip-gram Actually Learned

Skip-gram learns: "Blocks that appear NEAR each other should have SIMILAR embeddings."

This sounds reasonable, but consider:

```
In an oak house:
  oak_planks is surrounded by: oak_stairs, oak_log, glass, torch
  
In a spruce house (a different build):
  spruce_planks is surrounded by: spruce_stairs, spruce_log, glass, torch
```

**The problem:** `oak_planks` and `spruce_planks` never appear NEAR each other (they're in different builds), so skip-gram doesn't learn that they're similar!

Instead, skip-gram learned:
- `oak_planks` ≈ `oak_stairs` (they co-occur in oak builds)
- `spruce_planks` ≈ `spruce_stairs` (they co-occur in spruce builds)
- `oak_planks` ≉ `spruce_planks` (they never appear together!)

## The Block State Disaster

V1 treated every block state as a separate token:

```
Token 1234: minecraft:oak_stairs[facing=north,half=bottom]
Token 1235: minecraft:oak_stairs[facing=south,half=bottom]
Token 1236: minecraft:oak_stairs[facing=east,half=bottom]
...
```

These are ALL oak stairs, but because `[facing=north]` appears on the NORTH side of buildings and `[facing=south]` appears on the SOUTH side, they have **completely different context neighborhoods**!

Result: Some variants of the SAME block had **negative similarity** to each other.

---

# Part 2: The Two Types of Similarity

## Type 1: Co-occurrence Similarity

**"Blocks that appear NEAR each other are similar."**

This is what Skip-gram directly learns:

```
Training pair: (oak_log, oak_leaves)
Result: oak_log embedding ≈ oak_leaves embedding
```

Because they appear together in trees, skip-gram pushes their embeddings closer.

**Good for:** Learning that torches go on walls, chests go near crafting tables, ores cluster in caves.

**Bad for:** Learning that oak_planks ≈ spruce_planks (they never appear together).

## Type 2: Distributional Similarity

**"Blocks with similar NEIGHBORS are similar."**

This is a different concept:

```
oak_log's typical neighbors:   {oak_leaves, air, dirt, grass}
birch_log's typical neighbors: {birch_leaves, air, dirt, grass}
```

Notice: Both logs have **similar neighbor distributions** (leaves, air, dirt, grass). Even though `oak_log` and `birch_log` never appear NEAR each other, they appear in **similar contexts**.

Distributional similarity says: If two blocks have similar neighbors, they should have similar embeddings.

**Good for:** Learning that oak_planks ≈ spruce_planks (both have similar neighbor patterns).

## The Key Insight

We need BOTH types of similarity:
- Co-occurrence: torch ≈ wall (they go together)
- Distributional: oak_planks ≈ spruce_planks (similar contexts)

V1 only had co-occurrence. V2 adds distributional.

---

# Part 3: CBOW - Learning Distributional Similarity

## Skip-gram vs CBOW

Skip-gram and CBOW are opposites:

```
SKIP-GRAM: center → predict context
  Input:  oak_planks
  Output: Predict [oak_stairs, glass, torch, air, ...]
  
CBOW: context → predict center
  Input:  [oak_stairs, glass, torch, air, ...]
  Output: Predict oak_planks
```

## Why CBOW Captures Distributional Similarity

Consider two training examples:

```
Example 1 (oak house):
  Context: [oak_stairs, glass, torch, air]
  Center:  oak_planks
  
Example 2 (spruce house):
  Context: [spruce_stairs, glass, torch, air]
  Center:  spruce_planks
```

CBOW averages the context embeddings and predicts the center:

```
avg([oak_stairs, glass, torch, air]) → oak_planks
avg([spruce_stairs, glass, torch, air]) → spruce_planks
```

If `glass`, `torch`, and `air` are the same in both, and `oak_stairs` ≈ `spruce_stairs` (both are stairs), then:

```
avg(context1) ≈ avg(context2)
```

Since similar context averages predict both `oak_planks` and `spruce_planks`, CBOW learns that:

```
oak_planks ≈ spruce_planks
```

**This is exactly what we want!**

## The Math Behind CBOW

1. Take all context blocks: `[ctx1, ctx2, ctx3, ...]`
2. Look up their embeddings: `[emb1, emb2, emb3, ...]`
3. Average them: `context_avg = mean([emb1, emb2, emb3, ...])`
4. Use `context_avg` to predict the center block
5. Loss = how wrong was our prediction?

Blocks that appear in similar contexts get pulled together because they need to be predicted from similar context averages.

---

# Part 4: The Hybrid Approach

## Why Not Just Use CBOW?

CBOW has its own weakness: it loses **specific** co-occurrence information.

Skip-gram knows: "torch often appears next to wall" (specific pair)
CBOW knows: "torch appears in wall-like contexts" (averaged out)

For structure generation, we want both:
- Distributional: All planks are similar (for material substitution)
- Co-occurrence: Torches go on walls (for structural patterns)

## The Hybrid Loss

V2 uses both losses together:

```python
total_loss = alpha * skipgram_loss + beta * cbow_loss
```

Where:
- `alpha = 1.0` (skip-gram weight)
- `beta = 1.0` (CBOW weight)

Both losses update the SAME embedding matrices, so the embeddings learn from both signals.

## What the Model Learns

After hybrid training:

| Similarity | Source | Example |
|------------|--------|--------|
| oak_planks ≈ spruce_planks | CBOW | Similar contexts |
| oak_planks ≈ oak_stairs | Skip-gram | Co-occur in builds |
| diamond_ore ≈ emerald_ore | Both | Same caves + similar contexts |

The embeddings become richer, capturing multiple types of relationships.

---

# Part 5: Block State Collapsing

## The Problem

V1 vocabulary: 3,717 tokens
V2 vocabulary: 1,007 tokens

Where did 2,710 tokens go?

```
V1 had separate tokens for:
  oak_stairs[facing=north,half=bottom,shape=straight]  → Token 29
  oak_stairs[facing=east,half=bottom,shape=straight]   → Token 30
  oak_stairs[facing=south,half=bottom,shape=straight]  → Token 31
  oak_stairs[facing=west,half=bottom,shape=straight]   → Token 32
  ... (40 total variants!)

V2 collapses all to:
  oak_stairs → Token 42
```

## Why Collapse?

1. **They're the same block:** `oak_stairs[facing=north]` and `oak_stairs[facing=south]` are both oak stairs.

2. **States are orientation, not identity:** The `facing` property is about which way the stairs point, not what material they are.

3. **Different contexts:** North-facing stairs appear on north sides, south-facing on south sides. Skip-gram sees these as completely different blocks with different neighbors.

## What We Preserve

Some states ARE semantically meaningful:

| Block | State | Why Keep It |
|-------|-------|-------------|
| water | level=0 vs level=7 | Source vs flowing water |
| bed | part=head vs part=foot | Different parts of multi-block |
| door | half=upper vs half=lower | Different parts of door |
| snow | layers=1..8 | Height matters |

We keep these because they represent fundamentally different things, not just orientation.

## The Mapping

We create a mapping from original tokens to collapsed tokens:

```python
original_to_collapsed = {
    29: 42,   # oak_stairs[facing=north] → oak_stairs
    30: 42,   # oak_stairs[facing=east]  → oak_stairs
    31: 42,   # oak_stairs[facing=south] → oak_stairs
    32: 42,   # oak_stairs[facing=west]  → oak_stairs
    ...
}
```

During training, we convert all builds to use collapsed tokens.

---

# Part 6: New Tracking Metrics

V1 only tracked total loss. V2 tracks much more to help diagnose problems.

## Loss Components

```python
history = {
    "total_loss": [...],      # Combined loss
    "skipgram_loss": [...],   # Skip-gram component
    "cbow_loss": [...],       # CBOW component
}
```

**Why track separately?**

If skip-gram loss drops to 0.5 but CBOW stays at 2.0, something is wrong with CBOW. We can tune `alpha` and `beta` to rebalance.

## Gradient Norms

```python
history = {
    "sg_grad_norm": [...],    # Skip-gram gradient magnitude
    "cbow_grad_norm": [...],  # CBOW gradient magnitude
}
```

**Why track gradients?**

Gradients tell us how much each loss is actually updating the model. If CBOW gradients are 10x larger than skip-gram, CBOW will dominate training regardless of loss weights.

**What to look for:**
- Balanced gradients: SG and CBOW norms are similar
- Imbalanced: One is 10x the other → adjust weights

## Category Coherence

```python
history = {
    "category_coherence": [
        {"epoch": 1, "avg": 0.25, "per_cat": {...}},
        {"epoch": 5, "avg": 0.35, "per_cat": {...}},
        ...
    ]
}
```

**Why track coherence during training?**

V1 peaked at epoch 10 and started overfitting by epoch 20. We didn't know because we only looked at loss. By tracking coherence during training, we can see:

- Epoch 5: coherence 25% (improving)
- Epoch 10: coherence 45% (best!)
- Epoch 15: coherence 40% (overfitting!)

Now we know to use the epoch 10 checkpoint.

## Loss Ratio

We plot `skipgram_loss / cbow_loss` over time.

- Ratio ≈ 1.0: Balanced learning
- Ratio >> 1.0: Skip-gram not learning as fast
- Ratio << 1.0: CBOW not learning as fast

This helps diagnose if the hybrid is working or if one component is dominating.

---

# Part 7: Early Stopping and Checkpoints

## The Overfitting Problem

V1 trained for 50 epochs. Looking at the loss curve:

```
Epoch 1:  loss = 1.957
Epoch 10: loss = 1.596  ← Best!
Epoch 20: loss = 1.594  
Epoch 50: loss = 1.596  ← Slightly worse than epoch 10!
```

The loss stopped improving after epoch 10, but we trained for 40 more epochs! This is **overfitting**: the model memorizes training data instead of learning general patterns.

## V2 Solution: Early Stopping

```python
EPOCHS = 25        # Max epochs
PATIENCE = 5       # Stop if no improvement for 5 epochs
```

How it works:
1. Track the best loss seen so far
2. If current loss is better, save the model and reset patience
3. If current loss is worse, increment patience counter
4. If patience counter reaches 5, stop training

Example:
```
Epoch 1:  loss=2.0, best=2.0, patience=0 (new best!)
Epoch 2:  loss=1.8, best=1.8, patience=0 (new best!)
Epoch 3:  loss=1.9, best=1.8, patience=1 (worse)
Epoch 4:  loss=1.7, best=1.7, patience=0 (new best!)
Epoch 5:  loss=1.75, best=1.7, patience=1
Epoch 6:  loss=1.76, best=1.7, patience=2
Epoch 7:  loss=1.77, best=1.7, patience=3
Epoch 8:  loss=1.78, best=1.7, patience=4
Epoch 9:  loss=1.79, best=1.7, patience=5 → STOP!
```

We save the epoch 4 model (best loss) and stop at epoch 9.

## Checkpoints

We also save checkpoints every 5 epochs:

```
block2vec_v2_epoch5.pt
block2vec_v2_epoch10.pt
block2vec_v2_epoch15.pt
block2vec_v2_best.pt  ← Best by loss
```

**Why keep checkpoints?**

Sometimes the "best loss" model isn't the best for downstream tasks. If epoch 10 has the best category coherence but epoch 15 has the best loss, we might want epoch 10. Checkpoints let us try different epochs.

---

# Part 8: What Could Still Go Wrong

Even with all these improvements, V2 might not work perfectly. Here are potential issues we're watching for:

## Issue 1: Loss Weight Imbalance

If `beta` (CBOW weight) is too high, we might lose co-occurrence learning. If `alpha` is too high, we don't fix the problem.

**How to detect:** Look at the loss ratio plot. If it's far from 1.0, adjust weights.

## Issue 2: CBOW Gradient Dominance

CBOW averages 6 context embeddings into 1 prediction. Skip-gram makes 6 separate predictions. The gradient magnitudes might be different.

**How to detect:** Look at gradient norm plot. If CBOW >> Skip-gram, CBOW dominates.

## Issue 3: Collapsing Too Much

We collapsed `water[level=0]` and `water[level=7]` separately (source vs flowing). But what if we missed something important?

**How to detect:** Check if water blocks cluster correctly in evaluation.

## Issue 4: Smaller Vocab = Faster Overfitting

With 1,007 tokens instead of 3,717, there are fewer parameters to train. The model might memorize faster.

**How to detect:** Watch category coherence. If it peaks early then drops, we're overfitting.

## Issue 5: Semantic vs Structural Confusion

Skip-gram pulls oak_planks toward oak_stairs (structural).
CBOW pulls oak_planks toward spruce_planks (semantic).

These are opposite directions! The embedding might end up as a confused compromise.

**How to detect:** Check nearest neighbors. If oak_planks has both planks AND stairs as neighbors, the hybrid is working. If it has random blocks, the signals are conflicting.

---

# Part 9: Let's Start Training!

Now that you understand the concepts, let's implement it. The code below is extensively commented.

In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================
# These are the libraries we need:
# - torch: PyTorch, the deep learning framework
# - numpy: For numerical operations on arrays
# - h5py: For reading HDF5 files (our training data)
# - json: For reading the vocabulary file
# - matplotlib: For visualization
# - sklearn: For t-SNE and cosine similarity

import json
import random
import time
from pathlib import Path
from typing import Iterator, Optional

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import DataLoader, IterableDataset
from tqdm.notebook import tqdm

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================
# V2 configuration with changes from V1 highlighted

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"  # Original (3,717 tokens)
COLLAPSED_VOCAB_PATH = "/kaggle/input/minecraft-schematics/v2/tok2block_collapsed.json"  # NEW (1,007 tokens)
MAPPING_PATH = "/kaggle/input/minecraft-schematics/v2/original_to_collapsed.json"  # NEW
OUTPUT_DIR = "/kaggle/working"

# === Model Architecture ===
EMBEDDING_DIM = 32  # Same as V1

# === V2 Hybrid Loss Weights ===
# alpha controls skip-gram (co-occurrence)
# beta controls CBOW (distributional similarity)
ALPHA = 1.0  # Skip-gram weight
BETA = 1.0   # CBOW weight - NEW!

# === Training Hyperparameters ===
EPOCHS = 25          # V1 was 50, but peaked at 10. V2 uses 25 with early stopping
BATCH_SIZE = 4096    # Slightly smaller than V1 (8192) due to CBOW overhead
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001
PATIENCE = 5         # NEW: Early stopping patience

# === Negative Sampling ===
NUM_NEGATIVE_SAMPLES = 15  # V1 was 5. More negatives = better discrimination

# === Subsampling ===
SUBSAMPLE_THRESHOLD = 0.0001  # V1 was 0.001. More aggressive
INCLUDE_AIR = False           # NEW: Completely exclude air blocks

# === Checkpointing ===
CHECKPOINT_EVERY = 5  # Save checkpoint every N epochs
EVAL_EVERY = 5        # Evaluate category coherence every N epochs

# === Other ===
SEED = 42

print("V2 Configuration:")
print(f"  Epochs: {EPOCHS} (with early stopping, patience={PATIENCE})")
print(f"  Loss weights: alpha={ALPHA} (skip-gram), beta={BETA} (CBOW)")
print(f"  Negative samples: {NUM_NEGATIVE_SAMPLES} (V1 was 5)")
print(f"  Include air: {INCLUDE_AIR} (V1 was True with subsampling)")

In [None]:
# ============================================================
# CELL 3: Load Vocabularies
# ============================================================
# V2 uses TWO vocabularies:
# 1. Original: 3,717 tokens (what's in the H5 files)
# 2. Collapsed: 1,007 tokens (what we train on)
# Plus a mapping between them

# Original vocabulary (for reference)
with open(VOCAB_PATH, 'r') as f:
    tok2block_original = {int(k): v for k, v in json.load(f).items()}
print(f"Original vocabulary: {len(tok2block_original)} tokens")

# Collapsed vocabulary (what we train on)
with open(COLLAPSED_VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}
VOCAB_SIZE = len(tok2block)
print(f"Collapsed vocabulary: {VOCAB_SIZE} tokens")
print(f"Reduction: {len(tok2block_original) - VOCAB_SIZE} tokens removed ({100*(1-VOCAB_SIZE/len(tok2block_original)):.1f}%)")

# Mapping: original token ID → collapsed token ID
with open(MAPPING_PATH, 'r') as f:
    original_to_collapsed = {int(k): int(v) for k, v in json.load(f).items()}

# Find air token in collapsed vocabulary
AIR_TOKEN = None
for tok, name in tok2block.items():
    if name == "minecraft:air":
        AIR_TOKEN = tok
        break
print(f"\nAir token (collapsed): {AIR_TOKEN}")

# Show some examples of collapsing
print("\nExample collapses:")
shown = set()
for orig_tok, collapsed_tok in list(original_to_collapsed.items())[:500]:
    orig_name = tok2block_original.get(orig_tok, "?")
    collapsed_name = tok2block.get(collapsed_tok, "?")
    if "[" in orig_name and collapsed_name not in shown and len(shown) < 5:
        print(f"  {orig_name}")
        print(f"    → {collapsed_name}")
        shown.add(collapsed_name)

In [None]:
# ============================================================
# CELL 4: Category Coherence Evaluation Function
# ============================================================
# We define this early so we can use it during training.
# This measures: "For each planks block, are its nearest neighbors also planks?"

# Categories we'll evaluate
# These are blocks that SHOULD be similar to each other
EVAL_CATEGORIES = {
    "planks": ["oak_planks", "spruce_planks", "birch_planks", "jungle_planks", 
               "acacia_planks", "dark_oak_planks", "crimson_planks", "warped_planks"],
    "logs": ["oak_log", "spruce_log", "birch_log", "jungle_log", 
             "acacia_log", "dark_oak_log", "crimson_stem", "warped_stem"],
    "stone": ["stone", "cobblestone", "stone_bricks", "andesite", "diorite", "granite"],
    "ores": ["coal_ore", "iron_ore", "gold_ore", "diamond_ore", "emerald_ore"],
    "wool": ["white_wool", "red_wool", "blue_wool", "green_wool", "yellow_wool"],
}

def find_token_for_block(block_name, tok2block):
    """Find the token ID for a block name."""
    for tok, name in tok2block.items():
        if f"minecraft:{block_name}" == name:
            return tok
    return None

def evaluate_category_coherence(embeddings, tok2block, categories, k=5):
    """
    For each block in each category, check if its k nearest neighbors
    are also in the same category.
    
    Returns a dict of category → precision (0.0 to 1.0)
    """
    results = {}
    
    # Handle both torch tensors and numpy arrays
    if torch.is_tensor(embeddings):
        embeddings = embeddings.cpu().numpy()
    
    for cat_name, block_names in categories.items():
        # Get tokens for blocks in this category
        tokens = [find_token_for_block(b, tok2block) for b in block_names]
        tokens = [t for t in tokens if t is not None]
        
        if len(tokens) < 2:
            continue
        
        correct = 0
        total = 0
        
        for token in tokens:
            # Get k nearest neighbors
            query = embeddings[token].reshape(1, -1)
            sims = cosine_similarity(query, embeddings)[0]
            top_indices = np.argsort(sims)[::-1][1:k+1]  # Exclude self
            
            # Check if neighbors are in same category
            for idx in top_indices:
                total += 1
                if idx in tokens:
                    correct += 1
        
        if total > 0:
            results[cat_name] = correct / total
    
    return results

print("Evaluation function defined.")
print(f"Will evaluate {len(EVAL_CATEGORIES)} categories: {list(EVAL_CATEGORIES.keys())}")

---

# Part 10: The V2 Model

The key change from V1 is adding CBOW loss alongside skip-gram loss.

In [None]:
# ============================================================
# CELL 5: Block2Vec V2 Model
# ============================================================

class Block2VecV2(nn.Module):
    """
    Hybrid Skip-gram + CBOW model for block embeddings.
    
    The model has TWO embedding matrices (same as V1):
    1. center_embeddings: Used when a block is the CENTER
    2. context_embeddings: Used when a block is in the CONTEXT
    
    NEW in V2: We compute BOTH losses and combine them:
    - Skip-gram: center → predict each context
    - CBOW: average of contexts → predict center
    
    Total loss = alpha * skipgram_loss + beta * cbow_loss
    """
    
    def __init__(self, vocab_size, embedding_dim=32, alpha=1.0, beta=1.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.alpha = alpha  # Skip-gram weight
        self.beta = beta    # CBOW weight
        
        # Two embedding matrices
        self.center_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # Initialize with small random values
        init_range = 0.5 / embedding_dim
        self.center_embeddings.weight.data.uniform_(-init_range, init_range)
        self.context_embeddings.weight.data.uniform_(-init_range, init_range)
    
    def forward(self, center_ids, context_id, negative_ids, all_context_ids, context_mask):
        """
        Compute hybrid loss.
        
        Args:
            center_ids: Center block tokens [batch_size]
            context_id: ONE positive context [batch_size] (for skip-gram)
            negative_ids: Negative samples [batch_size, num_negatives]
            all_context_ids: ALL context tokens [batch_size, 6] (for CBOW)
            context_mask: Which contexts are valid [batch_size, 6]
        """
        
        # =========================================
        # SKIP-GRAM LOSS: center → predict context
        # =========================================
        
        # Get embeddings
        center_emb = self.center_embeddings(center_ids)      # [B, D]
        context_emb = self.context_embeddings(context_id)    # [B, D]
        
        # Positive score: dot product of center and context
        # We want this to be HIGH
        pos_score = torch.sum(center_emb * context_emb, dim=1)  # [B]
        pos_loss = F.logsigmoid(pos_score)  # log(sigmoid(x)) → 0 when x is large
        
        # Negative scores: center against random blocks
        # We want these to be LOW
        neg_emb = self.context_embeddings(negative_ids)  # [B, N, D]
        neg_score = torch.sum(center_emb.unsqueeze(1) * neg_emb, dim=2)  # [B, N]
        neg_loss = F.logsigmoid(-neg_score).sum(dim=1)  # log(sigmoid(-x)) → 0 when x is small
        
        # Skip-gram total (negate because we maximize log-likelihood)
        skipgram_loss = -(pos_loss + neg_loss).mean()
        
        # =========================================
        # CBOW LOSS: context average → predict center
        # =========================================
        
        # Get context embeddings for ALL neighbors
        ctx_emb = self.center_embeddings(all_context_ids)  # [B, 6, D]
        
        # Apply mask (some contexts might be invalid - out of bounds or air)
        mask_exp = context_mask.unsqueeze(-1).float()  # [B, 6, 1]
        
        # Compute masked average
        ctx_sum = (ctx_emb * mask_exp).sum(dim=1)  # [B, D]
        ctx_count = context_mask.sum(dim=1, keepdim=True).float().clamp(min=1)  # [B, 1]
        ctx_avg = ctx_sum / ctx_count  # [B, D] - this is the "context representation"
        
        # CBOW positive: context average should predict center
        # We use context_embeddings for the output (standard practice)
        center_out = self.context_embeddings(center_ids)  # [B, D]
        cbow_pos = torch.sum(ctx_avg * center_out, dim=1)  # [B]
        cbow_pos_loss = F.logsigmoid(cbow_pos)
        
        # CBOW negatives: context average should NOT predict random blocks
        cbow_neg_score = torch.sum(ctx_avg.unsqueeze(1) * neg_emb, dim=2)  # [B, N]
        cbow_neg_loss = F.logsigmoid(-cbow_neg_score).sum(dim=1)  # [B]
        
        # CBOW total
        cbow_loss = -(cbow_pos_loss + cbow_neg_loss).mean()
        
        # =========================================
        # COMBINED LOSS
        # =========================================
        total_loss = self.alpha * skipgram_loss + self.beta * cbow_loss
        
        # Return all losses for tracking
        return {
            "total_loss": total_loss,
            "skipgram_loss": skipgram_loss,
            "cbow_loss": cbow_loss,
        }
    
    def get_embeddings(self):
        """Return the learned embeddings (center embeddings)."""
        return self.center_embeddings.weight.data.clone()


# Test it
model = Block2VecV2(VOCAB_SIZE, EMBEDDING_DIM, ALPHA, BETA)
print(f"Model created!")
print(f"  Vocabulary: {VOCAB_SIZE} tokens")
print(f"  Embedding dim: {EMBEDDING_DIM}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  = {VOCAB_SIZE} × {EMBEDDING_DIM} × 2 matrices")

---

# Part 11: The V2 Dataset

The dataset handles:
1. Loading H5 files
2. Collapsing block states using our mapping
3. Extracting (center, all_contexts) for each block
4. Sampling negatives

In [None]:
# ============================================================
# CELL 6: V2 Dataset with State Collapsing
# ============================================================

class Block2VecDatasetV2(IterableDataset):
    """
    V2 Dataset changes from V1:
    1. Collapses block states during loading
    2. Returns ALL 6 contexts (for CBOW), not just one
    3. Completely excludes air blocks
    """
    
    # The 6 neighbors: up, down, left, right, front, back
    NEIGHBORS_6 = [
        (-1, 0, 0), (1, 0, 0),   # y-axis (up/down in Minecraft)
        (0, -1, 0), (0, 1, 0),   # x-axis
        (0, 0, -1), (0, 0, 1),   # z-axis
    ]
    
    def __init__(self, data_dir, vocab_size, original_to_collapsed,
                 num_negative_samples=15, subsample_threshold=0.0001,
                 air_token=0, include_air=False, seed=42):
        self.data_dir = Path(data_dir)
        self.vocab_size = vocab_size
        self.original_to_collapsed = original_to_collapsed
        self.num_negative_samples = num_negative_samples
        self.subsample_threshold = subsample_threshold
        self.air_token = air_token
        self.include_air = include_air
        self.seed = seed
        
        # Find all H5 files
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        print(f"Found {len(self.h5_files)} training files")
        
        # Will be computed on first iteration
        self._block_freqs = None
        self._negative_table = None
        self._subsample_probs = None
    
    def _collapse_build(self, build):
        """
        Convert a build from original tokens to collapsed tokens.
        
        Example:
          Input:  [29, 30, 31, 32]  (different oak_stairs orientations)
          Output: [42, 42, 42, 42]  (all collapsed to oak_stairs)
        """
        collapsed = np.zeros_like(build)
        for orig_id, collapsed_id in self.original_to_collapsed.items():
            collapsed[build == orig_id] = collapsed_id
        return collapsed
    
    def _compute_frequencies(self):
        """Count block frequencies for subsampling and negative sampling."""
        print("Computing collapsed block frequencies...")
        freqs = np.zeros(self.vocab_size, dtype=np.float64)
        
        for h5_path in tqdm(self.h5_files, desc="Counting blocks"):
            try:
                with h5py.File(h5_path, 'r') as f:
                    build = f[list(f.keys())[0]][:]
                collapsed = self._collapse_build(build)
                unique, counts = np.unique(collapsed, return_counts=True)
                for tok, count in zip(unique, counts):
                    if tok < self.vocab_size:
                        freqs[tok] += count
            except Exception:
                continue
        
        # Normalize to probabilities
        freqs /= freqs.sum()
        self._block_freqs = freqs
        
        # Build negative sampling table (frequency^0.75 weighting)
        weighted = np.power(freqs + 1e-10, 0.75)
        weighted /= weighted.sum()
        self._negative_table = np.random.choice(
            self.vocab_size, size=100_000_000, p=weighted
        )
        
        # Compute subsampling probabilities
        self._subsample_probs = np.ones(self.vocab_size, dtype=np.float32)
        for i, freq in enumerate(freqs):
            if freq > self.subsample_threshold:
                self._subsample_probs[i] = np.sqrt(self.subsample_threshold / freq)
        
        print(f"  Air frequency: {freqs[self.air_token]:.4f}")
        print(f"  Non-zero blocks: {(freqs > 0).sum()}/{self.vocab_size}")
    
    def __iter__(self):
        """Yield training examples."""
        if self._block_freqs is None:
            self._compute_frequencies()
        
        # Handle multi-worker loading
        worker_info = torch.utils.data.get_worker_info()
        files = self.h5_files
        worker_id = 0
        if worker_info:
            per_worker = len(self.h5_files) // worker_info.num_workers
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = start + per_worker if worker_id < worker_info.num_workers - 1 else len(self.h5_files)
            files = self.h5_files[start:end]
        
        rng = random.Random(self.seed + worker_id)
        
        for h5_path in files:
            try:
                with h5py.File(h5_path, 'r') as f:
                    build = f[list(f.keys())[0]][:]
            except Exception:
                continue
            
            # Collapse block states
            build = self._collapse_build(build)
            h, w, d = build.shape
            
            # Iterate through every position
            for y in range(h):
                for x in range(w):
                    for z in range(d):
                        center = int(build[y, x, z])
                        
                        # Skip air (V2 excludes air entirely)
                        if not self.include_air and center == self.air_token:
                            continue
                        
                        # Subsampling check (skip frequent blocks sometimes)
                        if rng.random() >= self._subsample_probs[center]:
                            continue
                        
                        # Get ALL 6 neighbors (for CBOW)
                        all_contexts = []
                        context_mask = []
                        
                        for dy, dx, dz in self.NEIGHBORS_6:
                            ny, nx, nz = y + dy, x + dx, z + dz
                            
                            if 0 <= ny < h and 0 <= nx < w and 0 <= nz < d:
                                ctx = int(build[ny, nx, nz])
                                if not self.include_air and ctx == self.air_token:
                                    # Air context - mark as invalid
                                    all_contexts.append(0)
                                    context_mask.append(False)
                                else:
                                    all_contexts.append(ctx)
                                    context_mask.append(True)
                            else:
                                # Out of bounds - mark as invalid
                                all_contexts.append(0)
                                context_mask.append(False)
                        
                        # Skip if no valid contexts
                        if not any(context_mask):
                            continue
                        
                        # Sample negatives
                        neg_indices = rng.sample(range(len(self._negative_table)), self.num_negative_samples)
                        negatives = self._negative_table[neg_indices]
                        
                        # Yield one example PER VALID CONTEXT (for skip-gram)
                        # But include ALL contexts (for CBOW)
                        for ctx, valid in zip(all_contexts, context_mask):
                            if valid:
                                yield (
                                    center,
                                    ctx,  # Single context for skip-gram
                                    negatives,
                                    np.array(all_contexts, dtype=np.int64),  # All contexts for CBOW
                                    np.array(context_mask, dtype=bool),
                                )


def collate_fn_v2(batch):
    """Convert list of examples to tensors."""
    centers = torch.tensor([b[0] for b in batch], dtype=torch.long)
    contexts = torch.tensor([b[1] for b in batch], dtype=torch.long)
    negatives = torch.tensor(np.stack([b[2] for b in batch]), dtype=torch.long)
    all_contexts = torch.tensor(np.stack([b[3] for b in batch]), dtype=torch.long)
    masks = torch.tensor(np.stack([b[4] for b in batch]), dtype=torch.bool)
    return centers, contexts, negatives, all_contexts, masks


print("V2 Dataset class defined!")

In [None]:
# ============================================================
# CELL 7: Create DataLoader
# ============================================================

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

train_dataset = Block2VecDatasetV2(
    data_dir=DATA_DIR,
    vocab_size=VOCAB_SIZE,
    original_to_collapsed=original_to_collapsed,
    num_negative_samples=NUM_NEGATIVE_SAMPLES,
    subsample_threshold=SUBSAMPLE_THRESHOLD,
    air_token=AIR_TOKEN,
    include_air=INCLUDE_AIR,
    seed=SEED,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn_v2,
    pin_memory=(device == "cuda"),
)

print(f"DataLoader ready with batch size {BATCH_SIZE}")

In [None]:
# ============================================================
# CELL 8: Setup Model and Optimizer
# ============================================================

model = Block2VecV2(
    vocab_size=VOCAB_SIZE,
    embedding_dim=EMBEDDING_DIM,
    alpha=ALPHA,
    beta=BETA
).to(device)

optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

# History tracking - V2 tracks much more than V1!
history = {
    "total_loss": [],
    "skipgram_loss": [],    # NEW: track SG separately
    "cbow_loss": [],        # NEW: track CBOW separately
    "sg_grad_norm": [],     # NEW: gradient magnitudes
    "cbow_grad_norm": [],
    "category_coherence": [],  # NEW: quality metric during training
    "best_epoch": 0,
}

print(f"Model on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")

---

# Part 12: Training Loop

The training loop is similar to V1 but with:
- Separate loss tracking
- Gradient norm monitoring
- Category coherence evaluation
- Early stopping

In [None]:
# ============================================================
# CELL 9: Training Loop with Enhanced Tracking
# ============================================================

print("="*60)
print("Starting V2 Training (Hybrid Skip-gram + CBOW)")
print(f"Epochs: {EPOCHS}, Early stopping patience: {PATIENCE}")
print("="*60)

start_time = time.time()
best_loss = float('inf')
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    epoch_stats = {"total": 0, "sg": 0, "cbow": 0, "sg_grad": 0, "cbow_grad": 0}
    num_batches = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch in pbar:
        # Unpack and move to device
        center_ids, context_ids, neg_ids, all_ctx, ctx_mask = [b.to(device) for b in batch]
        
        # Forward pass
        losses = model(center_ids, context_ids, neg_ids, all_ctx, ctx_mask)
        
        # Backward pass
        optimizer.zero_grad()
        losses["total_loss"].backward()
        
        # Compute gradient norms BEFORE clipping
        # This tells us how much each component is contributing
        sg_grad = 0
        cbow_grad = 0
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                # center_embeddings get gradient from both SG and CBOW
                # context_embeddings get gradient from both SG and CBOW
                if 'center' in name:
                    sg_grad += grad_norm
                else:
                    cbow_grad += grad_norm
        
        # Gradient clipping (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        
        # Track statistics
        epoch_stats["total"] += losses["total_loss"].item()
        epoch_stats["sg"] += losses["skipgram_loss"].item()
        epoch_stats["cbow"] += losses["cbow_loss"].item()
        epoch_stats["sg_grad"] += sg_grad
        epoch_stats["cbow_grad"] += cbow_grad
        num_batches += 1
        
        pbar.set_postfix({
            "loss": f"{losses['total_loss'].item():.3f}",
            "sg": f"{losses['skipgram_loss'].item():.3f}",
            "cbow": f"{losses['cbow_loss'].item():.3f}",
        })
    
    # Compute epoch averages
    avg_total = epoch_stats["total"] / num_batches
    avg_sg = epoch_stats["sg"] / num_batches
    avg_cbow = epoch_stats["cbow"] / num_batches
    avg_sg_grad = epoch_stats["sg_grad"] / num_batches
    avg_cbow_grad = epoch_stats["cbow_grad"] / num_batches
    
    # Store in history
    history["total_loss"].append(avg_total)
    history["skipgram_loss"].append(avg_sg)
    history["cbow_loss"].append(avg_cbow)
    history["sg_grad_norm"].append(avg_sg_grad)
    history["cbow_grad_norm"].append(avg_cbow_grad)
    
    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1}: loss={avg_total:.4f}, sg={avg_sg:.4f}, cbow={avg_cbow:.4f}, "
          f"sg_grad={avg_sg_grad:.2f}, cbow_grad={avg_cbow_grad:.2f}, time={elapsed:.0f}s")
    
    # Evaluate category coherence periodically
    if (epoch + 1) % EVAL_EVERY == 0 or epoch == 0:
        model.eval()
        with torch.no_grad():
            emb = model.get_embeddings().cpu().numpy()
            coherence = evaluate_category_coherence(emb, tok2block, EVAL_CATEGORIES)
            avg_coh = np.mean(list(coherence.values()))
            history["category_coherence"].append({
                "epoch": epoch + 1,
                "avg": avg_coh,
                "per_cat": coherence
            })
            print(f"  Category coherence: {avg_coh*100:.1f}% (V1 was 20.4%)")
        model.train()
    
    # Save checkpoint
    if (epoch + 1) % CHECKPOINT_EVERY == 0:
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/block2vec_v2_epoch{epoch+1}.pt")
        print(f"  Checkpoint saved: epoch {epoch+1}")
    
    # Early stopping check
    if avg_total < best_loss:
        best_loss = avg_total
        history["best_epoch"] = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/block2vec_v2_best.pt")
        patience_counter = 0
        print(f"  -> New best model!")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {PATIENCE} epochs)")
            break

print("\n" + "="*60)
print(f"Training complete in {time.time() - start_time:.0f}s")
print(f"Best epoch: {history['best_epoch']}")
print("="*60)

In [None]:
# ============================================================
# CELL 10: Save Results
# ============================================================

# Load best model
model.load_state_dict(torch.load(f"{OUTPUT_DIR}/block2vec_v2_best.pt"))
model.eval()

# Save embeddings
embeddings = model.get_embeddings().cpu().numpy()
np.save(f"{OUTPUT_DIR}/block_embeddings_v2.npy", embeddings)
print(f"Embeddings saved: {embeddings.shape}")

# Save training history
with open(f"{OUTPUT_DIR}/training_history_v2.json", 'w') as f:
    json.dump(history, f, indent=2)
print("Training history saved")

# Save vocabulary
with open(f"{OUTPUT_DIR}/tok2block_collapsed.json", 'w') as f:
    json.dump({str(k): v for k, v in tok2block.items()}, f, indent=2)
print("Vocabulary saved")

In [None]:
# ============================================================
# CELL 11: Plot Training Analysis
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: All losses
axes[0,0].plot(history["total_loss"], 'b-', label='Total', linewidth=2)
axes[0,0].plot(history["skipgram_loss"], 'g--', label='Skip-gram', linewidth=1.5)
axes[0,0].plot(history["cbow_loss"], 'r--', label='CBOW', linewidth=1.5)
axes[0,0].axvline(history["best_epoch"]-1, color='k', linestyle=':', 
                  label=f'Best (epoch {history["best_epoch"]})')
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].set_title('Training Losses')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

# Plot 2: Gradient norms
axes[0,1].plot(history["sg_grad_norm"], 'g-', label='Center embeddings', linewidth=2)
axes[0,1].plot(history["cbow_grad_norm"], 'r-', label='Context embeddings', linewidth=2)
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Gradient Norm')
axes[0,1].set_title('Gradient Norms (detect imbalance)')
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

# Plot 3: Category coherence
if history["category_coherence"]:
    epochs = [c["epoch"] for c in history["category_coherence"]]
    coherences = [c["avg"] * 100 for c in history["category_coherence"]]
    axes[1,0].plot(epochs, coherences, 'b-o', linewidth=2, markersize=8)
    axes[1,0].axhline(20.4, color='r', linestyle='--', label='V1 baseline (20.4%)')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].set_ylabel('Category Coherence (%)')
    axes[1,0].set_title('Category Coherence Over Training')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)

# Plot 4: Loss ratio
loss_ratio = [sg/cbow if cbow > 0 else 1 for sg, cbow in 
              zip(history["skipgram_loss"], history["cbow_loss"])]
axes[1,1].plot(loss_ratio, 'purple', linewidth=2)
axes[1,1].axhline(1.0, color='k', linestyle='--', label='Balanced (ratio=1)')
axes[1,1].set_xlabel('Epoch')
axes[1,1].set_ylabel('Skip-gram / CBOW Loss')
axes[1,1].set_title('Loss Ratio (detect if one dominates)')
axes[1,1].legend()
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_analysis_v2.png", dpi=150)
plt.show()

In [None]:
# ============================================================
# CELL 12: t-SNE Visualization
# ============================================================

import re

def get_block_category(block_name):
    """Categorize a block for coloring in t-SNE."""
    name = block_name.replace("minecraft:", "")
    name = re.sub(r"\[.*\]", "", name)
    
    if any(x in name for x in ["planks", "log", "wood", "fence", "door"]):
        return "wood"
    elif any(x in name for x in ["stone", "cobble", "brick", "andesite", "diorite", "granite"]):
        return "stone"
    elif "ore" in name or any(x in name for x in ["diamond", "gold", "iron", "coal", "emerald"]):
        return "ore/mineral"
    elif "glass" in name:
        return "glass"
    elif "wool" in name or "carpet" in name:
        return "wool"
    elif "concrete" in name:
        return "concrete"
    elif "leaves" in name:
        return "leaves"
    else:
        return "other"

# Sample for visualization
sample_size = min(500, VOCAB_SIZE)
indices = np.random.choice(VOCAB_SIZE, sample_size, replace=False)
sampled_embeddings = embeddings[indices]

categories = [get_block_category(tok2block.get(i, "unknown")) for i in indices]
unique_cats = list(set(categories))

print(f"Running t-SNE on {len(indices)} blocks...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
coords = tsne.fit_transform(sampled_embeddings)

plt.figure(figsize=(14, 10))
cmap = plt.cm.get_cmap('tab10', len(unique_cats))

for i, cat in enumerate(unique_cats):
    mask = [c == cat for c in categories]
    plt.scatter(coords[mask, 0], coords[mask, 1], c=[cmap(i)], label=cat, alpha=0.6, s=30)

plt.title('Block2Vec V2 Embeddings (t-SNE)', fontsize=14)
plt.xlabel('t-SNE dimension 1')
plt.ylabel('t-SNE dimension 2')
plt.legend(loc='upper right')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/tsne_embeddings_v2.png", dpi=150)
plt.show()

In [None]:
# ============================================================
# CELL 13: Final Evaluation
# ============================================================

print("\nFinal Category Coherence:")
print("="*50)

coherence = evaluate_category_coherence(embeddings, tok2block, EVAL_CATEGORIES)
for cat, prec in coherence.items():
    status = "GOOD" if prec > 0.5 else "POOR" if prec > 0.3 else "BAD"
    print(f"  {cat:15} {prec*100:5.1f}% [{status}]")

avg_coherence = np.mean(list(coherence.values()))
print(f"\n  V2 Average: {avg_coherence*100:.1f}%")
print(f"  V1 Average: 20.4%")
print(f"  Improvement: {avg_coherence*100 - 20.4:+.1f}%")

In [None]:
# ============================================================
# CELL 14: Nearest Neighbor Examples
# ============================================================

def find_similar_blocks(block_name, embeddings, tok2block, top_k=5):
    """Find and print the most similar blocks."""
    token = find_token_for_block(block_name, tok2block)
    if token is None:
        print(f"Block '{block_name}' not found")
        return
    
    query = embeddings[token].reshape(1, -1)
    sims = cosine_similarity(query, embeddings)[0]
    top_indices = np.argsort(sims)[::-1]
    
    print(f"\nNearest neighbors for '{tok2block[token]}':")
    count = 0
    for idx in top_indices:
        if idx != token:
            name = tok2block.get(idx, "unknown")
            short = name.replace("minecraft:", "")[:40]
            print(f"  {count+1}. {short:40} sim={sims[idx]:.3f}")
            count += 1
            if count >= top_k:
                break

# Test key blocks
test_blocks = ["oak_planks", "stone", "diamond_ore", "white_wool", "glass"]
for block in test_blocks:
    find_similar_blocks(block, embeddings, tok2block)

In [None]:
# ============================================================
# CELL 15: Summary
# ============================================================

print("="*60)
print("BLOCK2VEC V2 TRAINING COMPLETE")
print("="*60)

print(f"\nModel:")
print(f"  Vocabulary: {VOCAB_SIZE} collapsed tokens (was 3,717)")
print(f"  Embedding dim: {EMBEDDING_DIM}")
print(f"  Loss weights: alpha={ALPHA} (skip-gram), beta={BETA} (CBOW)")

print(f"\nTraining:")
print(f"  Epochs completed: {len(history['total_loss'])}/{EPOCHS}")
print(f"  Best epoch: {history['best_epoch']}")
print(f"  Best loss: {best_loss:.4f}")

print(f"\nEvaluation:")
print(f"  V2 category coherence: {avg_coherence*100:.1f}%")
print(f"  V1 category coherence: 20.4%")
improvement = "IMPROVED" if avg_coherence > 0.204 else "NO IMPROVEMENT"
print(f"  Result: {improvement}")

print(f"\nOutput files:")
print(f"  - block_embeddings_v2.npy (use this for VQ-VAE)")
print(f"  - block2vec_v2_best.pt")
print(f"  - training_history_v2.json (for debugging)")
print(f"  - training_analysis_v2.png")
print(f"  - tsne_embeddings_v2.png")

print("\n" + "="*60)