# Superconductor VAE - Google Colab Training

Train the FullMaterialsVAE (V12) on Google Colab using your repo uploaded to Google Drive.

**Setup**: Upload the entire `superconductor-vae` repository to your Google Drive, then run these cells in order.

**Checkpoints**: Saved to `outputs/` inside the repo on Drive, so they persist across Colab sessions.

## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Cell 2: Configuration

Edit these values to match your setup.

In [None]:
# Path to the superconductor-vae repo on your Google Drive
REPO_PATH = "/content/drive/MyDrive/superconductor-vae"

# Training options (override defaults)
RESUME_FROM_CHECKPOINT = True   # Resume from best checkpoint if available
NUM_EPOCHS = 2000               # Total epochs (training resumes from last epoch)
BATCH_SIZE = 'auto'             # 'auto' scales with GPU memory, or set integer (e.g. 32, 48)

## Cell 3: Install Dependencies

PyTorch, NumPy, pandas, and scikit-learn are pre-installed on Colab. We only need matminer and pymatgen.

In [None]:
!pip install matminer pymatgen

## Cell 4: Setup Paths and Verify Environment

In [None]:
import sys
import os
from pathlib import Path

repo = Path(REPO_PATH)

# Add src/ to Python path so imports work
src_path = str(repo / "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Verify key files exist
required_files = {
    "Training data": repo / "data/processed/supercon_fractions_combined.csv",
    "Holdout set": repo / "data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json",
    "Training script": repo / "scripts/train_v12_clean.py",
    "VAE model": repo / "src/superconductor/models/attention_vae.py",
    "Decoder model": repo / "src/superconductor/models/autoregressive_decoder.py",
}

all_found = True
for name, path in required_files.items():
    exists = path.exists()
    status = "OK" if exists else "MISSING"
    print(f"  [{status}] {name}: {path}")
    if not exists:
        all_found = False

if not all_found:
    raise FileNotFoundError(
        f"Missing files. Check that REPO_PATH is correct: {REPO_PATH}"
    )

print()

# GPU info
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    capability = torch.cuda.get_device_capability(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_mem:.1f} GB")
    print(f"Compute capability: {capability[0]}.{capability[1]}")
    print(f"PyTorch: {torch.__version__}")
    print(f"CUDA: {torch.version.cuda}")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > GPU.")
    print(f"PyTorch: {torch.__version__}")

## Cell 5: Apply Colab-Specific Config Overrides

Patches the training config for Colab compatibility before importing the training function.

In [None]:
import importlib

# Add scripts/ to path so we can import train_v12_clean as a module
scripts_path = str(repo / "scripts")
if scripts_path not in sys.path:
    sys.path.insert(0, scripts_path)

# Import the training module
import train_v12_clean

# --- Colab-specific overrides ---

# DataLoader: Colab's containerized env has issues with multiprocessing workers
train_v12_clean.TRAIN_CONFIG['num_workers'] = 0
train_v12_clean.TRAIN_CONFIG['persistent_workers'] = False
train_v12_clean.TRAIN_CONFIG['pin_memory'] = True

# torch.compile: Colab often lacks the C++ toolchain (gcc/triton) needed
train_v12_clean.TRAIN_CONFIG['use_torch_compile'] = False

# --- User overrides from Cell 2 ---

train_v12_clean.TRAIN_CONFIG['num_epochs'] = NUM_EPOCHS

if BATCH_SIZE != 'auto':
    train_v12_clean.TRAIN_CONFIG['batch_size'] = int(BATCH_SIZE)

if RESUME_FROM_CHECKPOINT:
    train_v12_clean.TRAIN_CONFIG['resume_checkpoint'] = 'outputs/checkpoint_best.pt'
else:
    train_v12_clean.TRAIN_CONFIG['resume_checkpoint'] = None

# --- Redirect paths to Drive-based repo ---

train_v12_clean.PROJECT_ROOT = repo
train_v12_clean.DATA_PATH = repo / 'data/processed/supercon_fractions_combined.csv'
train_v12_clean.HOLDOUT_PATH = repo / 'data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json'
train_v12_clean.OUTPUT_DIR = repo / 'outputs'

# Ensure output directory exists
train_v12_clean.OUTPUT_DIR.mkdir(exist_ok=True)

# --- Print final config ---

print("Colab config applied:")
print(f"  num_workers: {train_v12_clean.TRAIN_CONFIG['num_workers']}")
print(f"  persistent_workers: {train_v12_clean.TRAIN_CONFIG['persistent_workers']}")
print(f"  use_torch_compile: {train_v12_clean.TRAIN_CONFIG['use_torch_compile']}")
print(f"  num_epochs: {train_v12_clean.TRAIN_CONFIG['num_epochs']}")
print(f"  batch_size: {train_v12_clean.TRAIN_CONFIG['batch_size']}")
print(f"  resume_checkpoint: {train_v12_clean.TRAIN_CONFIG['resume_checkpoint']}")
print(f"  OUTPUT_DIR: {train_v12_clean.OUTPUT_DIR}")

## Cell 6: Run Training

This calls the existing `train()` function directly. Training output streams to the notebook.

**Tip**: If Colab disconnects, re-run Cells 1-5 then this cell. With `RESUME_FROM_CHECKPOINT = True`, training picks up from the last saved best checkpoint.

In [None]:
train_v12_clean.train()

## Cell 7: Post-Training Summary (Optional)

List saved checkpoints and show basic info.

In [None]:
import os
from pathlib import Path

output_dir = Path(REPO_PATH) / 'outputs'

if output_dir.exists():
    checkpoints = sorted(output_dir.glob('checkpoint_*.pt'))
    if checkpoints:
        print(f"Saved checkpoints ({len(checkpoints)}):")
        for cp in checkpoints:
            size_mb = cp.stat().st_size / (1024 * 1024)
            print(f"  {cp.name} ({size_mb:.1f} MB)")

        # Load best checkpoint to show final metrics
        best_path = output_dir / 'checkpoint_best.pt'
        if best_path.exists():
            import torch
            ckpt = torch.load(best_path, map_location='cpu', weights_only=False)
            print(f"\nBest checkpoint:")
            print(f"  Epoch: {ckpt.get('epoch', 'unknown')}")
            print(f"  Best exact match: {ckpt.get('best_exact', 'unknown')}")
            if 'prev_exact' in ckpt:
                print(f"  Exact at save: {ckpt['prev_exact']:.4f}")
    else:
        print("No checkpoints found.")
else:
    print(f"Output directory not found: {output_dir}")