# Superconductor VAE - Google Colab Training

Train the FullMaterialsVAE (V12) on Google Colab using your repo uploaded to Google Drive.

**Setup**: Upload the entire `superconductor-vae` repository to your Google Drive, then run these cells in order.

**Checkpoints**: Saved to `outputs/` inside the repo on Drive, so they persist across Colab sessions.

**Remote monitoring (optional)**: Set `GIST_ID` in Cell 2 to push live training metrics to a GitHub Gist. This lets you (or Claude Code) check training progress without being in the Colab tab.

## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Cell 2: Configuration

Edit these values to match your setup.

In [None]:
# Path to the superconductor-vae repo on your Google Drive
REPO_PATH = "/content/drive/MyDrive/superconductor-vae"

# Training options (override defaults)
RESUME_FROM_CHECKPOINT = True   # Resume from best checkpoint if available
NUM_EPOCHS = 2000               # Total epochs (training resumes from last epoch)
BATCH_SIZE = 'auto'             # 'auto' scales with GPU memory, or set integer (e.g. 32, 48)

# --- Remote monitoring via GitHub Gist ---
# Set GIST_ID to enable live training metrics visible outside Colab.
# 1. Create a personal access token at https://github.com/settings/tokens with "gist" scope
# 2. In Colab: click the key icon (left sidebar) > add secret named GITHUB_TOKEN
# 3. Create a gist at https://gist.github.com with one file named "training_log.json"
#    containing just "{}" — then copy the gist ID from the URL.
# Set to None to disable gist logging.
GIST_ID = "acceed7daef4d6893801cc7337531b68"
GIST_LOG_EVERY = 5  # Update gist every N epochs

## Cell 3: Install Dependencies

PyTorch, NumPy, pandas, and scikit-learn are pre-installed on Colab. We only need matminer and pymatgen.

In [None]:
!pip install matminer pymatgen

## Cell 4: Setup Paths and Verify Environment

In [None]:
import sys
import os
from pathlib import Path

repo = Path(REPO_PATH)

