In [None]:
# === Active Learning Loop with Initial Data Creation and Start Confirmation ===
# Console: German; Comments: English
# Run this cell directly in VSCode / Jupyter.

import sys
import traceback
from pathlib import Path
from datetime import datetime
import shutil
import numpy as np
import pandas as pd
import torch

# --- Project imports ---
from config import Config
from oracle import Oracle
from inference import run_inference_cycle, cleanup_gpu
from trainer import Trainer
from dataset import create_dataloaders
from template_graph_builder import TemplateGraphBuilder
from model import create_model_from_config, count_parameters
from utils import get_node_input_dim, save_model_for_inference

# --- Optional Weights & Biases support ---
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

# ============================================================================
# Helper functions
# ============================================================================

def is_csv_missing_or_empty(csv_path: str) -> bool:
    """Check if the CSV file is missing or empty."""
    p = Path(csv_path)
    if not p.exists():
        return True
    try:
        df = pd.read_csv(p)
        return len(df) == 0
    except Exception:
        return True


def sample_simplex_uniform(n: int, k: int) -> np.ndarray:
    """Sample n points uniformly on a (k-1)-simplex using Dirichlet distribution."""
    return np.random.dirichlet(alpha=np.ones(k), size=n)


def initial_data_creation_if_needed(config: Config, oracle: Oracle):
    """
    Create an initial dataset if CSV is missing or empty.
    Uses config.al_initial_samples to determine the number of samples.
    """
    csv_path = config.csv_path
    if not is_csv_missing_or_empty(csv_path):
        print(f"✓ Datenbank gefunden: {csv_path}")
        return

    elements = list(getattr(config, "elements", []))
    n_seed = int(getattr(config, "al_initial_samples", 0) or 0)
    if not elements or n_seed <= 0:
        raise RuntimeError(
            "Initial data creation requires 'elements' and 'al_initial_samples' > 0 in config."
        )

    print("\n======================================================================")
    print("ERSTELLUNG INITIALER DATEN (CSV leer oder nicht vorhanden)")
    print("======================================================================")
    print(f"Elemente: {elements}")
    print(f"Zu erstellende Samples: {n_seed}")
    print(f"Ziel-CSV: {csv_path}")

    weights = sample_simplex_uniform(n_seed, len(elements))
    compositions = []
    for row in weights:
        comp = {el: float(val) for el, val in zip(elements, row)}
        s = sum(comp.values())
        if abs(s - 1.0) > 1e-12:
            for el in comp:
                comp[el] /= s
        compositions.append(comp)

    successes = 0
    for i, comp in enumerate(compositions, 1):
        try:
            print(f"  [{i}/{n_seed}] Berechne: {comp}")
            ok = oracle.calculate(comp)
            if ok is not False:
                successes += 1
        except Exception as e:
            print(f"   ✗ Fehler bei Sample {i}: {e}")

    if successes == 0:
        raise RuntimeError("Initial data creation failed: no valid samples added.")
    print(f"\n✓ Initiale Datenerstellung abgeschlossen. {successes} Samples hinzugefügt.\n")


def get_database_stats(csv_path: str) -> dict:
    """Return summary statistics for the current database CSV."""
    p = Path(csv_path)
    if not p.exists():
        return {"n_samples": 0, "n_compositions": 0}
    try:
        df = pd.read_csv(p)
    except Exception:
        return {"n_samples": 0, "n_compositions": 0}
    return {
        "n_samples": len(df),
        "n_compositions": df['composition_string'].nunique() if 'composition_string' in df.columns else 0
    }


