# üè∞ H-MOLQD: Hybrid Masked-diffusion Optimization with Logical Quality Diversity

**GPU-Accelerated Training Notebook for Zelda Dungeon Generation**

## Architecture Overview (6 Blocks)

```
VGLC Data (18 dungeons √ó .txt grids + .dot mission graphs)
    ‚îÇ
    ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Block I  ‚Äî VQ-VAE: 44 tiles ‚Üí 512 codebook ‚Üí latent_dim=64       ‚îÇ
‚îÇ  Block II ‚Äî Dual-Stream Condition Encoder (GATv2Conv GNN + Local)  ‚îÇ
‚îÇ  Block III‚Äî Latent Diffusion (U-Net, cosine schedule, DDIM)        ‚îÇ
‚îÇ  Block IV ‚Äî LogicNet (differentiable solvability + key-lock check) ‚îÇ
‚îÇ  Block V  ‚Äî WFC Refiner (inference-only post-processing)           ‚îÇ
‚îÇ  Block VI ‚Äî Cognitive Validator (A* solver + CBS agent)            ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

## Training Pipeline
- **Stage 1**: VQ-VAE pretraining (reconstructing dungeon grids)
- **Stage 2**: Latent Diffusion with **real .dot graph conditioning** + LogicNet guidance

## How to Use on Kaggle
1. Upload your project repo as a **Kaggle Dataset**
2. Enable **GPU accelerator** (T4/P100)
3. Run all cells sequentially
4. Download checkpoints from outputs

In [None]:
# ============================================================================
# CELL 1: Environment Setup
# ============================================================================
import os, sys, subprocess
from pathlib import Path

IS_KAGGLE = os.path.exists('/kaggle/working')
IS_COLAB = 'google.colab' in sys.modules
ENV_NAME = "Kaggle" if IS_KAGGLE else ("Colab" if IS_COLAB else "Local")
print(f"üñ•Ô∏è  Environment: {ENV_NAME}")

# --- Find project root ---
if IS_KAGGLE:
    WORKING_DIR = Path('/kaggle/working')
    candidates = [
        Path('/kaggle/input/kltn'), Path('/kaggle/input/hmolqd'),
        Path('/kaggle/input/kltn/KLTN'), Path('/kaggle/working/KLTN'),
    ]
elif IS_COLAB:
    WORKING_DIR = Path('/content')
    candidates = [Path('/content/KLTN'), Path('/content/drive/MyDrive/KLTN')]
else:
    WORKING_DIR = Path('.').resolve()
    candidates = [Path('.').resolve(), Path('..').resolve()]

PROJECT_ROOT = None
for c in candidates:
    if (c / 'src' / 'train_vqvae.py').exists():
        PROJECT_ROOT = c
        break

if PROJECT_ROOT is None and not (not IS_KAGGLE and not IS_COLAB):
    clone_target = WORKING_DIR / 'KLTN'
    if not clone_target.exists():
        subprocess.run(['git', 'clone', 'https://github.com/YOUR_USERNAME/KLTN.git',
                        str(clone_target)], check=False)
    if (clone_target / 'src' / 'train_vqvae.py').exists():
        PROJECT_ROOT = clone_target

if PROJECT_ROOT is None:
    raise FileNotFoundError("‚ùå Could not find project root! Upload repo as Kaggle dataset.")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
os.chdir(PROJECT_ROOT)

CHECKPOINT_DIR = (WORKING_DIR if IS_KAGGLE else PROJECT_ROOT) / 'checkpoints'
OUTPUT_DIR = (WORKING_DIR if IS_KAGGLE else PROJECT_ROOT) / 'output'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DATA_DIR = PROJECT_ROOT / 'data' / 'The Legend of Zelda'
if not DATA_DIR.exists():
    DATA_DIR = PROJECT_ROOT / 'Data' / 'The Legend of Zelda'

print(f"‚úÖ Project root: {PROJECT_ROOT}")
print(f"üìÅ Data dir: {DATA_DIR} (exists={DATA_DIR.exists()})")
print(f"üíæ Checkpoints: {CHECKPOINT_DIR}")

# Install dependencies
req = PROJECT_ROOT / 'requirements-hmolqd.txt'
if req.exists():
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-r', str(req)], check=False)
print("üéâ Setup complete!")

In [None]:
# ============================================================================
# CELL 3: Multi-Path Checkpoint Discovery System
# ============================================================================
# ‚ö†Ô∏è  IMPORTANT: This cell MUST be executed BEFORE training cells (VQ-VAE/Diffusion)
# If you get "unexpected keyword argument" errors, restart kernel and run all cells in order
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Dict
from pathlib import Path
import torch

@dataclass
class CheckpointInfo:
    """Metadata about a discovered checkpoint file."""
    path: Path
    source_type: str  # 'working', 'input_dataset', 'notebook_output', 'direct_file'
    source_location: str  # Human-readable source directory
    file_size_mb: float
    modified_time: float
    
    # Checkpoint content (extracted after validation)
    epoch: Optional[int] = None
    accuracy: Optional[float] = None
    solvability: Optional[float] = None
    loss: Optional[float] = None
    
    # Validation status
    is_valid: bool = False
    validation_msg: str = ""

def find_checkpoint_locations() -> Dict[str, List[Path]]:
    """
    Discover all potential checkpoint directories across Kaggle environment.
    
    Returns:
        Dict with keys: 'working', 'input_datasets', 'notebook_outputs', 'direct_files'
    """
    locations = {
        'working': [],
        'input_datasets': [],
        'notebook_outputs': [],
        'direct_files': []
    }
    
    # Priority 1: Working directory (writable, current run)
    working_ckpt = Path('/kaggle/working/checkpoints')
    if working_ckpt.exists():
        locations['working'].append(working_ckpt)
    
    # Priority 2: Input datasets (read-only, uploaded runs)
    input_base = Path('/kaggle/input')
    if input_base.exists():
        for dataset_dir in input_base.iterdir():
            if not dataset_dir.is_dir() or dataset_dir.name == 'notebooks':
                continue
            
            # Check common checkpoint subdirectories
            for subdir_name in ['checkpoints', 'output', 'outputs']:
                ckpt_subdir = dataset_dir / subdir_name
                if ckpt_subdir.exists() and ckpt_subdir.is_dir():
                    locations['input_datasets'].append(ckpt_subdir)
            
            # Also check dataset root for flat checkpoint structure
            if any((dataset_dir / f).suffix == '.pth' for f in dataset_dir.iterdir() if f.is_file()):
                locations['direct_files'].append(dataset_dir)
    
    # Priority 3: Notebook outputs (read-only, /kaggle/input/notebooks/user/notebook-name/)
    notebook_base = Path('/kaggle/input/notebooks')
    if notebook_base.exists():
        for user_dir in notebook_base.iterdir():
            if not user_dir.is_dir():
                continue
            for notebook_dir in user_dir.iterdir():
                if not notebook_dir.is_dir():
                    continue
                
                # Check for checkpoint subdirectories
                for subdir_name in ['checkpoints', 'output', 'outputs']:
                    ckpt_subdir = notebook_dir / subdir_name
                    if ckpt_subdir.exists() and ckpt_subdir.is_dir():
                        locations['notebook_outputs'].append(ckpt_subdir)
                
                # Check notebook root for flat structure
                if any((notebook_dir / f).suffix == '.pth' for f in notebook_dir.iterdir() if f.is_file()):
                    locations['direct_files'].append(notebook_dir)
    
    return locations

def validate_and_load_checkpoint(
    ckpt_path: Path,
    required_keys: Optional[List[str]] = None,
    source_type: str = '',
    source_location: str = ''
) -> CheckpointInfo:
    """
    Validate checkpoint file and extract metadata.
    
    Args:
        ckpt_path: Path to checkpoint file
        required_keys: Keys that must be present in checkpoint
        source_type: Type of source ('working', 'input_dataset', etc.)
        source_location: Human-readable source directory
        
    Returns:
        CheckpointInfo with all extracted metadata
    """
    info = CheckpointInfo(
        path=ckpt_path,
        source_type=source_type,
        source_location=source_location,
        file_size_mb=ckpt_path.stat().st_size / (1024**2),
        modified_time=ckpt_path.stat().st_mtime
    )
    
    # Try to load and validate checkpoint
    try:
        ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
        
        # Validate required keys
        if required_keys:
            missing = [k for k in required_keys if k not in ckpt]
            if missing:
                info.is_valid = False
                info.validation_msg = f"Missing keys: {missing}"
                return info
        
        # Extract metadata
        info.epoch = ckpt.get('epoch', None)
        info.accuracy = ckpt.get('accuracy', None)
        info.solvability = ckpt.get('val_solvability', ckpt.get('solvability', None))
        info.loss = ckpt.get('loss', None)
        
        # Sanity checks
        if info.epoch is not None and info.epoch < 0:
            info.is_valid = False
            info.validation_msg = "Invalid epoch (< 0)"
            return info
        
        if info.accuracy is not None and not (0 <= info.accuracy <= 1):
            info.is_valid = False
            info.validation_msg = f"Invalid accuracy ({info.accuracy})"
            return info
        
        info.is_valid = True
        info.validation_msg = "Valid"
        return info
        
    except Exception as e:
        info.is_valid = False
        info.validation_msg = f"Load failed: {str(e)[:100]}"
        return info

def find_best_checkpoint_across_sources(
    checkpoint_filename: str = 'vqvae_pretrained.pth',
    required_keys: Optional[List[str]] = None,
    prefer_metric: Optional[str] = None  # 'accuracy', 'solvability', 'epoch'
) -> Tuple[Optional[Path], Optional[CheckpointInfo], List[CheckpointInfo]]:
    """
    Find the best checkpoint across ALL available sources in Kaggle.
    
    Args:
        checkpoint_filename: Name of checkpoint to search for (e.g., 'vqvae_pretrained.pth')
        required_keys: Keys that must be present in checkpoint
        prefer_metric: Metric to prioritize when multiple checkpoints found
        
    Returns:
        Tuple of (best_path, best_info, all_valid_checkpoints)
        
    Priority logic:
        1. If prefer_metric specified: choose checkpoint with best metric value
        2. Otherwise: prefer working > input_datasets > notebook_outputs > direct_files
        3. Within same priority level: choose most recent (by modified time)
    """
    locations = find_checkpoint_locations()
    all_candidates: List[CheckpointInfo] = []
    
    # Search priority order
    search_order = [
        ('working', locations['working']),
        ('input_datasets', locations['input_datasets']),
        ('notebook_outputs', locations['notebook_outputs']),
        ('direct_files', locations['direct_files'])
    ]
    
    print(f"üîç Searching for '{checkpoint_filename}' across all sources...")
    
    for priority_idx, (source_type, dirs) in enumerate(search_order):
        if not dirs:
            continue
        
        for checkpoint_dir in dirs:
            ckpt_path = checkpoint_dir / checkpoint_filename
            if not ckpt_path.exists():
                continue
            
            # Validate checkpoint
            info = validate_and_load_checkpoint(
                ckpt_path,
                required_keys=required_keys,
                source_type=source_type,
                source_location=str(checkpoint_dir)
            )
            
            # Display found checkpoint
            status_icon = "‚úÖ" if info.is_valid else "‚ùå"
            print(f"   {status_icon} [{source_type:15s}] {ckpt_path}")
            
            if info.is_valid:
                metrics_parts = [f"{info.file_size_mb:.1f}MB"]
                if info.epoch is not None:
                    metrics_parts.append(f"epoch={info.epoch}")
                if info.accuracy is not None:
                    metrics_parts.append(f"acc={info.accuracy:.3f}")
                if info.solvability is not None:
                    metrics_parts.append(f"solv={info.solvability:.3f}")
                print(f"      {', '.join(metrics_parts)} - {info.validation_msg}")
                
                all_candidates.append((info, priority_idx))
            else:
                print(f"      Invalid: {info.validation_msg}")
    
    # No valid checkpoints found
    if not all_candidates:
        print("   ‚ùå No valid checkpoints found")
        return None, None, []
    
    # Select best checkpoint
    if prefer_metric:
        # Choose by best metric value
        if prefer_metric == 'accuracy':
            best_info, _ = max(all_candidates, key=lambda x: (x[0].accuracy or 0, -x[1]))
            metric_display = f"accuracy={best_info.accuracy:.3f}"
        elif prefer_metric == 'solvability':
            best_info, _ = max(all_candidates, key=lambda x: (x[0].solvability or 0, -x[1]))
            metric_display = f"solvability={best_info.solvability:.3f}"
        elif prefer_metric == 'epoch':
            best_info, _ = max(all_candidates, key=lambda x: (x[0].epoch or 0, -x[1]))
            metric_display = f"epoch={best_info.epoch}"
        else:
            # Unknown metric - fall back to priority
            best_info, _ = min(all_candidates, key=lambda x: (x[1], -x[0].modified_time))
            metric_display = "priority"
        
        print(f"\nüéØ Selected checkpoint by best {prefer_metric}: {metric_display}")
    else:
        # Choose by priority (working > input > notebook > direct)
        best_info, _ = min(all_candidates, key=lambda x: (x[1], -x[0].modified_time))
        print(f"\nüéØ Selected checkpoint by priority: {best_info.source_type}")
    
    print(f"   üìÇ {best_info.path}")
    print(f"   üìä Epoch {best_info.epoch or 'N/A'}", end='')
    if best_info.accuracy:
        print(f", accuracy={best_info.accuracy:.3f}", end='')
    if best_info.solvability:
        print(f", solvability={best_info.solvability:.3f}", end='')
    print()
    
    all_valid_infos = [info for info, _ in all_candidates]
    return best_info.path, best_info, all_valid_infos

def copy_checkpoint_to_working(
    source_path: Path,
    target_filename: str,
    working_dir: Path = Path('/kaggle/working/checkpoints')
) -> Path:
    """
    Copy checkpoint from read-only input to writable working directory.
    
    Args:
        source_path: Source checkpoint path (may be read-only)
        target_filename: Target filename in working directory
        working_dir: Working directory path (default: /kaggle/working/checkpoints)
        
    Returns:
        Path to copied checkpoint in working directory
    """
    import shutil
    
    # Create working directory if needed
    working_dir.mkdir(parents=True, exist_ok=True)
    target_path = working_dir / target_filename
    
    # If source is already in working directory, no copy needed
    if str(source_path).startswith(str(working_dir)):
        return source_path
    
    # Copy file
    shutil.copy2(source_path, target_path)
    print(f"   üìã Copied to working: {target_path}")
    
    return target_path

def discover_and_validate_all_checkpoints(show_invalid: bool = False) -> Dict[str, List[CheckpointInfo]]:
    """
    Comprehensive scan of all checkpoint files across all sources.
    Useful for debugging and understanding what's available.
    
    Args:
        show_invalid: Whether to print invalid checkpoints
        
    Returns:
        Dict mapping checkpoint filename to list of CheckpointInfo objects
    """
    locations = find_checkpoint_locations()
    all_checkpoints: Dict[str, List[CheckpointInfo]] = {}
    
    print("üîé Comprehensive checkpoint scan:")
    print("=" * 70)
    
    search_order = [
        ('working', locations['working']),
        ('input_datasets', locations['input_datasets']),
        ('notebook_outputs', locations['notebook_outputs']),
        ('direct_files', locations['direct_files'])
    ]
    
    for source_type, dirs in search_order:
        if not dirs:
            continue
        
        print(f"\nüìÅ {source_type.upper().replace('_', ' ')}:")
        for checkpoint_dir in dirs:
            print(f"   {checkpoint_dir}")
            
            # Find all .pth files
            pth_files = list(checkpoint_dir.glob('*.pth'))
            if not pth_files:
                print("      (no .pth files found)")
                continue
            
            for pth_file in pth_files:
                info = validate_and_load_checkpoint(
                    pth_file,
                    required_keys=None,  # No requirements for discovery
                    source_type=source_type,
                    source_location=str(checkpoint_dir)
                )
                
                if info.is_valid or show_invalid:
                    status = "‚úÖ" if info.is_valid else "‚ùå"
                    metrics = []
                    if info.epoch is not None:
                        metrics.append(f"epoch={info.epoch}")
                    if info.accuracy is not None:
                        metrics.append(f"acc={info.accuracy:.3f}")
                    if info.solvability is not None:
                        metrics.append(f"solv={info.solvability:.3f}")
                    
                    metric_str = f" ({', '.join(metrics)})" if metrics else ""
                    print(f"      {status} {pth_file.name}{metric_str}")
                    if not info.is_valid:
                        print(f"         ‚ö†Ô∏è  {info.validation_msg}")
                
                # Add to results
                filename = pth_file.name
                if filename not in all_checkpoints:
                    all_checkpoints[filename] = []
                all_checkpoints[filename].append(info)
    
    print(f"\nüìä Summary: Found {sum(len(v) for v in all_checkpoints.values())} checkpoint files")
    return all_checkpoints

print("‚úÖ Multi-path checkpoint discovery system loaded")
print("   Functions available:")
print("   - find_best_checkpoint_across_sources()")
print("   - copy_checkpoint_to_working()")
print("   - discover_and_validate_all_checkpoints()")

---
## üîÑ Multi-Path Checkpoint Discovery System

**Purpose**: Automatically find and resume from checkpoints across multiple Kaggle runs.

### How It Works

The system searches **ALL** possible checkpoint locations:
1. `/kaggle/working/checkpoints/` - Current run (writable)
2. `/kaggle/input/*/checkpoints/` - Previous runs uploaded as datasets (read-only)
3. `/kaggle/input/notebooks/*/*/checkpoints/` - Auto-versioned notebook outputs

### User Workflow (Zero Configuration!)

**First Run (0-9 hours)**:
```
Start notebook ‚Üí Train ‚Üí Download outputs ‚Üí Upload as Kaggle dataset
```

**Resume Run (9+ hours)**:
```
Add dataset to inputs ‚Üí Run notebook ‚Üí Automatically resumes from best checkpoint!
```

### Key Features

‚úÖ **Automatic Discovery** - Finds checkpoints from ANY previous run  
‚úÖ **Intelligent Selection** - Chooses best checkpoint by accuracy or epoch  
‚úÖ **Validation** - Checks checkpoint integrity before loading  
‚úÖ **Auto-Copy** - Handles read-only ‚Üí writable directory  
‚úÖ **Detailed Logging** - Shows exactly what was found and why  
‚úÖ **Error Recovery** - Falls back to fresh training if no valid checkpoint  

### Example Output

```
üîç Searching for checkpoints across all sources...
   ‚úÖ [working        ] working/checkpoints/vqvae_pretrained.pth
      23.4MB, epoch=50, acc=0.875 - Valid
   ‚úÖ [input_datasets ] input/hmolqd-run1/checkpoints/vqvae_pretrained.pth
      23.1MB, epoch=45, acc=0.860 - Valid

