In [None]:
"""
Training Script for Diffusion Barrier Prediction - Jupyter Notebook Version

Simple script to train a GNN model on diffusion barrier data in Jupyter.

Usage in Jupyter:
    %run train_notebook.py
"""

import torch
from pathlib import Path

from config import Config
from template_graph_builder import TemplateGraphBuilder
from dataset import create_dataloaders
from model import create_model_from_config, count_parameters
from trainer import Trainer
from utils import save_model_for_inference, get_node_input_dim


def train(config, save_dir: str = "checkpoints"):
    """
    Main training function.
    
    Args:
        config: Configuration object
        save_dir: Directory to save checkpoints
    """
    print("\n" + "="*70)
    print("DIFFUSION BARRIER PREDICTION - TRAINING")
    print("="*70)
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nDevice: {device}")
    
    # ========================================================================
    # 1. SETUP GRAPH BUILDER
    # ========================================================================
    print("\n" + "-"*70)
    print("1. SETUP GRAPH BUILDER")
    print("-"*70)
    
    builder = TemplateGraphBuilder(config)
    node_input_dim = get_node_input_dim(builder)
    
    print(f"\n✓ Graph builder ready")
    print(f"  Elements: {builder.elements}")
    print(f"  Node input dim: {node_input_dim}")
    
    # ========================================================================
    # 2. CREATE MODEL
    # ========================================================================
    print("\n" + "-"*70)
    print("2. CREATE MODEL")
    print("-"*70)
    
    model = create_model_from_config(config, node_input_dim)
    model = model.to(device)
    
    params = count_parameters(model)
    
    print(f"\n✓ Model created")
    print(f"  Architecture:")
    print(f"    GNN layers: {config.gnn_num_layers}")
    print(f"    GNN hidden dim: {config.gnn_hidden_dim}")
    print(f"    MLP hidden dims: {config.mlp_hidden_dims}")
    print(f"  Parameters:")
    print(f"    Encoder: {params['encoder']:,}")
    print(f"    Predictor: {params['predictor']:,}")
    print(f"    Total: {params['total']:,}")
    
    # ========================================================================
    # 3. LOAD DATA
    # ========================================================================
    print("\n" + "-"*70)
    print("3. LOAD DATA")
    print("-"*70)
    
    train_loader, val_loader = create_dataloaders(
        config,
        val_split=config.val_split,
        random_seed=config.random_seed
    )
    
    # ========================================================================
    # 4. TRAIN MODEL
    # ========================================================================
    print("\n" + "-"*70)
    print("4. TRAIN MODEL")
    print("-"*70)
    
    # Create trainer
    trainer = Trainer(model, config, save_dir=save_dir)
    
    # Train
    history = trainer.train(train_loader, val_loader, verbose=True)
    
    # ========================================================================
    # 5. SAVE FINAL MODEL
    # ========================================================================
    print("\n" + "-"*70)
    print("5. SAVE FINAL MODEL")
    print("-"*70)
    
    # Load best model
    best_checkpoint_path = Path(save_dir) / "best_model.pt"
    if best_checkpoint_path.exists():
        trainer.load_checkpoint(str(best_checkpoint_path))
        print(f"\n✓ Loaded best model from training")
    
    # Save for inference
    final_model_path = Path(save_dir) / "final_model_for_inference.pt"
    save_model_for_inference(
        model=trainer.model,
        node_input_dim=node_input_dim,
        elements=builder.elements,
        config=config,
        filepath=str(final_model_path),
        metadata={
            'best_val_loss': trainer.best_val_loss,
            'total_epochs': trainer.current_epoch + 1,
            'train_samples': len(train_loader.dataset),
            'val_samples': len(val_loader.dataset) if val_loader else 0
        }
    )
    
    # ========================================================================
    # 6. SUMMARY
    # ========================================================================
    print("\n" + "="*70)
    print("TRAINING SUMMARY")
    print("="*70)
    
    print(f"\nTraining completed:")
    print(f"  Total epochs: {trainer.current_epoch + 1}")
    print(f"  Best val loss: {trainer.best_val_loss:.4f}")
    print(f"  Train samples: {len(train_loader.dataset)}")
    if val_loader:
        print(f"  Val samples: {len(val_loader.dataset)}")
    
    print(f"\nFiles saved:")
    print(f"  Best checkpoint: {best_checkpoint_path}")
    print(f"  Final model: {final_model_path}")
    print(f"  History: {Path(save_dir) / 'training_history.json'}")
    
    print(f"\nTo use this model for prediction:")
    print(f"  from utils import load_model_for_inference")
    print(f"  model, checkpoint = load_model_for_inference('{final_model_path}', config)")
    
    print("\n" + "="*70 + "\n")
    
    return trainer, history


# ============================================================================
# MAIN EXECUTION - Runs automatically when script is executed
# ============================================================================

# Create and configure
config = Config()

# Optional: Customize config here
# config.epochs = 500
# config.batch_size = 64
# config.learning_rate = 0.001
# config.use_wandb = True
# config.wandb_run_name = "my-experiment"
# config.wandb_tags = ["baseline", "v1"]

# Run training
print("Starting training with configuration:")
print(f"  Epochs: {config.epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Scheduler: {config.scheduler_type}")
print(f"  Wandb: {config.use_wandb}")
if config.use_wandb:
    print(f"  Wandb project: {config.wandb_project}")
    print(f"  Wandb run name: {config.wandb_run_name or 'auto-generated'}")

trainer, history = train(config, save_dir="checkpoints")

print("\n✓ Training completed! Trainer and history are available in variables.")