In [None]:
# JUPYTER WORKFLOW (safe mode): do NOT reset DB -> if DB exists, raise error
# -------------------------------------------------------------------------
# Steps:
# 1) Check DB state: if CSV or DB folder already exist -> raise an error (abort).
# 2) Initialize Oracle -> will create fresh CSV header and DB structure.
# 3) Generate N compositions from Config elements and run Oracle for each.
# 4) Build DataLoaders, create model, and train on the freshly created data.
#
# Notes:
# - Code comments are in English (per your preference).
# - No argparse; tweak parameters at the top of this cell.

# =========================
# Parameters (edit here)
# =========================
N_POINTS = 5000            # number of data points (compositions) to generate
SEED = 42               # random seed for composition sampling
VAL_SPLIT = None        # set to a float (e.g., 0.2) to override Config.val_split, or leave None
EPOCHS = None           # set to an int to override Config.epochs, or leave None
DISABLE_WANDB = False    # True: disable wandb for quick tests

# =========================
# Imports and path setup
# =========================
import sys
from pathlib import Path
import time
import numpy as np

# Ensure module path (adjust if needed)
MODULE_DIRS = ["/mnt/data", "."]
for d in MODULE_DIRS:
    if d not in sys.path:
        sys.path.insert(0, d)

# Project modules expected to be available
from config import Config
from oracle import Oracle
from dataset import create_dataloaders
from template_graph_builder import TemplateGraphBuilder
from model import create_model_from_config
from trainer import Trainer


# =========================
# Helper functions
# =========================
def print_header(title: str):
    """Pretty section header for console output."""
    print("\n" + "=" * 70)
    print(title)
    print("=" * 70)

def ensure_database_absent_or_raise(cfg: Config) -> None:
    """
    Enforce a clean start WITHOUT deleting anything.
    If CSV or DB directory already exist, raise an error and abort.
    """
    csv_path = Path(cfg.csv_path)
    db_dir = Path(cfg.database_dir)

    exists_csv = csv_path.exists()
    exists_db_dir = db_dir.exists() and any(db_dir.iterdir())  # treat empty dir as "exists with content"

    if exists_csv or exists_db_dir:
        # Build a helpful error message
        msg_lines = ["Existing database detected – refusing to overwrite.",
                     f"- CSV path: {csv_path}  -> {'EXISTS' if exists_csv else 'missing'}",
                     f"- DB dir:   {db_dir}   -> {'EXISTS & non-empty' if exists_db_dir else ('exists & empty' if db_dir.exists() else 'missing')}",
                     "",
                     "To proceed, move or delete the existing database/CSV and re-run this cell."]
        raise FileExistsError("\n".join(msg_lines))

def dirichlet_compositions(elements, n_points, seed=42):
    """
    Generate n compositions on the simplex using a Dirichlet distribution.
    Returns a list of dicts: {element: fraction}, fractions sum to 1.0.
    """
    rng = np.random.default_rng(seed)
    alpha = np.ones(len(elements), dtype=float)  # uniform Dirichlet
    samples = rng.dirichlet(alpha, size=n_points)

    comps = []
    for row in samples:
        row = np.clip(row, 0.0, 1.0)
        row = row / row.sum()
        comps.append({el: float(fr) for el, fr in zip(elements, row)})

    # Ensure the equimolar point appears (helpful for sanity)
    if n_points >= 1:
        equi = {el: 1.0 / len(elements) for el in elements}
        comps[0] = equi

    return comps


# =========================
# Main notebook workflow
# =========================
# 1) Load and optionally tweak config
print_header("CONFIG SUMMARY")
config = Config()

if DISABLE_WANDB:
    config.use_wandb = False  # turn off wandb logging for quick tests

if EPOCHS is not None:
    config.epochs = int(EPOCHS)

if VAL_SPLIT is not None:
    config.val_split = float(VAL_SPLIT)

print(f"Elements (from Config): {config.elements}")
print(f"Database dir:          {config.database_dir}")
print(f"CSV path:              {config.csv_path}")
print(f"W&B enabled:           {config.use_wandb}")
print(f"Epochs:                {config.epochs}")
print(f"Val split:             {config.val_split}")

# 2) Safety check: do NOT reset DB; if something exists, raise error
print_header("SAFETY CHECK (DB must be absent)")
ensure_database_absent_or_raise(config)
print("✓ No existing DB/CSV detected -> safe to create a new one.")

# 3) Initialize Oracle (this will create the CSV header and required folders)
print_header("INIT ORACLE")
oracle = Oracle(config)  # expected to create fresh CSV header and structure

# 4) Build composition list of length N_POINTS based on elements in Config
print_header("GENERATE COMPOSITIONS")
elements = list(config.elements)
comps = dirichlet_compositions(elements, n_points=N_POINTS, seed=SEED)
for i, c in enumerate(comps, 1):
    fr_str = ", ".join(f"{k}={v:.3f}" for k, v in c.items())
    print(f"[{i:02d}] {fr_str}")

# 5) Run NEB/CHGNet for each composition to populate the database
print_header("POPULATE DATABASE (NEB runs)")
t0 = time.time()
successes = 0
for idx, comp in enumerate(comps, 1):
    ok = oracle.calculate(comp)  # performs full pipeline and writes CSV/structures
    if ok:
        successes += 1
dt = time.time() - t0
print(f"\n✓ Finished generating data: {successes}/{len(comps)} successful in {dt/60.0:.1f} min")

if successes == 0:
    raise RuntimeError("No successful data points were created; aborting training.")

# 6) Create dataloaders (reads config.csv_path, filters barriers, builds graphs on-the-fly)
print_header("CREATE DATALOADERS")
train_loader, val_loader = create_dataloaders(
    config,
    val_split=config.val_split,
    random_seed=config.random_seed
)

# 7) Determine dynamic node_input_dim using the template graph builder
#    Node features are: 3 (positions) + N (one-hot) + 4 (atomic properties)
print_header("BUILD TEMPLATE & MODEL")
builder = TemplateGraphBuilder(config, csv_path=config.csv_path)
node_input_dim = 3 + len(builder.elements) + 4
print(f"Detected elements in DB: {builder.elements}")
print(f"node_input_dim = 3 (pos) + {len(builder.elements)} (one-hot) + 4 (props) = {node_input_dim}")

# 8) Create model and trainer
model = create_model_from_config(config, node_input_dim=node_input_dim)
trainer = Trainer(model, config, save_dir=config.checkpoint_dir)

# 9) Train
print_header("TRAINING")
history = trainer.train(train_loader, val_loader, verbose=True)

print_header("DONE")
print("Best validation loss:", trainer.best_val_loss)
print("Checkpoints:", Path(config.checkpoint_dir).resolve())