üéØ Selected checkpoint by best accuracy:
   üìÇ working/checkpoints/vqvae_pretrained.pth
   üìä Epoch 50, accuracy=0.875
```

### Troubleshooting

**Checkpoint not found?**
```python
discover_and_validate_all_checkpoints(show_invalid=True)
```

### Documentation

- üìñ **Full Guide**: [`docs/KAGGLE_CHECKPOINT_RESUME_GUIDE.md`](../docs/KAGGLE_CHECKPOINT_RESUME_GUIDE.md)
- üìÑ **Quick Ref**: [`docs/KAGGLE_CHECKPOINT_QUICK_REF.md`](../docs/KAGGLE_CHECKPOINT_QUICK_REF.md)
- üìä **Visual Guide**: [`docs/KAGGLE_CHECKPOINT_VISUAL_GUIDE.md`](../docs/KAGGLE_CHECKPOINT_VISUAL_GUIDE.md)

---

In [None]:
# ============================================================================
# CELL 2: Import Verification & GPU Check
# ============================================================================
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s', datefmt='%H:%M:%S')

print("=" * 60)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB)")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

# Import project modules
from src.core.vqvae import SemanticVQVAE, create_vqvae, VQVAETrainer
from src.core.latent_diffusion import create_latent_diffusion
from src.core.condition_encoder import create_condition_encoder
from src.core.logic_net import LogicNet
from src.data.zelda_loader import create_dataloader, graph_collate_fn
from src.train_vqvae import grids_to_onehot
from src.train_diffusion import DiffusionTrainingConfig, DiffusionTrainer
print("‚úÖ All modules imported")

# Quick data test
loader_test = create_dataloader(str(DATA_DIR), batch_size=2, use_vglc=True, normalize=True, load_graphs=True)
print(f"\nüìä Dataset: {len(loader_test.dataset)} dungeons")
for batch in loader_test:
    if isinstance(batch, (list, tuple)):
        imgs, graphs = batch
        g = graphs[0]
        print(f"   Image batch: {imgs.shape}")
        print(f"   Graph[0]: {g['num_nodes']} nodes, {g['num_edges']} edges, features={g['node_features'].shape}")
    break
print("‚úÖ Data loading verified with real .dot graphs!")

---
## üîµ Stage 1: VQ-VAE Pretraining (Block I)

Train the Semantic VQ-VAE to reconstruct dungeon grids.
- **Input**: 44-class one-hot tiles `[B, 44, H, W]`
- **Codebook**: 512 embeddings, latent_dim=64
- **Goal**: ‚â•85% reconstruction accuracy before Stage 2

In [None]:
# ============================================================================
# CELL 3: Stage 1 ‚Äî VQ-VAE Pretraining (PRODUCTION VERSION)
# ============================================================================
import json, time
from pathlib import Path

VQVAE_EPOCHS = 300           # Increase to 200-300 for production
VQVAE_BATCH_SIZE = 2
VQVAE_LR = 3e-4
VQVAE_TARGET_ACCURACY = 0.85
VQVAE_SAVE_PATH = CHECKPOINT_DIR / 'vqvae_pretrained.pth'

print("=" * 60)
print("üîµ STAGE 1: VQ-VAE PRETRAINING")
print("=" * 60)

def validate_vqvae_checkpoint(ckpt_path):
    """Validate VQ-VAE checkpoint integrity."""
    required_keys = ['epoch', 'model_state_dict', 'optimizer_state_dict', 'accuracy']
    try:
        ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
        missing = [k for k in required_keys if k not in ckpt]
        if missing:
            raise ValueError(f"Missing keys: {missing}")
        # Sanity checks
        if ckpt['epoch'] < 0:
            raise ValueError("Invalid epoch")
        if not (0 <= ckpt['accuracy'] <= 1):
            raise ValueError("Invalid accuracy")
        return ckpt
    except Exception as e:
        print(f"‚ö†Ô∏è  Invalid checkpoint: {e}")
        return None

# Smart resume logic with MULTI-PATH checkpoint discovery
SKIP_VQVAE = False
resume_epoch = 0
vqvae_history = []
resume_checkpoint = None

# Search for VQ-VAE checkpoints across ALL available sources
best_ckpt_path, best_ckpt_info, all_ckpts = find_best_checkpoint_across_sources(
    checkpoint_filename='vqvae_pretrained.pth',
    required_keys=['epoch', 'model_state_dict', 'optimizer_state_dict', 'accuracy'],
    prefer_metric='accuracy'  # Choose checkpoint with highest accuracy
)

if best_ckpt_path is not None:
    # Load the best checkpoint found
    ckpt = torch.load(best_ckpt_path, map_location='cpu', weights_only=False)
    existing_acc = ckpt['accuracy']
    resume_epoch = ckpt['epoch'] + 1
    resume_checkpoint = ckpt
    vqvae_history = ckpt.get('history', [])
    
    print(f"‚ö° Loaded VQ-VAE checkpoint from: {best_ckpt_info.source_location}")
    print(f"   Epoch {resume_epoch}, accuracy={existing_acc:.3f}")
    
    # Copy to working directory if from input dataset (allows overwriting with better checkpoints)
    if IS_KAGGLE and best_ckpt_info.source_type != 'working':
        VQVAE_SAVE_PATH = copy_checkpoint_to_working(best_ckpt_path, 'vqvae_pretrained.pth')
        print(f"   üìã Copied to working directory for incremental saves")
    else:
        VQVAE_SAVE_PATH = best_ckpt_path
    
    if existing_acc >= VQVAE_TARGET_ACCURACY:
        print(f"   ‚úÖ Accuracy ‚â• {VQVAE_TARGET_ACCURACY} ‚Äî skipping training")
        SKIP_VQVAE = True
    else:
        print(f"   üîÑ Resuming from epoch {resume_epoch}")
        if vqvae_history:
            print(f"   üìä Loaded {len(vqvae_history)} epochs of history")
else:
    print("‚ÑπÔ∏è  No existing VQ-VAE checkpoint found - starting fresh training")
    # Ensure VQVAE_SAVE_PATH points to working directory
    if IS_KAGGLE:
        VQVAE_SAVE_PATH = Path('/kaggle/working/checkpoints') / 'vqvae_pretrained.pth'
    else:
        VQVAE_SAVE_PATH = CHECKPOINT_DIR / 'vqvae_pretrained.pth'

if not SKIP_VQVAE:
    # Create model and trainer with CORRECT parameter names
    vqvae_model = create_vqvae(
        num_classes=44, 
        latent_dim=64, 
        codebook_size=512  # ‚úÖ FIXED: was 'num_embeddings'
    ).to(DEVICE)
    
    # ‚úÖ FIXED: use 'lr' instead of 'learning_rate', removed 'device' parameter
    trainer = VQVAETrainer(vqvae_model, lr=VQVAE_LR)
    
    # Create dataloader
    dataloader = create_dataloader(
        str(DATA_DIR), 
        batch_size=VQVAE_BATCH_SIZE,
        shuffle=True, 
        use_vglc=True, 
        normalize=True, 
        load_graphs=False
    )
    print(f"üìä Training: {len(dataloader.dataset)} dungeons, {len(dataloader)} batches/epoch")
    
    # Resume model and optimizer states AFTER creating model/trainer
    if resume_checkpoint is not None:
        try:
            vqvae_model.load_state_dict(resume_checkpoint['model_state_dict'])
            trainer.optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict'])
            print(f"   üîÑ Resumed model and optimizer state")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not resume state: {e}")
            print(f"   Starting fresh from epoch 0")
            resume_epoch = 0
            vqvae_history = []
    
    # Training tracking
    best_loss = float('inf')
    if vqvae_history:
        best_loss = min(h['loss'] for h in vqvae_history)
        print(f"   üìà Previous best loss: {best_loss:.4f}")
    
    history = vqvae_history
    t0 = time.time()
    
    print(f"\n{'Epoch':>6} | {'Loss':>8} | {'Recon':>8} | {'VQ':>6} | {'Acc':>6} | {'Time':>6}")
    print("‚îÄ" * 60)

    for epoch in range(resume_epoch, VQVAE_EPOCHS):
        # Training loop
        vqvae_model.train()
        metrics_sum = {'loss': 0, 'recon_loss': 0, 'vq_loss': 0, 'perplexity': 0}
        nb = 0
        
        for batch in dataloader:
            if isinstance(batch, (list, tuple)):
                batch = batch[0]
            batch = batch.to(DEVICE)
            x_onehot = grids_to_onehot(batch, num_classes=44)
            _, m = trainer.train_step(x_onehot)
            for k in metrics_sum:
                metrics_sum[k] += m.get(k, 0.0)
            nb += 1
        
        # Average metrics
        for k in metrics_sum:
            metrics_sum[k] /= max(nb, 1)

        # Evaluation
        vqvae_model.eval()
        acc_sum, acc_n = 0, 0
        with torch.no_grad():
            for batch in dataloader:
                if isinstance(batch, (list, tuple)):
                    batch = batch[0]
                batch = batch.to(DEVICE)
                x_onehot = grids_to_onehot(batch, num_classes=44)
                recon, _ = vqvae_model(x_onehot)
                pred_tiles = recon.argmax(dim=1)
                orig_tiles = x_onehot.argmax(dim=1)
                acc_sum += (pred_tiles == orig_tiles).float().mean().item()
                acc_n += 1
        eval_acc = acc_sum / max(acc_n, 1)

        # Record history
        history.append({
            'epoch': epoch + 1,
            'loss': metrics_sum['loss'],
            'recon_loss': metrics_sum['recon_loss'],
            'vq_loss': metrics_sum['vq_loss'],
            'perplexity': metrics_sum['perplexity'],
            'accuracy': eval_acc
        })
        
        # Save checkpoint (atomic write to prevent corruption)
        if metrics_sum['loss'] < best_loss or eval_acc >= VQVAE_TARGET_ACCURACY:
            checkpoint_state = {
                'epoch': epoch,
                'model_state_dict': vqvae_model.state_dict(),
                'optimizer_state_dict': trainer.optimizer.state_dict(),
                'loss': metrics_sum['loss'],
                'accuracy': eval_acc,
                'perplexity': metrics_sum['perplexity'],
                'history': history
            }
            
            # Atomic save (temp file ‚Üí rename) to prevent corruption during write
            temp_path = VQVAE_SAVE_PATH.parent / f".{VQVAE_SAVE_PATH.name}.tmp"
            torch.save(checkpoint_state, temp_path)
            if VQVAE_SAVE_PATH.exists():
                VQVAE_SAVE_PATH.unlink()
            temp_path.rename(VQVAE_SAVE_PATH)
            
            best_loss = metrics_sum['loss']
        
        # Progress logging
        if (epoch + 1) % 5 == 0 or epoch == resume_epoch or epoch == 0:
            elapsed = time.time() - t0
            print(f"  {epoch+1:4d}   | {metrics_sum['loss']:.4f}  | "
                  f"{metrics_sum['recon_loss']:.4f}  | {metrics_sum['vq_loss']:.4f} | "
                  f"{eval_acc:.3f} | {elapsed/60:.1f}m")
        
        # Early stopping
        if eval_acc >= VQVAE_TARGET_ACCURACY:
            print(f"\nüéØ Target accuracy {VQVAE_TARGET_ACCURACY:.2f} reached!")
            break

    print(f"\n‚úÖ VQ-VAE complete! Best loss={best_loss:.4f}, Final acc={eval_acc:.3f}")
    print(f"   Checkpoint: {VQVAE_SAVE_PATH}")
    
    # Save training history separately (for analysis)
    history_path = CHECKPOINT_DIR / 'vqvae_history.json'
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    
    # Plot training curves
    if len(history) > 1:
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        ep_x = [h['epoch'] for h in history]
        axes[0].plot(ep_x, [h['loss'] for h in history], 'b-')
        axes[0].set_title('Total Loss'); axes[0].grid(True, alpha=0.3)
        
        axes[1].plot(ep_x, [h['recon_loss'] for h in history], 'r-', label='Recon')
        axes[1].plot(ep_x, [h['vq_loss'] for h in history], 'g-', label='VQ')
        axes[1].set_title('Loss Components'); axes[1].legend(); axes[1].grid(True, alpha=0.3)
        
        axes[2].plot(ep_x, [h['accuracy'] for h in history], 'm-')
        axes[2].set_title('Accuracy'); axes[2].set_ylim(0, 1); axes[2].grid(True, alpha=0.3)
        
        plt.suptitle('Stage 1: VQ-VAE Training', fontweight='bold')
        plt.tight_layout()
        plt.savefig(str(OUTPUT_DIR / 'vqvae_curves.png'), dpi=150)
        plt.show()
else:
    print("‚è≠Ô∏è  Skipping VQ-VAE training - target accuracy achieved")

---
## üü¢ Stage 2: Latent Diffusion with Real Graph Conditioning

Full pipeline training using **real .dot mission graphs** from VGLC:
- **U-Net** denoiser in VQ-VAE latent space
- **GATv2Conv GNN** encodes real dungeon graph topology (nodes=rooms, edges=doors)
- **LogicNet** gradient guidance: differentiable solvability + key-lock checking
- **EMA** model weights for stable sampling

| Phase | Epochs | Loss Components |
|-------|--------|-----------------|
| Warmup | 1‚Äì5 | Diffusion only (no logic) |
| Full | 6+ | Diffusion + Œ±√óLogicNet |

In [None]:
# ============================================================================
# CELL 4: Stage 2 ‚Äî Diffusion Training with Real Graph Conditioning (PRODUCTION VERSION)
# ============================================================================
import time

# Helper functions for compact checkpointing (Kaggle disk space optimization)
def _save_minimal_checkpoint(ckpt_path, trainer, metrics_dict):
    """Save checkpoint without optimizer/scheduler states to save disk space."""
    state = {
        'epoch': trainer.epoch,
        'model_state_dict': trainer.ema_diffusion.state_dict(),
        'ema_state_dict': trainer.ema_diffusion.state_dict(),
        'condition_encoder_state_dict': trainer.condition_encoder.state_dict(),
        'logic_net_state_dict': trainer.logic_net.state_dict()
    }
    if metrics_dict:
        state.update(metrics_dict)
    
    torch.save(state, ckpt_path)
    print(f"üíæ Saved compact checkpoint: {ckpt_path.name} ({ckpt_path.stat().st_size / 1024**2:.1f}MB)")

def _prune_checkpoints(checkpoint_dir, keep=3, pattern='checkpoint_*.pth'):
    """Keep only the most recent N checkpoints to save disk space."""
    checkpoints = sorted(checkpoint_dir.glob(pattern), key=lambda x: x.stat().st_mtime, reverse=True)
    for old_ckpt in checkpoints[keep:]:
        size_mb = old_ckpt.stat().st_size / 1024**2
        old_ckpt.unlink()
        print(f"üóëÔ∏è  Pruned: {old_ckpt.name} ({size_mb:.1f} MB freed)")

def verify_checkpoint_integrity(ckpt_path, required_keys):
    """Validate checkpoint before loading."""
    try:
        ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
        missing = [k for k in required_keys if k not in ckpt]
        if missing:
            return False, f"Missing keys: {missing}"
        # Check epoch sanity
        if 'epoch' in ckpt and ckpt['epoch'] < 0:
            return False, "Invalid epoch value"
        return True, "Valid"
    except Exception as e:
        return False, str(e)

DIFFUSION_EPOCHS = 500       # 100+ recommended; 500 for production
DIFFUSION_BATCH_SIZE = 2
DIFFUSION_LR = 1e-4
ALPHA_LOGIC = 0.1
WARMUP_EPOCHS = 5

# Checkpointing policy: compact checkpoints for Kaggle (reduce disk usage)
if IS_KAGGLE:
    SAVE_EVERY = 50  # less frequent full saves
    MAX_KEEP_CHECKPOINTS = 3
    COMPACT_CHECKPOINTS = True
else:
    SAVE_EVERY = 10
    MAX_KEEP_CHECKPOINTS = 10
    COMPACT_CHECKPOINTS = False

print("=" * 60)
print("üü¢ STAGE 2: LATENT DIFFUSION TRAINING")
print("=" * 60)

if not VQVAE_SAVE_PATH.exists():
    raise FileNotFoundError(f"‚ùå VQ-VAE checkpoint not found. Run Stage 1 first!")

config = DiffusionTrainingConfig(
    data_dir=str(DATA_DIR), batch_size=DIFFUSION_BATCH_SIZE, use_vglc=True,
    vqvae_checkpoint=str(VQVAE_SAVE_PATH), latent_dim=64, model_channels=128,
    context_dim=256, num_timesteps=1000, schedule_type='cosine',
    num_logic_iterations=30, guidance_scale=1.0, epochs=DIFFUSION_EPOCHS,
    learning_rate=DIFFUSION_LR, alpha_visual=1.0, alpha_logic=ALPHA_LOGIC,
    warmup_epochs=WARMUP_EPOCHS, checkpoint_dir=str(CHECKPOINT_DIR),
    save_every=SAVE_EVERY, device='cuda' if torch.cuda.is_available() else 'cpu',
)

# Data loaders with REAL graph data from .dot files
train_loader = create_dataloader(config.data_dir, batch_size=config.batch_size,
    shuffle=True, use_vglc=True, normalize=True, load_graphs=True)
val_loader = create_dataloader(config.data_dir, batch_size=config.batch_size,
    shuffle=False, use_vglc=True, normalize=True, load_graphs=True)
print(f"üìä {len(train_loader.dataset)} dungeons with real .dot graphs, {len(train_loader)} batches/epoch")

# Create trainer
diff_trainer = DiffusionTrainer(config)
print(f"üèóÔ∏è  All models on {config.device}, VQ-VAE frozen ‚úÖ")

# Smart resume logic: find the latest valid checkpoint across ALL sources
def find_latest_checkpoint_multi_source(
    checkpoint_patterns: List[str] = ['final_model.pth', 'best_model.pth', 'checkpoint_*.pth'],
    required_keys: Optional[List[str]] = None
) -> Tuple[Optional[Path], Optional[CheckpointInfo]]:
    """
    Find the most recent valid checkpoint across all Kaggle sources.
    Supports glob patterns like 'checkpoint_*.pth' to find numbered checkpoints.
    
    Returns:
        Tuple of (best_checkpoint_path, checkpoint_info)
    """
    all_candidates: List[CheckpointInfo] = []
    locations = find_checkpoint_locations()
    
    print("üîç Searching for diffusion checkpoints...")
    
    search_order = [
        ('working', locations['working']),
        ('input_datasets', locations['input_datasets']),
        ('notebook_outputs', locations['notebook_outputs']),
        ('direct_files', locations['direct_files'])
    ]
    
    for source_type, dirs in search_order:
        for checkpoint_dir in dirs:
            for pattern in checkpoint_patterns:
                # Handle glob patterns
                if '*' in pattern:
                    matches = sorted(checkpoint_dir.glob(pattern), 
                                   key=lambda p: int(p.stem.split('_')[-1]) if p.stem.split('_')[-1].isdigit() else 0,
                                   reverse=True)
                else:
                    matches = [checkpoint_dir / pattern] if (checkpoint_dir / pattern).exists() else []
                
                for ckpt_path in matches:
                    if not ckpt_path.exists():
                        continue
                    
                    # ‚úÖ FIXED: Use validate_and_load_checkpoint from Cell 3
                    info = validate_and_load_checkpoint(
                        ckpt_path=ckpt_path,
                        required_keys=required_keys,
                        source_type=source_type,
                        source_location=str(checkpoint_dir.relative_to('/kaggle') if IS_KAGGLE else checkpoint_dir)
                    )
                    
                    if info.is_valid:
                        all_candidates.append(info)
                        status = "‚úÖ"
                        epoch_str = f"epoch={info.epoch}" if info.epoch is not None else "epoch=?"
                        solv_str = f", solv={info.solvability:.3f}" if info.solvability else ""
                        print(f"   {status} [{source_type:15}] {ckpt_path.name} - {epoch_str}{solv_str}")
    
    if not all_candidates:
        print("‚ÑπÔ∏è  No valid diffusion checkpoints found")
        return None, None
    
    # Sort by epoch (most recent first), then by source priority
    priority_map = {'working': 0, 'input_datasets': 1, 'notebook_outputs': 2, 'direct_files': 3}
    all_candidates.sort(
        key=lambda c: (-(c.epoch if c.epoch is not None else -1), priority_map.get(c.source_type, 99))
    )
    
    best = all_candidates[0]
    print(f"\nüéØ Selected: {best.path.name} from {best.source_location}")
    print(f"   üìä Epoch {best.epoch}", end="")
    if best.solvability is not None:
        print(f", validation solvability={best.solvability:.3f}", end="")
    print()
    
    return best.path, best

# Resume from checkpoint if exists (search across ALL sources)
start_epoch = 0
resume_history = []
latest_ckpt_path, latest_ckpt_info = find_latest_checkpoint_multi_source(
    checkpoint_patterns=['final_model.pth', 'best_model.pth', 'checkpoint_*.pth'],
    required_keys=['epoch', 'diffusion_state_dict', 'ema_diffusion_state_dict']
)

if latest_ckpt_path is not None:
    try:
        # Copy to working directory if from input dataset
        if IS_KAGGLE and latest_ckpt_info.source_type != 'working':
            working_ckpt_path = copy_checkpoint_to_working(latest_ckpt_path, latest_ckpt_path.name)
            print(f"üìã Copied to working directory for incremental saves")
        else:
            working_ckpt_path = latest_ckpt_path
        
        # Load the checkpoint into trainer
        diff_trainer.load_checkpoint(str(latest_ckpt_path))
        start_epoch = diff_trainer.epoch + 1
        
        # Try to load training history
        history_path = CHECKPOINT_DIR / 'diffusion_history.json'
        if history_path.exists():
            with open(history_path, 'r') as f:
                resume_history = json.load(f)
                print(f"üìä Loaded {len(resume_history)} epochs of training history")
        
        print(f"üîÑ Resuming from epoch {start_epoch}")
    except Exception as e:
        print(f"‚ö†Ô∏è  Error loading checkpoint {latest_ckpt_path.name}: {e}")
        print("   Starting fresh training")
        start_epoch = 0
        resume_history = []
else:
    print("‚ÑπÔ∏è  No existing diffusion checkpoint found - starting fresh training")

best_solv, history = 0.0, resume_history
if resume_history:
    # Find the best validation solvability from history
    best_solv = max((h.get('val_solvability', 0) for h in resume_history), default=0.0)
    print(f"üìà Previous best validation solvability: {best_solv:.4f}")
    
t0 = time.time()
print(f"\n{'Epoch':>6} | {'Loss':>8} | {'Diffusion':>10} | {'Logic':>8} | {'Val Solv':>10} | {'Time':>6}")
print("‚îÄ" * 65)

for epoch in range(start_epoch, DIFFUSION_EPOCHS):
    train_m = diff_trainer.train_epoch(train_loader)
    val_m = diff_trainer.validate(val_loader, num_samples=4)
    lr = diff_trainer.scheduler.get_last_lr()[0]

    rec = {'epoch': epoch+1, **train_m, 'val_solvability': val_m['val_solvability'], 'lr': lr}
    history.append(rec)

    logic_flag = "üîí" if epoch < WARMUP_EPOCHS else "‚úÖ"
    elapsed = time.time() - t0
    print(f"  {epoch+1:4d}   | {train_m['loss']:.4f}  | {train_m['diffusion_loss']:.6f}  | "
          f"{train_m['logic_loss']:.4f}{logic_flag} | {val_m['val_solvability']:.4f}     | {elapsed/60:.1f}m")

    # Save periodic checkpoints
    if (epoch+1) % SAVE_EVERY == 0:
        ckpt_path = CHECKPOINT_DIR / f'checkpoint_{epoch+1:04d}.pth'
        if COMPACT_CHECKPOINTS:
            _save_minimal_checkpoint(ckpt_path, diff_trainer, rec)
            _prune_checkpoints(CHECKPOINT_DIR, keep=MAX_KEEP_CHECKPOINTS)
        else:
            diff_trainer.save_checkpoint(str(ckpt_path), rec)

    # Save best model
    if val_m['val_solvability'] > best_solv:
        best_solv = val_m['val_solvability']
        best_path = CHECKPOINT_DIR / 'best_model.pth'
        if COMPACT_CHECKPOINTS:
            _save_minimal_checkpoint(best_path, diff_trainer, rec)
        else:
            diff_trainer.save_checkpoint(str(best_path), rec)

# Save final checkpoint (compact if configured)
final_path = CHECKPOINT_DIR / 'final_model.pth'
if COMPACT_CHECKPOINTS:
    _save_minimal_checkpoint(final_path, diff_trainer, history[-1] if history else None)
else:
    diff_trainer.save_checkpoint(str(final_path), history[-1] if history else None)
print(f"\n‚úÖ Done! Best val solvability: {best_solv:.4f}, Time: {(time.time()-t0)/60:.1f} min")

# Save training history separately
with open(CHECKPOINT_DIR / 'diffusion_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Plot training curves
if len(history) > 1:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    ep_x = [h['epoch'] for h in history]
    
    axes[0,0].plot(ep_x, [h['loss'] for h in history], 'b-')
    axes[0,0].set_title('Total Loss'); axes[0,0].grid(True, alpha=0.3)
    
    axes[0,1].plot(ep_x, [h['diffusion_loss'] for h in history], 'r-', label='Diffusion')
    axes[0,1].plot(ep_x, [h['logic_loss'] for h in history], 'g-', label='Logic')
    axes[0,1].axvline(x=WARMUP_EPOCHS, color='gray', ls='--', alpha=0.5, label='Warmup end')
    axes[0,1].set_title('Loss Components'); axes[0,1].legend(); axes[0,1].grid(True, alpha=0.3)
    
    axes[1,0].plot(ep_x, [h['solvability'] for h in history], 'c-', label='Train')
    axes[1,0].plot(ep_x, [h['val_solvability'] for h in history], 'm-', label='Val')
    axes[1,0].set_title('Solvability'); axes[1,0].legend(); axes[1,0].set_ylim(0,1); axes[1,0].grid(True, alpha=0.3)
    
    axes[1,1].plot(ep_x, [h['lr'] for h in history], 'k-')
    axes[1,1].set_title('Learning Rate')
    axes[1,1].set_yscale('log'); axes[1,1].grid(True, alpha=0.3)
    
    plt.suptitle('Stage 2: Diffusion Training', fontweight='bold')
    plt.tight_layout()
    plt.savefig(str(OUTPUT_DIR / 'diffusion_curves.png'), dpi=150)
    plt.show()


---
## üé® Generation & Visualization

Generate sample dungeons using the trained model with graph-conditioned DDIM sampling.

In [None]:
# ============================================================================
# CELL 5: Generate Sample Dungeons
# ============================================================================
import matplotlib.colors as mcolors

NUM_SAMPLES = 4

print("üé® GENERATING SAMPLE DUNGEONS")

# Load best checkpoint
ckpt_path = CHECKPOINT_DIR / 'best_model.pth'
if not ckpt_path.exists():
    ckpt_path = CHECKPOINT_DIR / 'final_model.pth'

gen_config = DiffusionTrainingConfig(
    data_dir=str(DATA_DIR), batch_size=1, use_vglc=True,
    vqvae_checkpoint=str(VQVAE_SAVE_PATH), latent_dim=64,
    model_channels=128, context_dim=256, num_timesteps=1000,
    schedule_type='cosine', device='cuda' if torch.cuda.is_available() else 'cpu',
)
gen = DiffusionTrainer(gen_config)
gen.load_checkpoint(str(ckpt_path))
print(f"‚úÖ Loaded model (epoch {gen.epoch})")

# Get real graph conditioning
cond_loader = create_dataloader(str(DATA_DIR), batch_size=1, shuffle=True,
    use_vglc=True, normalize=True, load_graphs=True)
conditionings, shapes = [], []
for batch_data in cond_loader:
    if isinstance(batch_data, (list, tuple)) and len(batch_data) == 2:
        imgs, graphs = batch_data
        imgs = imgs.to(DEVICE)
        for g in graphs:
            try: conditionings.append(gen._encode_graph_conditioning(g))
            except: pass
        shapes.append(gen.encode_to_latent(imgs).shape)
    if len(conditionings) >= NUM_SAMPLES: break

# Generate
gen.ema_diffusion.eval()
generated = []
with torch.no_grad():
    for i in range(NUM_SAMPLES):
        c = conditionings[i % len(conditionings)]
        z = gen.ema_diffusion.sample(c, shape=shapes[0])
        logits = gen.decode_from_latent(z)
        tile_ids = logits.argmax(dim=1).squeeze(0).cpu().numpy()
        generated.append(tile_ids)
        print(f"  Sample {i+1}: {tile_ids.shape}, {len(np.unique(tile_ids))} tile types")

# Visualize
TILE_COLORS = {
    0: '#1a1a2e', 1: '#e8d5b7', 2: '#4a4a4a', 3: '#8b7355',
    10: '#90EE90', 11: '#FFD700', 12: '#FF6347', 13: '#9370DB',
    14: '#DC143C', 15: '#FFA500', 20: '#FF4444', 21: '#00FF00',
    22: '#FFD700', 23: '#8B0000', 30: '#FFFF00', 31: '#FF69B4',
    32: '#00CED1', 33: '#87CEEB', 40: '#4169E1', 41: '#6495ED',
    42: '#DEB887', 43: '#DA70D6',
}

def grid_to_rgb(grid):
    h, w = grid.shape
    rgb = np.full((h, w, 3), 0.5, dtype=np.float32)
    for tid, hex_c in TILE_COLORS.items():
        mask = grid == tid
        if mask.any():
            r, g, b = mcolors.hex2color(hex_c)
            rgb[mask] = [r, g, b]
    return rgb

fig, axes = plt.subplots(1, NUM_SAMPLES, figsize=(5*NUM_SAMPLES, 8))
if NUM_SAMPLES == 1: axes = [axes]
for i, (ax, grid) in enumerate(zip(axes, generated)):
    ax.imshow(grid_to_rgb(grid), interpolation='nearest', aspect='auto')
    ax.set_title(f'Dungeon {i+1}\n{grid.shape[0]}x{grid.shape[1]}')
    ax.axis('off')
plt.suptitle('H-MOLQD Generated Dungeons (Graph-Conditioned)', fontweight='bold')
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / 'generated_samples.png'), dpi=150)
plt.show()

for i, grid in enumerate(generated):
    np.save(str(OUTPUT_DIR / f'dungeon_{i+1}.npy'), grid)
print(f"üíæ Saved {NUM_SAMPLES} dungeons to {OUTPUT_DIR}")

---

## üì¶ Summary & Download

### Checkpoints Saved

| File | Description |
|------|-------------|
| `checkpoints/vqvae_pretrained.pth` | VQ-VAE encoder/decoder (44 tiles, 512 codebook) |
| `checkpoints/best_model.pth` | Best diffusion model (lowest val loss) |
| `checkpoints/final_model.pth` | Final diffusion model (last epoch) |

### Generated Outputs

| File | Description |
|------|-------------|
| `outputs/generated_samples.png` | Visualization of generated dungeons |
| `outputs/dungeon_*.npy` | Raw tile grids (NumPy arrays) |

### Download Results (Kaggle)

```python
import shutil
shutil.make_archive('/kaggle/working/hmolqd_results', 'zip', '/kaggle/working/hmolqd_outputs')
```

Then download `hmolqd_results.zip` from the **Output** tab.

### Next Steps (Local)
1. **WFC Refinement** (Block V): `python -m src.generation.wfc_refiner --input outputs/dungeon_1.npy`
2. **Cognitive Validation** (Block VI): `python -m src.simulation.cognitive_validator --dungeon outputs/dungeon_1.npy`
3. **Quality-Diversity Search**: Use MAP-Elites with trained models for diverse dungeon generation