# Multi-Task Superconductor Generator - Google Colab Training (V12.43 + Phase 2)

Train the multi-task encoder-decoder on Google Colab using your repo uploaded to Google Drive.

**Setup**: Upload the `superconductor-vae` repo to Google Drive, then run these cells.

**Current model: V12.43** (pre-V13/V14/V15). This is the V12.41 best checkpoint (epoch 3292, 85.4% exact) expanded 12.5% wider via Net2Net. Key specs:
- **Vocab**: 148 tokens (original digit-by-digit fraction encoding)
- **Architecture**: d_model=576, dim_feedforward=2304, fusion_dim=288, latent_dim=2048
- **Heads**: Token type classifier, enriched decoder memory (heads_to_memory), stoich conditioning, stop head
- **NOT applied**: V13 semantic fractions (vocab 4355), V14 isotopes (vocab 4647), V15 memory bottleneck

The training script contains code for V13-V15 features but they are **disabled** (`use_semantic_fractions: False`, `memory_bottleneck_dim: 0`). No auto-migration will occur.

**Phase 2 Self-Supervised**: Interleaved self-supervised sub-epochs that use the model's own generations to improve generalization. Runs every 2 supervised epochs once exact match >= 80%. This is a **model enhancement** algorithm — novel superconductor discoveries are flagged opportunistically but are not the goal. See `docs/PHASE2_SELF_SUPERVISED_DESIGN.md`.

**Checkpoints**: Saved to `outputs/` inside the repo on Drive, so they persist across Colab sessions. `checkpoint_best.pt` is auto-detected on resume.

**Remote monitoring (optional)**: Set `GIST_ID` in Cell 2 to push live training metrics to a GitHub Gist. Includes all physics losses and Phase 2 metrics (z-MSE, valid rate, discoveries).

## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Cell 1b: Sync Repo with GitHub

Pull the latest code from GitHub so Colab matches your most recent push. This avoids `ImportError` or missing-feature bugs when the Drive copy is stale.

**Note**: If you have local uncommitted changes on Drive (unlikely), this will attempt a merge. If you see merge conflicts, the safest fix is to delete the repo folder from Drive and re-clone.

In [None]:
# Sync repo with GitHub (pulls latest code changes)
# REPO_PATH is defined in the Configuration cell below — hardcode it here
# so this cell can run independently before Configuration.
_REPO = "/content/drive/My Drive/Colab Notebooks/SuperconductorVAE/superconductor-vae"

import subprocess, os

if os.path.isdir(_REPO):
    print(f"Pulling latest from GitHub...")
    result = subprocess.run(
        ["git", "pull", "--ff-only"],
        cwd=_REPO,
        capture_output=True, text=True, timeout=60,
    )
    print(result.stdout.strip())
    if result.returncode != 0:
        print(f"git pull failed (exit {result.returncode}):")
        print(result.stderr.strip())
        print("\nIf you see merge conflicts, restart with a fresh clone:")
        print(f"  !rm -rf '{_REPO}'")
        print(f"  !git clone https://github.com/jamesconde/superconductor-vae.git '{_REPO}'")
    else:
        # Show current commit for reference
        commit = subprocess.run(
            ["git", "log", "--oneline", "-1"],
            cwd=_REPO, capture_output=True, text=True,
        )
        print(f"Current commit: {commit.stdout.strip()}")
else:
    print(f"Repo not found at: {_REPO}")
    print("Check that Google Drive is mounted and the path is correct.")

## Cell 2: Configuration

Edit these values to match your setup.

In [None]:
# Path to the superconductor-vae repo on your Google Drive
REPO_PATH = "/content/drive/My Drive/Colab Notebooks/SuperconductorVAE/superconductor-vae"

# Training options (override defaults)
# 'auto' prefers checkpoint_best.pt, then highest checkpoint_epoch_*.pt.
# Set to a specific path (e.g. 'outputs/checkpoint_epoch_3999.pt') to override,
# or None to train from scratch.
RESUME_CHECKPOINT = 'auto'
NUM_EPOCHS = 5000               # Total epochs (training resumes from last epoch)
BATCH_SIZE = 'auto'             # 'auto' scales with GPU memory, or set integer (e.g. 32, 48)

