In [None]:
# === Active Learning Loop with Logging and Cycle Management ===
# Console: German; Comments: English
# Run this cell directly in VSCode / Jupyter.

import sys
import traceback
import logging
from pathlib import Path
from datetime import datetime
import shutil
import numpy as np
import pandas as pd
import torch

# --- Project imports ---
from config import Config
from oracle import Oracle
from inference import run_inference_cycle, cleanup_gpu
from trainer import Trainer
from dataset import create_dataloaders
from template_graph_builder import TemplateGraphBuilder
from model import create_model_from_config, count_parameters
from utils import get_node_input_dim, save_model_for_inference

# --- Optional Weights & Biases support ---
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

# ============================================================================
# Setup Main Logger
# ============================================================================

def setup_main_logger(config: Config):
    """Setup main logger for active learning loop."""
    logger = logging.getLogger("active_learning")
    logger.setLevel(getattr(logging, config.log_level.upper()))
    logger.handlers = []  # Clear existing handlers
    
    # File handler
    log_file = Path(config.log_dir) / "active_learning.log"
    log_file.parent.mkdir(parents=True, exist_ok=True)
    fh = logging.FileHandler(log_file)
    fh.setLevel(getattr(logging, config.log_level.upper()))
    
    # Console handler
    if config.log_to_console:
        ch = logging.StreamHandler()
        ch.setLevel(getattr(logging, config.log_level.upper()))
    
    # Formatter
    formatter = logging.Formatter(
        '%(asctime)s | %(levelname)-8s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    fh.setFormatter(formatter)
    if config.log_to_console:
        ch.setFormatter(formatter)
    
    # Add handlers
    logger.addHandler(fh)
    if config.log_to_console:
        logger.addHandler(ch)
    
    return logger

# ============================================================================
# Helper functions
# ============================================================================

def is_csv_missing_or_empty(csv_path: str) -> bool:
    """Check if the CSV file is missing or empty."""
    p = Path(csv_path)
    if not p.exists():
        return True
    try:
        df = pd.read_csv(p)
        return len(df) == 0
    except Exception:
        return True


def sample_simplex_uniform(n: int, k: int) -> np.ndarray:
    """Sample n points uniformly on a (k-1)-simplex using Dirichlet distribution."""
    return np.random.dirichlet(alpha=np.ones(k), size=n)


def initial_data_creation_if_needed(config: Config, oracle: Oracle, logger: logging.Logger):
    """
    Create an initial dataset if CSV is missing or empty.
    Uses config.al_initial_samples to determine the number of samples.
    """
    csv_path = config.csv_path
    if not is_csv_missing_or_empty(csv_path):
        logger.info(f"✓ Database found: {csv_path}")
        return

    elements = list(getattr(config, "elements", []))
    n_seed = int(getattr(config, "al_initial_samples", 0) or 0)
    if not elements or n_seed <= 0:
        raise RuntimeError(
            "Initial data creation requires 'elements' and 'al_initial_samples' > 0 in config."
        )

    logger.info("="*70)
    logger.info("INITIAL DATA CREATION (CSV empty or not found)")
    logger.info("="*70)
    logger.info(f"Elements: {elements}")
    logger.info(f"Samples to create: {n_seed}")
    logger.info(f"Target CSV: {csv_path}")

    weights = sample_simplex_uniform(n_seed, len(elements))
    compositions = []
    for row in weights:
        comp = {el: float(val) for el, val in zip(elements, row)}
        s = sum(comp.values())
        if abs(s - 1.0) > 1e-12:
            for el in comp:
                comp[el] /= s
        compositions.append(comp)

    successes = 0
    for i, comp in enumerate(compositions, 1):
        try:
            logger.info(f"  [{i}/{n_seed}] Calculating: {comp}")
            ok = oracle.calculate(comp)
            if ok is not False:
                successes += 1
        except Exception as e:
            logger.error(f"   ✗ Error at sample {i}: {e}")

    if successes == 0:
        raise RuntimeError("Initial data creation failed: no valid samples added.")
    logger.info(f"✓ Initial data creation completed. {successes} samples added.")


def get_database_stats(csv_path: str) -> dict:
    """Return summary statistics for the current database CSV."""
    p = Path(csv_path)
    if not p.exists():
        return {"n_samples": 0, "n_compositions": 0}
    try:
        df = pd.read_csv(p)
    except Exception:
        return {"n_samples": 0, "n_compositions": 0}
    return {
        "n_samples": len(df),
        "n_compositions": df['composition_string'].nunique() if 'composition_string' in df.columns else 0
    }


def train_cycle_model(config: Config, cycle: int, logger: logging.Logger) -> dict:
    """Train the model for the given active learning cycle."""
    logger.info("="*70)
    logger.info(f"TRAINING MODEL - CYCLE {cycle}")
    logger.info("="*70)

    outdir = Path(config.checkpoint_dir) / f"cycle_{cycle}"
    outdir.mkdir(parents=True, exist_ok=True)

    try:
        # Get current database size
        db_stats = get_database_stats(config.csv_path)
        n_samples = db_stats['n_samples']
        
        train_loader, val_loader = create_dataloaders(config)
        builder = TemplateGraphBuilder(config)
        node_input_dim = get_node_input_dim(builder)
        model = create_model_from_config(config, node_input_dim)
        
        # Pass cycle and n_samples to trainer for naming
        trainer = Trainer(model, config, save_dir=str(outdir), cycle=cycle)
        trainer.train(train_loader, val_loader, verbose=True)
        
        logger.info(f"✓ Model training completed (Cycle {cycle})")
    except Exception as e:
        logger.error(f"✗ Training failed: {e}")
        traceback.print_exc()


def active_learning_loop(config: Config, logger: logging.Logger):
    """Main active learning loop with initial data creation."""
    logger.info("="*70)
    logger.info("ACTIVE LEARNING LOOP STARTING")
    logger.info("="*70)
    logger.info(f"Cycles: {config.al_max_cycles}")
    logger.info(f"Test samples per cycle: {config.al_n_test}")
    logger.info(f"Query samples per cycle: {config.al_n_query}")
    logger.info(f"Elements: {config.elements}")
    logger.info("="*70)

    oracle = Oracle(config)
    initial_data_creation_if_needed(config, oracle, logger)

    # CRITICAL FIX: Use range(config.al_max_cycles) directly, do NOT multiply
    for cycle in range(config.al_max_cycles):
        logger.info("="*70)
        logger.info(f"Cycle {cycle}/{config.al_max_cycles - 1}")
        logger.info("="*70)

        db_stats = get_database_stats(config.csv_path)
        logger.info(f"Current database: {db_stats['n_samples']} samples, {db_stats['n_compositions']} compositions")

        # Model path logic
        if cycle == 0:
            logger.info("→ Training initial model ...")
            train_cycle_model(config, cycle, logger)
        else:
            prev_model = Path(config.checkpoint_dir) / f"cycle_{cycle-1}" / "best_model.pt"
            if not prev_model.exists():
                logger.warning("⚠️ Previous model not found, training new model ...")
                train_cycle_model(config, cycle, logger)
            else:
                logger.info(f"→ Using model: {prev_model}")

        # Determine current model path
        current_model = Path(config.checkpoint_dir) / f"cycle_{cycle}" / "best_model.pt"
        
        # Run inference with CURRENT cycle number (not cycle*2!)
        logger.info("→ Starting inference cycle ...")
        try:
            run_inference_cycle(cycle, str(current_model), oracle, config, verbose=True)
            
            # After inference, train next model
            if cycle < config.al_max_cycles - 1:
                logger.info("→ Training model with new data ...")
                train_cycle_model(config, cycle + 1, logger)
                
        except Exception as e:
            logger.error(f"✗ Inference failed: {e}")
            traceback.print_exc()
            continue

        logger.info(f"✓ Cycle {cycle} completed.")


# ============================================================================
# Execution (show config + confirmation)
# ============================================================================

config = Config()

# Setup main logger
logger = setup_main_logger(config)

# Print config summary for user review
logger.info("="*70)
logger.info("CURRENT CONFIG")
logger.info("="*70)
for key, val in config.__dict__.items():
    logger.info(f"{key:25s}: {val}")
logger.info("="*70)

confirm = input("❓ Do you want to start the Active Learning workflow? (y/n): ").strip().lower()
if confirm == "y":
    try:
        active_learning_loop(config, logger)
    except Exception as e:
        logger.error(f"✗ Run aborted: {e}")
        traceback.print_exc()
else:
    logger.info("↪️  Aborted – no execution started.")