# Add src/ to Python path so imports work
src_path = str(repo / "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Verify key files exist
required_files = {
    "Training data": repo / "data/processed/supercon_fractions_combined.csv",
    "Holdout set": repo / "data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json",
    "Training script": repo / "scripts/train_v12_clean.py",
    "VAE model": repo / "src/superconductor/models/attention_vae.py",
    "Decoder model": repo / "src/superconductor/models/autoregressive_decoder.py",
}

all_found = True
for name, path in required_files.items():
    exists = path.exists()
    status = "OK" if exists else "MISSING"
    print(f"  [{status}] {name}: {path}")
    if not exists:
        all_found = False

if not all_found:
    raise FileNotFoundError(
        f"Missing files. Check that REPO_PATH is correct: {REPO_PATH}"
    )

print()

# GPU info
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    capability = torch.cuda.get_device_capability(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_mem:.1f} GB")
    print(f"Compute capability: {capability[0]}.{capability[1]}")
    print(f"PyTorch: {torch.__version__}")
    print(f"CUDA: {torch.version.cuda}")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > GPU.")
    print(f"PyTorch: {torch.__version__}")

## Cell 5: Apply Colab-Specific Config Overrides

Patches the training config for Colab compatibility before importing the training function.

In [None]:
import importlib

# Add scripts/ to path so we can import train_v12_clean as a module
scripts_path = str(repo / "scripts")
if scripts_path not in sys.path:
    sys.path.insert(0, scripts_path)

# Import the training module
import train_v12_clean

# --- Colab-specific overrides ---

# DataLoader: Colab's containerized env has issues with multiprocessing workers
train_v12_clean.TRAIN_CONFIG['num_workers'] = 0
train_v12_clean.TRAIN_CONFIG['persistent_workers'] = False
train_v12_clean.TRAIN_CONFIG['pin_memory'] = True

# torch.compile: Colab often lacks the C++ toolchain (gcc/triton) needed
train_v12_clean.TRAIN_CONFIG['use_torch_compile'] = False

# More frequent checkpoints — Colab can disconnect even on Pro tier
train_v12_clean.TRAIN_CONFIG['checkpoint_interval'] = 25

# Gradient checkpointing not needed on A100/L4 (40GB+ VRAM)
train_v12_clean.TRAIN_CONFIG['use_gradient_checkpointing'] = False

# --- User overrides from Cell 2 ---

train_v12_clean.TRAIN_CONFIG['num_epochs'] = NUM_EPOCHS

if BATCH_SIZE != 'auto':
    train_v12_clean.TRAIN_CONFIG['batch_size'] = int(BATCH_SIZE)

if RESUME_FROM_CHECKPOINT:
    train_v12_clean.TRAIN_CONFIG['resume_checkpoint'] = 'outputs/checkpoint_best.pt'
else:
    train_v12_clean.TRAIN_CONFIG['resume_checkpoint'] = None

# --- Redirect paths to Drive-based repo ---

train_v12_clean.PROJECT_ROOT = repo
train_v12_clean.DATA_PATH = repo / 'data/processed/supercon_fractions_combined.csv'
train_v12_clean.HOLDOUT_PATH = repo / 'data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json'
train_v12_clean.OUTPUT_DIR = repo / 'outputs'

# Ensure output directory exists
train_v12_clean.OUTPUT_DIR.mkdir(exist_ok=True)

# --- Print final config ---

print("Colab config applied:")
print(f"  num_workers: {train_v12_clean.TRAIN_CONFIG['num_workers']}")
print(f"  persistent_workers: {train_v12_clean.TRAIN_CONFIG['persistent_workers']}")
print(f"  use_torch_compile: {train_v12_clean.TRAIN_CONFIG['use_torch_compile']}")
print(f"  checkpoint_interval: {train_v12_clean.TRAIN_CONFIG['checkpoint_interval']}")
print(f"  use_gradient_checkpointing: {train_v12_clean.TRAIN_CONFIG['use_gradient_checkpointing']}")
print(f"  num_epochs: {train_v12_clean.TRAIN_CONFIG['num_epochs']}")
print(f"  batch_size: {train_v12_clean.TRAIN_CONFIG['batch_size']}")
print(f"  resume_checkpoint: {train_v12_clean.TRAIN_CONFIG['resume_checkpoint']}")
print(f"  OUTPUT_DIR: {train_v12_clean.OUTPUT_DIR}")

## Cell 5b: Setup Gist Logging (Optional)

If `GIST_ID` is set in Cell 2, this hooks into the training loop to push metrics to a GitHub Gist every N epochs. You can then monitor training remotely via `gh gist view GIST_ID` or by visiting the gist URL.

In [None]:
import json
import requests
from datetime import datetime, timezone

_gist_token = None
_gist_log = {"history": [], "status": "initialized"}

def _get_github_token():
    """Get GitHub token from Colab secrets."""
    global _gist_token
    if _gist_token is not None:
        return _gist_token
    try:
        from google.colab import userdata
        _gist_token = userdata.get('GITHUB_TOKEN')
        return _gist_token
    except Exception as e:
        print(f"  [Gist] Could not read GITHUB_TOKEN from Colab secrets: {e}")
        return None

def update_gist(epoch, metrics, best_exact, status="training"):
    """Push current training metrics to the GitHub Gist."""
    if GIST_ID is None:
        return

    token = _get_github_token()
    if token is None:
        return

    entry = {
        "epoch": epoch,
        "loss": round(metrics.get("loss", 0), 4),
        "exact_match": round(metrics.get("exact_match", 0) * 100, 2),
        "token_accuracy": round(metrics.get("accuracy", 0) * 100, 2),
        "tc_loss": round(metrics.get("tc_loss", 0), 4),
        "magpie_loss": round(metrics.get("magpie_loss", 0), 4),
        "stoich_loss": round(metrics.get("stoich_loss", 0), 4),
        "reinforce_loss": round(metrics.get("reinforce_loss", 0), 4),
        "mean_reward": round(metrics.get("mean_reward", 0), 3),
        "entropy": round(metrics.get("entropy", 0), 3),
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }

    _gist_log["status"] = status
    _gist_log["best_exact_match"] = round(best_exact * 100, 2)
    _gist_log["last_update"] = entry["timestamp"]
    _gist_log["current_epoch"] = epoch
    _gist_log["total_epochs"] = NUM_EPOCHS

    # Keep last 200 entries to avoid gist getting huge
    _gist_log["history"].append(entry)
    if len(_gist_log["history"]) > 200:
        _gist_log["history"] = _gist_log["history"][-200:]

    try:
        resp = requests.patch(
            f"https://api.github.com/gists/{GIST_ID}",
            headers={
                "Authorization": f"token {token}",
                "Accept": "application/vnd.github.v3+json",
            },
            json={"files": {"training_log.json": {"content": json.dumps(_gist_log, indent=2)}}},
            timeout=10,
        )
        if resp.status_code != 200:
            print(f"  [Gist] Update failed (HTTP {resp.status_code})")
    except Exception as e:
        # Don't let gist errors interrupt training
        print(f"  [Gist] Update error: {e}")

# --- Hook into the training loop ---
# Monkey-patch save_checkpoint to also push metrics to gist.
# This fires on best checkpoints and periodic checkpoints without
# modifying the training script.

_original_save_checkpoint = train_v12_clean.save_checkpoint
_latest_metrics = {}  # Filled by our patched train_epoch
_latest_best_exact = 0.0

def _save_checkpoint_with_gist(encoder, decoder, epoch, suffix='', **kwargs):
    """Wrapper that calls original save_checkpoint then updates gist."""
    global _latest_best_exact
    _original_save_checkpoint(encoder, decoder, epoch, suffix=suffix, **kwargs)

    best = kwargs.get('best_exact', _latest_best_exact)
    _latest_best_exact = max(_latest_best_exact, best)

    if GIST_ID is not None and _latest_metrics:
        status = "training"
        if suffix == 'final':
            status = "completed"
        elif suffix == 'interrupt':
            status = "interrupted"
        update_gist(epoch, _latest_metrics, _latest_best_exact, status=status)

# Patch train_epoch to capture metrics for the gist hook
_original_train_epoch = train_v12_clean.train_epoch

def _train_epoch_with_capture(*args, **kwargs):
    """Wrapper that captures epoch metrics for gist logging."""
    global _latest_metrics
    metrics = _original_train_epoch(*args, **kwargs)
    _latest_metrics = metrics
    return metrics

if GIST_ID is not None:
    token = _get_github_token()
    if token:
        train_v12_clean.save_checkpoint = _save_checkpoint_with_gist
        train_v12_clean.train_epoch = _train_epoch_with_capture
        # Push initial status
        update_gist(0, {}, 0, status="starting")
        print(f"Gist logging enabled: https://gist.github.com/{GIST_ID}")
        print(f"  Metrics pushed on every checkpoint save")
    else:
        print("Gist logging disabled (no GITHUB_TOKEN found in Colab secrets)")
else:
    print("Gist logging disabled (GIST_ID is None)")

## Cell 6: Run Training

This calls the existing `train()` function directly. Training output streams to the notebook.

**Tip**: If Colab disconnects, re-run Cells 1-5 then this cell. With `RESUME_FROM_CHECKPOINT = True`, training picks up from the last saved best checkpoint.

In [None]:
train_v12_clean.train()

## Cell 7: Post-Training Summary (Optional)

List saved checkpoints and show basic info.

In [None]:
import os
from pathlib import Path

output_dir = Path(REPO_PATH) / 'outputs'

if output_dir.exists():
    checkpoints = sorted(output_dir.glob('checkpoint_*.pt'))
    if checkpoints:
        print(f"Saved checkpoints ({len(checkpoints)}):")
        for cp in checkpoints:
            size_mb = cp.stat().st_size / (1024 * 1024)
            print(f"  {cp.name} ({size_mb:.1f} MB)")

        # Load best checkpoint to show final metrics
        best_path = output_dir / 'checkpoint_best.pt'
        if best_path.exists():
            import torch
            ckpt = torch.load(best_path, map_location='cpu', weights_only=False)
            print(f"\nBest checkpoint:")
            print(f"  Epoch: {ckpt.get('epoch', 'unknown')}")
            print(f"  Best exact match: {ckpt.get('best_exact', 'unknown')}")
            if 'prev_exact' in ckpt:
                print(f"  Exact at save: {ckpt['prev_exact']:.4f}")
    else:
        print("No checkpoints found.")
else:
    print(f"Output directory not found: {output_dir}")