# --- Phase 2: Self-Supervised Training ---
# Interleaved self-supervised sub-epochs to improve generalization.
# Purpose: MODEL ENHANCEMENT (not discovery). Runs every N supervised epochs.
# Novel superconductors found during generation are flagged and saved to
# outputs/phase2_discoveries.jsonl but are a side effect, not the goal.
PHASE2_ENABLED = True           # Master toggle
PHASE2_START = 'auto'           # 'auto' = activate when exact >= 80%, or set epoch number
PHASE2_N_SAMPLES = 'auto'       # 'auto' = scale with VRAM: clamp(3.2*GB, 32, 512)
PHASE2_INTERVAL = 2             # Run Phase 2 every N supervised epochs (2 = every other epoch)

# --- Remote monitoring via GitHub Gist ---
# Set GIST_ID to enable live training metrics visible outside Colab.
# 1. Create a personal access token at https://github.com/settings/tokens with "gist" scope
# 2. In Colab: click the key icon (left sidebar) > add secret named GITHUB_TOKEN
# 3. Create a gist at https://gist.github.com with one file named "training_log.json"
#    containing just "{}" — then copy the gist ID from the URL.
# Set to None to disable gist logging.
GIST_ID = "acceed7daef4d6893801cc7337531b68"
GIST_LOG_EVERY = 5  # Update gist every N epochs

## Cell 3: Install Dependencies

PyTorch, NumPy, pandas, and scikit-learn are pre-installed on Colab. scipy is usually pre-installed but required explicitly for V12.20 Magpie quantile transforms.

In [None]:
!pip install matminer pymatgen scipy

## Cell 4: Setup Paths and Verify Environment

In [None]:
import sys
import os
from pathlib import Path

repo = Path(REPO_PATH)