def train_cycle_model(config: Config, cycle: int) -> dict:
    """Train the model for the given active learning cycle."""
    print("\n======================================================================")
    print(f"TRAINING MODEL - ZYKLUS {cycle}")
    print("======================================================================")

    outdir = Path(config.checkpoint_dir) / f"cycle_{cycle}"
    outdir.mkdir(parents=True, exist_ok=True)

    try:
        train_loader, val_loader = create_dataloaders(config)
        builder = TemplateGraphBuilder(config)
        node_input_dim = get_node_input_dim(builder)
        model = create_model_from_config(config, node_input_dim)
        trainer = Trainer(model, config, save_dir=str(outdir))
        trainer.train(train_loader, val_loader, verbose=True)
        print(f"✓ Modelltraining abgeschlossen (Zyklus {cycle})")
    except Exception as e:
        print(f"✗ Training fehlgeschlagen: {e}")
        traceback.print_exc()


def active_learning_loop(config: Config):
    """Main active learning loop with initial data creation."""
    print("\n======================================================================")
    print("AKTIVER LERNZYKLUS STARTET")
    print("======================================================================")
    print(f"Zyklen: {config.al_max_cycles}")
    print(f"Samples pro Test: {config.al_n_test}")
    print(f"Samples pro Abfrage: {config.al_n_query}")
    print(f"Elemente: {config.elements}")
    print("======================================================================\n")

    oracle = Oracle(config)
    initial_data_creation_if_needed(config, oracle)

    for cycle in range(config.al_max_cycles):
        print(f"\n======================================================================")
        print(f"Zyklus {cycle}/{config.al_max_cycles}")
        print("======================================================================")

        db_stats = get_database_stats(config.csv_path)
        print(f"Aktuelle Datenbank: {db_stats['n_samples']} Samples, {db_stats['n_compositions']} Kompositionen")

        # Model path logic
        if cycle == 0:
            print("→ Trainiere initiales Modell ...")
            train_cycle_model(config, cycle)
        else:
            prev_model = Path(config.checkpoint_dir) / f"cycle_{cycle-1}" / "best_model.pt"
            if not prev_model.exists():
                print("⚠️ Vorheriges Modell nicht gefunden, trainiere neu ...")
                train_cycle_model(config, cycle)
            else:
                print(f"→ Verwende Modell: {prev_model}")

        # Run inference
        print("→ Starte Inferenzzyklus ...")
        try:
            run_inference_cycle(cycle, str(prev_model), oracle, config, verbose=True)
        except Exception as e:
            print(f"✗ Inferenz fehlgeschlagen: {e}")
            traceback.print_exc()
            continue

        print(f"✓ Zyklus {cycle} abgeschlossen.")


# ============================================================================
# Execution (show config + confirmation)
# ============================================================================

config = Config()

# Print config summary for user review
print("\n====================== AKTUELLE CONFIG ======================")
for key, val in config.__dict__.items():
    print(f"{key:25s}: {val}")
print("==============================================================\n")

confirm = input("❓ Willst du den Active-Learning-Workflow wirklich starten? (y/n): ").strip().lower()
if confirm == "y":
    try:
        active_learning_loop(config)
    except Exception as e:
        print(f"\n✗ Lauf abgebrochen: {e}")
        traceback.print_exc()
else:
    print("↪️  Abgebrochen — keine Ausführung gestartet.")



csv_path                 : database_navi.csv
checkpoint_dir           : checkpoints
database_dir             : database
elements                 : ['Mo', 'Nb', 'Ta', 'W', 'Cr']
supercell_size           : 4
lattice_parameter        : 3.2
cutoff_radius            : 3.5
max_neighbors            : 50
batch_size               : 32
num_workers              : 0
min_barrier              : 0.1
max_barrier              : 15.0
val_split                : 0.1
random_seed              : 42
gnn_hidden_dim           : 64
gnn_num_layers           : 5
gnn_embedding_dim        : 64
mlp_hidden_dims          : [1024, 512, 256]
dropout                  : 0.15
learning_rate            : 0.0005
weight_decay             : 0.01
gradient_clip_norm       : 1.0
epochs                   : 1000
patience                 : 50
save_interval            : 50
use_scheduler            : True
scheduler_type           : plateau
scheduler_factor         : 0.5
scheduler_patience       : 10
scheduler_step_size      : 100
sched