# Add src/ to Python path so imports work
src_path = str(repo / "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Verify key files exist
required_files = {
    "Training data (contrastive)": repo / "data/processed/supercon_fractions_contrastive.csv",
    "Holdout set": repo / "data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json",
    "Training script": repo / "scripts/train_v12_clean.py",
    "VAE model": repo / "src/superconductor/models/attention_vae.py",
    "Decoder model": repo / "src/superconductor/models/autoregressive_decoder.py",
    "Fraction vocab": repo / "data/fraction_vocab.json",
    "Fraction vocab (old)": repo / "data/fraction_vocab_old.json",
}

# Optional files (nice to have, not required)
optional_files = {
    "Best checkpoint": repo / "outputs/checkpoint_best.pt",
    "Isotope vocab": repo / "data/isotope_vocab.json",
}

# Add specific resume checkpoint to optional checks (if not 'auto')
if RESUME_CHECKPOINT != 'auto' and RESUME_CHECKPOINT is not None:
    optional_files["Resume checkpoint"] = repo / RESUME_CHECKPOINT

all_found = True
for name, path in required_files.items():
    exists = path.exists()
    status = "OK" if exists else "MISSING"
    print(f"  [{status}] {name}: {path.name}")
    if not exists:
        all_found = False

print()
for name, path in optional_files.items():
    exists = path.exists()
    status = "OK" if exists else "---"
    print(f"  [{status}] {name}: {path.name}")

if not all_found:
    raise FileNotFoundError(
        f"Missing required files. Check that REPO_PATH is correct: {REPO_PATH}"
    )

# Show resume mode
if RESUME_CHECKPOINT == 'auto':
    best_path = repo / "outputs/checkpoint_best.pt"
    if best_path.exists():
        print(f"\n  Resume: 'auto' — will load checkpoint_best.pt")
    else:
        # Check for epoch checkpoints
        epoch_files = sorted((repo / "outputs").glob("checkpoint_epoch_*.pt"))
        if epoch_files:
            print(f"\n  Resume: 'auto' — no checkpoint_best.pt, will use {epoch_files[-1].name}")
        else:
            print(f"\n  Resume: 'auto' — no checkpoints found, will train from scratch")
elif RESUME_CHECKPOINT is not None:
    resume_path = repo / RESUME_CHECKPOINT
    if not resume_path.exists():
        print(f"\n  WARNING: Resume checkpoint not found: {resume_path.name}")
        print(f"  Consider setting RESUME_CHECKPOINT = 'auto' to auto-detect.")
else:
    print(f"\n  Resume: None — training from scratch")

print()

# GPU info
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    capability = torch.cuda.get_device_capability(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_mem:.1f} GB")
    print(f"Compute capability: {capability[0]}.{capability[1]}")
    print(f"PyTorch: {torch.__version__}")
    print(f"CUDA: {torch.version.cuda}")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > GPU.")
    print(f"PyTorch: {torch.__version__}")

## Cell 5: Apply Colab-Specific Config Overrides

Patches the training config for Colab before importing the training function.
`detect_environment()` auto-detects Colab and sets optimal DataLoader/torch.compile settings.

**V12.43 model**: `resume_checkpoint='auto'` auto-detects the best checkpoint (prefers `checkpoint_best.pt`, then highest `checkpoint_epoch_*.pt`). The current checkpoint is V12.43 (vocab=148, d_model=576). V13/V14/V15 features are disabled — no vocab expansion or auto-migration will occur.

**Phase 2**: Self-supervised sub-epochs enabled via `PHASE2_ENABLED` in Cell 2. n_samples auto-scales with GPU VRAM (A100 40GB -> 128, A100 80GB -> 256). Discovery tracker logs novel formulas to `outputs/phase2_discoveries.jsonl`.

In [None]:
import importlib

# Add scripts/ to path so we can import train_v12_clean as a module
scripts_path = str(repo / "scripts")
if scripts_path not in sys.path:
    sys.path.insert(0, scripts_path)

# Import (or reload) the training module.
# IMPORTANT: reload() ensures TRAIN_CONFIG is reset to defaults when cells are
# re-run without restarting the runtime. Without this, batch_size_multiplier
# from detect_environment() gets applied multiple times (42 -> 2100 -> 105000).
import train_v12_clean
importlib.reload(train_v12_clean)

# --- Colab-specific overrides ---
# NOTE: DataLoader settings (num_workers, pin_memory, persistent_workers)
# and torch.compile settings (use_torch_compile, compile_mode) are
# auto-detected by detect_environment() inside train(). No manual override needed.

# More frequent checkpoints — Colab can disconnect even on Pro tier
train_v12_clean.TRAIN_CONFIG['checkpoint_interval'] = 25

# Gradient checkpointing incompatible with torch.compile; disabled for V12.20+
train_v12_clean.TRAIN_CONFIG['use_gradient_checkpointing'] = False

# Disable z_cache writes to save Colab disk space (~400MB/epoch)
train_v12_clean.TRAIN_CONFIG['z_cache_every_epoch'] = False

# --- RL gating: gate on AR exact >= 40% (script default) ---
# Adaptive TF (V15.1) closes the exposure bias gap first. RL kicks in
# when 40% of RLOO samples get positive reward — enough signal for useful
# gradients. Saves 5x epoch time while TF does the heavy lifting.
# (No override needed — uses script default rl_min_ar_exact=0.40)

# --- PhysZ: always on — physical regularizers should always be active ---
# PhysZ and SC losses are physically grounded regularizers, not auxiliary
# tricks. They should always contribute to z-space organization.
# The warmup ramp still applies on first activation to avoid sudden shock.
train_v12_clean.TRAIN_CONFIG['use_physics_z'] = True

# --- Phase 2: Self-Supervised Training ---
train_v12_clean.TRAIN_CONFIG['phase2_enabled'] = PHASE2_ENABLED
train_v12_clean.TRAIN_CONFIG['phase2_start'] = PHASE2_START
train_v12_clean.TRAIN_CONFIG['phase2_n_samples'] = PHASE2_N_SAMPLES
train_v12_clean.TRAIN_CONFIG['phase2_interval'] = PHASE2_INTERVAL
# Phase 2 activation: TF exact >= 90% AND AR exact >= 60%.
# Phase 2 generates formulas and re-encodes them (round-trip consistency).
# At low AR, most generated formulas are garbage — the training signal is
# thin and noisy. At 60% AR, enough generated formulas are valid for
# meaningful self-supervised learning. RL (which activates at 40% AR)
# already provides sequence-level training at lower AR thresholds.
train_v12_clean.TRAIN_CONFIG['phase2_auto_min_exact'] = 0.90
train_v12_clean.TRAIN_CONFIG['phase2_min_ar_exact'] = 0.60
train_v12_clean.TRAIN_CONFIG['phase2_min_resume_epochs'] = 50

# --- User overrides from Cell 2 ---

train_v12_clean.TRAIN_CONFIG['num_epochs'] = NUM_EPOCHS

# Batch size: 'auto' lets detect_environment() apply the correct multiplier
# to the default (42). Setting a specific int skips the multiplier.
if BATCH_SIZE == 'auto':
    # Keep default (42) — detect_environment() in train() will apply
    # the correct multiplier for the GPU class (e.g. x25 for A100 80GB = 1050)
    pass
else:
    train_v12_clean.TRAIN_CONFIG['batch_size'] = int(BATCH_SIZE)

# Checkpoint resume — 'auto' prefers checkpoint_best.pt, then highest
# checkpoint_epoch_*.pt. Set RESUME_CHECKPOINT to a specific path to override.
train_v12_clean.TRAIN_CONFIG['resume_checkpoint'] = RESUME_CHECKPOINT

# --- Redirect paths to Drive-based repo ---

train_v12_clean.PROJECT_ROOT = repo
train_v12_clean.CONTRASTIVE_DATA_PATH = repo / 'data/processed/supercon_fractions_contrastive.csv'
train_v12_clean.DATA_PATH = repo / 'data/processed/supercon_fractions_contrastive.csv'
train_v12_clean.HOLDOUT_PATH = repo / 'data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json'
train_v12_clean.OUTPUT_DIR = repo / 'outputs'

# Ensure output directory exists
train_v12_clean.OUTPUT_DIR.mkdir(exist_ok=True)

# --- Print final config ---

print("Colab config applied (V12.43 + Phase 2):")
print(f"  resume_checkpoint: {train_v12_clean.TRAIN_CONFIG['resume_checkpoint']}")
print(f"  checkpoint_interval: {train_v12_clean.TRAIN_CONFIG['checkpoint_interval']}")
print(f"  num_epochs: {train_v12_clean.TRAIN_CONFIG['num_epochs']}")
print(f"  batch_size: {train_v12_clean.TRAIN_CONFIG['batch_size']} (detect_environment will apply GPU-specific multiplier)")
print(f"  rl_min_ar_exact: {train_v12_clean.TRAIN_CONFIG.get('rl_min_ar_exact', 0)} (suppress RL below this AR exact)")
print(f"  rl_auto_scale: {train_v12_clean.TRAIN_CONFIG.get('rl_auto_scale', False)}")
print(f"  rl_auto_scale_target: {train_v12_clean.TRAIN_CONFIG.get('rl_auto_scale_target', 10.0)}")
print(f"  use_physics_z: {train_v12_clean.TRAIN_CONFIG.get('use_physics_z', False)} (PhysZ regularization)")
print(f"  tf_onset: {train_v12_clean.TRAIN_CONFIG.get('tf_onset', 0.80)} | tf_floor: {train_v12_clean.TRAIN_CONFIG.get('tf_floor', 0.20)}")
print(f"  phase2_enabled: {train_v12_clean.TRAIN_CONFIG['phase2_enabled']}")
print(f"  phase2_auto_min_exact: {train_v12_clean.TRAIN_CONFIG['phase2_auto_min_exact']}")
print(f"  phase2_min_ar_exact: {train_v12_clean.TRAIN_CONFIG.get('phase2_min_ar_exact', 0)} (suppress Phase 2 below this AR exact)")
print(f"  phase2_min_resume_epochs: {train_v12_clean.TRAIN_CONFIG['phase2_min_resume_epochs']}")
print(f"  OUTPUT_DIR: {train_v12_clean.OUTPUT_DIR}")
print("  (torch.compile + DataLoader settings auto-detected by detect_environment())")

## Cell 5b: Setup Gist Logging (Optional)

If `GIST_ID` is set in Cell 2, this hooks into the training loop to push metrics to a GitHub Gist every N epochs. You can then monitor training remotely via `gh gist view GIST_ID` or by visiting the gist URL.

In [None]:
import json
import io
import sys
import requests
from datetime import datetime, timezone

_gist_token = None
_gist_log = {"history": [], "console_log": [], "status": "initialized"}

def _get_github_token():
    """Get GitHub token from Colab secrets."""
    global _gist_token
    if _gist_token is not None:
        return _gist_token
    try:
        from google.colab import userdata
        _gist_token = userdata.get('GITHUB_TOKEN')
        return _gist_token
    except Exception as e:
        print(f"  [Gist] Could not read GITHUB_TOKEN from Colab secrets: {e}")
        return None

def update_gist(epoch, metrics, best_exact, status="training", console_lines=None):
    """Push current training metrics + console log to the GitHub Gist."""
    if GIST_ID is None:
        return

    token = _get_github_token()
    if token is None:
        return

    entry = {
        "epoch": epoch,
        # Core losses
        "loss": round(metrics.get("loss", 0), 4),
        "exact_match": round(metrics.get("exact_match", 0) * 100, 2),
        "token_accuracy": round(metrics.get("accuracy", 0) * 100, 2),
        "tc_loss": round(metrics.get("tc_loss", 0), 4),
        "magpie_loss": round(metrics.get("magpie_loss", 0), 4),
        "stoich_loss": round(metrics.get("stoich_loss", 0), 4),
        "reinforce_loss": round(metrics.get("reinforce_loss", 0), 4),
        "mean_reward": round(metrics.get("mean_reward", 0), 3),
        "entropy": round(metrics.get("entropy", 0), 3),
        "z_norm": round(metrics.get("z_norm", 0), 1),
        # Physics & classification losses (V15.0: added back)
        "hp_loss": round(metrics.get("hp_loss", 0), 4),
        "sc_loss": round(metrics.get("sc_loss", 0), 4),
        "theory_loss": round(metrics.get("theory_loss", 0), 4),
        "physics_z_loss": round(metrics.get("physics_z_loss", 0), 4),
        "family_loss": round(metrics.get("family_loss", 0), 4),
        "tc_class_loss": round(metrics.get("tc_class_loss", 0), 4),
        # Decoder auxiliary losses
        "stop_loss": round(metrics.get("stop_loss", 0), 4),
        "type_loss": round(metrics.get("type_loss", 0), 4),
        "type_accuracy": round(metrics.get("type_accuracy", 0) * 100, 2),
        # Constraint zoo
        "constraint_zoo_loss": round(metrics.get("constraint_zoo_loss", 0), 4),
        # SC / non-SC breakdown
        "sc_exact_match": round(metrics.get("sc_exact_match", 0) * 100, 2),
        "non_sc_exact_match": round(metrics.get("non_sc_exact_match", 0) * 100, 2),
        # Phase 2: Self-supervised metrics (only present on Phase 2 sub-epochs)
        "phase2_total_loss": round(metrics.get("phase2_total_loss", 0), 4),
        "phase2_z_mse": round(metrics.get("phase2_z_mse", 0), 4),
        "phase2_tc_mse": round(metrics.get("phase2_tc_mse", 0), 4),
        "phase2_valid_rate": round(metrics.get("phase2_valid_rate", 0), 3),
        "phase2_unique_rate": round(metrics.get("phase2_unique_rate", 0), 3),
        "phase2_weight": round(metrics.get("phase2_weight", 0), 4),
        "phase2_n_valid": int(metrics.get("phase2_n_valid", 0)),
        "phase2_n_degenerate": int(metrics.get("phase2_n_degenerate", 0)),
        "phase2_n_novel": int(metrics.get("phase2_n_novel", 0)),
        "phase2_n_holdout_recovered": int(metrics.get("phase2_n_holdout_recovered", 0)),
        "phase2_collapse_active": bool(metrics.get("phase2_collapse_active", False)),
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }

    _gist_log["status"] = status
    _gist_log["best_exact_match"] = round(best_exact * 100, 2)
    _gist_log["last_update"] = entry["timestamp"]
    _gist_log["current_epoch"] = epoch
    _gist_log["total_epochs"] = NUM_EPOCHS

    # Keep last 200 metric entries
    _gist_log["history"].append(entry)
    if len(_gist_log["history"]) > 200:
        _gist_log["history"] = _gist_log["history"][-200:]

    # Append console lines (the full printed output from training)
    if console_lines:
        _gist_log["console_log"].extend(console_lines)
    # Keep last 500 console lines to avoid gist getting huge
    if len(_gist_log["console_log"]) > 500:
        _gist_log["console_log"] = _gist_log["console_log"][-500:]

    try:
        resp = requests.patch(
            f"https://api.github.com/gists/{GIST_ID}",
            headers={
                "Authorization": f"token {token}",
                "Accept": "application/vnd.github.v3+json",
            },
            json={"files": {
                "training_log.json": {"content": json.dumps(_gist_log, indent=2)},
                "console_output.txt": {"content": "\n".join(_gist_log["console_log"][-500:])},
            }},
            timeout=10,
        )
        if resp.status_code != 200:
            print(f"  [Gist] Update failed (HTTP {resp.status_code})")
    except Exception as e:
        # Don't let gist errors interrupt training
        print(f"  [Gist] Update error: {e}")

# --- Stdout capture for full console logging ---
# Tee stdout so all print() output is captured AND still displayed in Colab.

class _GistTeeStream:
    """Captures stdout lines for gist logging while still printing to Colab."""
    def __init__(self, original_stdout):
        self._original = original_stdout
        self._buffer = []
        self._line_buf = ""

    def write(self, text):
        self._original.write(text)
        # Buffer lines
        self._line_buf += text
        while "\n" in self._line_buf:
            line, self._line_buf = self._line_buf.split("\n", 1)
            if line.strip():  # Skip blank lines
                self._buffer.append(line.strip())

    def flush(self):
        self._original.flush()
        # Flush partial line if any
        if self._line_buf.strip():
            self._buffer.append(self._line_buf.strip())
            self._line_buf = ""

    def drain(self):
        """Return and clear captured lines."""
        lines = self._buffer
        self._buffer = []
        return lines

    # Forward all other attributes to original stream
    def __getattr__(self, name):
        return getattr(self._original, name)

_gist_tee = None

# --- Hook into the training loop ---
# Monkey-patch save_checkpoint and train_epoch to push metrics + console to gist.
# Guard against double-patching (re-running this cell) which causes recursion.

_latest_metrics = {}
_latest_best_exact = 0.0

_GIST_PATCHED = getattr(train_v12_clean, '_GIST_PATCHED', False)

if not _GIST_PATCHED:
    _original_save_checkpoint = train_v12_clean.save_checkpoint
    _original_train_epoch = train_v12_clean.train_epoch

    def _get_real_epoch():
        """Get the actual epoch number from the training loop's shutdown state."""
        return train_v12_clean._shutdown_state.get('epoch', 0)

    def _save_checkpoint_with_gist(encoder, decoder, epoch, suffix='', **kwargs):
        """Wrapper that calls original save_checkpoint then updates gist."""
        global _latest_best_exact, _gist_tee
        _original_save_checkpoint(encoder, decoder, epoch, suffix=suffix, **kwargs)

        best = kwargs.get('best_exact', _latest_best_exact)
        _latest_best_exact = max(_latest_best_exact, best)

        if GIST_ID is not None and _latest_metrics:
            status = "training"
            if suffix == 'final':
                status = "completed"
            elif suffix == 'interrupt':
                status = "interrupted"
            # Drain captured console lines
            console_lines = _gist_tee.drain() if _gist_tee else None
            update_gist(epoch, _latest_metrics, _latest_best_exact,
                       status=status, console_lines=console_lines)

    def _train_epoch_with_capture(*args, **kwargs):
        """Wrapper that captures epoch metrics and pushes to gist periodically."""
        global _latest_metrics, _gist_tee
        metrics = _original_train_epoch(*args, **kwargs)
        _latest_metrics = metrics

        # Get the REAL epoch number from the training loop's state
        real_epoch = _get_real_epoch()

        # Push gist every GIST_LOG_EVERY epochs (not just checkpoint saves)
        if GIST_ID is not None and real_epoch % GIST_LOG_EVERY == 0:
            console_lines = _gist_tee.drain() if _gist_tee else None
            update_gist(real_epoch, metrics, _latest_best_exact,
                       console_lines=console_lines)
        return metrics

    if GIST_ID is not None:
        token = _get_github_token()
        if token:
            train_v12_clean.save_checkpoint = _save_checkpoint_with_gist
            train_v12_clean.train_epoch = _train_epoch_with_capture
            train_v12_clean._GIST_PATCHED = True  # Prevent double-patching

            # Install stdout tee to capture all console output
            _gist_tee = _GistTeeStream(sys.stdout)
            sys.stdout = _gist_tee

            # Push initial status
            update_gist(0, {}, 0, status="starting",
                       console_lines=["=== Training session started ==="])
            print(f"Gist logging enabled: https://gist.github.com/{GIST_ID}")
            print(f"  Metrics + full console log pushed every {GIST_LOG_EVERY} epochs and on checkpoint saves")
            print(f"  Console output also saved to 'console_output.txt' in gist")
        else:
            print("Gist logging disabled (no GITHUB_TOKEN found in Colab secrets)")
    else:
        print("Gist logging disabled (GIST_ID is None)")
else:
    print("Gist logging already patched (safe to re-run this cell)")

## Cell 6: Run Training

This calls the existing `train()` function directly. Training output streams to the notebook.

**V12.43 flow**: On first run, you'll see:
1. Auto-detect checkpoint: `[AUTO] Found checkpoint_best.pt (epoch XXXX, best_exact=0.YYYY)`
2. Pre-training baseline eval: exact match % of loaded model before any training
3. RL probe mode: tiny RL weight for calibration, then auto-scaled
4. Phase 2 self-supervised sub-epochs (if exact >= 80% and `PHASE2_ENABLED`)

**Note**: V13/V14/V15 features are disabled. No auto-migration or vocab expansion will occur.

**Tip**: If Colab disconnects, re-run Cells 1-5 then this cell. Training auto-resumes from `checkpoint_best.pt`.

In [None]:
train_v12_clean.train()

## Cell 7: Post-Training Summary (Optional)

List saved checkpoints and show basic info.

In [None]:
import os
from pathlib import Path

output_dir = Path(REPO_PATH) / 'outputs'

if output_dir.exists():
    checkpoints = sorted(output_dir.glob('checkpoint_*.pt'))
    if checkpoints:
        print(f"Saved checkpoints ({len(checkpoints)}):")
        for cp in checkpoints:
            size_mb = cp.stat().st_size / (1024 * 1024)
            print(f"  {cp.name} ({size_mb:.1f} MB)")

        # Load best checkpoint to show final metrics
        best_path = output_dir / 'checkpoint_best.pt'
        if best_path.exists():
            import torch
            ckpt = torch.load(best_path, map_location='cpu', weights_only=False)
            print(f"\nBest checkpoint:")
            print(f"  Epoch: {ckpt.get('epoch', 'unknown')}")
            print(f"  Best exact match: {ckpt.get('best_exact', 'unknown')}")
            if 'prev_exact' in ckpt:
                print(f"  Exact at save: {ckpt['prev_exact']:.4f}")
    else:
        print("No checkpoints found.")
else:
    print(f"Output directory not found: {output_dir}")