# Hyperparameter Optimization for Weather Prediction Models

This notebook demonstrates how to use the HPO module for tuning weather prediction models.

## Supported Model Types
- `simple-linear`: Simple linear regression
- `linear`: Linear regression with regularization (Ridge, Lasso, ElasticNet)
- `lgbm`: LightGBM gradient boosting
- `gnn`: Graph Neural Network (main architecture)
- `cnn`: U-NET Convolutional Neural Network

In [None]:
import sys
sys.path.insert(0, "../src")

from hpo import HPO, HPOConfig, ModelType
from stwp.features import Features

## 1. Configuration

HPO can be configured using either `HPOConfig` dataclass or a dictionary.

In [None]:
# Available model types
BASELINES = [ModelType.SIMPLE_LINEAR, ModelType.LINEAR, ModelType.LGBM]
NEURAL_NETS = [ModelType.GNN, ModelType.CNN]

print(f"Features: {Features.as_list()}")
print(f"Feature count: {Features.COUNT}")

## 2. Baseline Models HPO

For baseline models (linear, lgbm), we can optimize:
- Input sequence length
- Forecast horizon
- Model-specific hyperparameters (alpha, regressor type, n_estimators, etc.)

In [None]:
# Configuration using HPOConfig dataclass
baseline_config = HPOConfig(
    model_type=ModelType.SIMPLE_LINEAR,
    n_trials=10,               # Number of Optuna trials for param search
    sequence_n_trials=15,      # Max sequence length to try
    fh_n_trials=10,            # Max forecast horizon to try
    sequence_length=5,         # Initial sequence length
    forecast_horizon=1,        # Initial forecast horizon
    use_neighbours=False,      # Whether to use spatial neighbours
)

# Or using a dictionary
baseline_config_dict = {
    "model_type": "simple-linear",
    "n_trials": 10,
    "sequence_n_trials": 15,
    "fh_n_trials": 10,
    "sequence_length": 5,
    "forecast_horizon": 1,
}

In [None]:
# Initialize HPO for simple linear regression
hpo = HPO(baseline_config)

# Run full study: sequence -> params -> forecast horizon
# results = hpo.run_full_study()

# Or run individual studies:
# seq_results = hpo.determine_best_sequence()
# param_results = hpo.run_param_study()
# fh_results = hpo.determine_best_fh()

In [None]:
# Linear regression with regularization
linear_config = HPOConfig(
    model_type=ModelType.LINEAR,
    n_trials=50,
    sequence_n_trials=15,
    fh_n_trials=10,
    max_alpha=10.0,  # Max regularization strength for Optuna
)

# hpo_linear = HPO(linear_config)
# results = hpo_linear.run_full_study()

In [None]:
# LightGBM gradient boosting
lgbm_config = HPOConfig(
    model_type=ModelType.LGBM,
    n_trials=100,
    sequence_n_trials=15,
    fh_n_trials=10,
)

# hpo_lgbm = HPO(lgbm_config)
# results = hpo_lgbm.run_full_study()

## 3. Neural Network Models HPO

For neural networks (GNN, CNN), we optimize:
- Input sequence length
- Forecast horizon
- GNN-specific: number of graph cells (layers)

In [None]:
from stwp.models.gnn.gnn_module import ArchitectureType

# GNN configuration
gnn_config = HPOConfig(
    model_type=ModelType.GNN,
    sequence_n_trials=10,
    fh_n_trials=10,
    num_epochs=100,             # Training epochs per trial
    sequence_length=5,
    forecast_horizon=1,
    subset=None,                # Use full dataset (or int for subset)
    # GNN-specific params
    gnn_hidden_dim=32,
    gnn_lr=1e-3,
    gnn_architecture=ArchitectureType.TRANSFORMER,
)

In [None]:
# Initialize GNN HPO
# hpo_gnn = HPO(gnn_config)

# Run GNN layer study (find optimal number of graph cells)
# layer_results = hpo_gnn.gnn_layer_study()
# print(f"Best layer count: {layer_results.best_value}")

# Run sequence length study
# seq_results = hpo_gnn.determine_best_sequence()
# print(f"Best sequence length: {seq_results.best_value}")

# Run forecast horizon study
# fh_results = hpo_gnn.determine_best_fh()
# print(f"Best forecast horizon: {fh_results.best_value}")

In [None]:
# CNN (U-NET) configuration
cnn_config = HPOConfig(
    model_type=ModelType.CNN,
    sequence_n_trials=10,
    fh_n_trials=10,
    num_epochs=100,
    sequence_length=5,
    forecast_horizon=1,
)

# hpo_cnn = HPO(cnn_config)
# results = hpo_cnn.run_full_study()

## 4. Additional Analysis Methods

After finding optimal hyperparameters, run additional analyses.

In [None]:
# Test different scalers
# Supported: 'standard', 'min_max', 'max_abs', 'robust'
# scaler_results = hpo.test_scalers()
# for scaler, metrics in scaler_results.items():
#     print(f"{scaler}: RMSE={metrics['rmse']:.4f}, time={metrics['execution_time']:.2f}s")

In [None]:
# Calculate monthly prediction errors
# monthly_errors = hpo.monthly_error()
# for month, error in monthly_errors.items():
#     print(f"{month}: RMSE={error:.4f}")

In [None]:
# Save results to JSON files
# hpo.save_results()

# Print summary report
# hpo.report()

## 5. Accessing Results

HPO results are stored in `HPOResults` dataclass.

In [None]:
# After running studies, access results:
# print(f"Best sequence length: {hpo.best_sequence}")
# print(f"Best forecast horizon: {hpo.best_fh}")
# print(f"Best parameters: {hpo.best_params}")

# Detailed results structure:
# results = hpo.results
# results.sequence_results  # StudyResults for sequence optimization
# results.fh_results        # StudyResults for forecast horizon
# results.params            # Best hyperparameters
# results.scaler_metrics    # Scaler comparison results
# results.monthly_errors    # Monthly error breakdown
# results.layer_results     # GNN layer study results

## 6. Visualization

Results are automatically saved to `modelsplots.json` for visualization.

In [None]:
import json
import matplotlib.pyplot as plt

def plot_study_results(model_type: str):
    """Plot HPO results from saved JSON."""
    try:
        with open("modelsplots.json") as f:
            data = json.load(f)
    except FileNotFoundError:
        print("No results file found. Run HPO studies first.")
        return
    
    if model_type not in data:
        print(f"No results for {model_type}")
        return
    
    results = data[model_type]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Sequence length vs RMSE
    if results["sequence_plot_x"]:
        axes[0].plot(results["sequence_plot_x"], results["sequence_plot_y"], marker="o")
        axes[0].set_xlabel("Sequence Length")
        axes[0].set_ylabel("Mean RMSE")
        axes[0].set_title(f"{model_type}: Sequence Length Study")
        axes[0].grid(True)
    
    # Forecast horizon vs RMSE
    if results["fh_plot_x"]:
        axes[1].plot(results["fh_plot_x"], results["fh_plot_y"], marker="o")
        axes[1].set_xlabel("Forecast Horizon")
        axes[1].set_ylabel("Mean RMSE")
        axes[1].set_title(f"{model_type}: Forecast Horizon Study")
        axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

# plot_study_results("simple-linear")

## 7. Using Pre-trained Models

For neural networks, you can load pre-trained models for evaluation.

In [None]:
# Load pre-trained GNN model
# from stwp.models.gnn.trainer import Trainer as GNNTrainer
# from stwp.models.gnn.gnn_module import ArchitectureType

# trainer = GNNTrainer(
#     architecture=ArchitectureType.TRANSFORMER,
#     hidden_dim=32,
#     lr=1e-3,
# )
# trainer.load_model("path/to/model.pt")

# Evaluate on test set
# metrics, y_hat = trainer.evaluate("test", verbose=True)
# rmse, mae = metrics

# Plot predictions
# trainer.plot_predictions("test", pretty=True)

## 8. Complete Example: Full HPO Pipeline

In [None]:
def run_complete_hpo(model_type: str, quick: bool = True):
    """Run complete HPO pipeline for a model type.
    
    Args:
        model_type: One of 'simple-linear', 'linear', 'lgbm', 'gnn', 'cnn'
        quick: If True, use fewer trials for faster results
    """
    config = HPOConfig(
        model_type=ModelType(model_type),
        n_trials=10 if quick else 50,
        sequence_n_trials=5 if quick else 15,
        fh_n_trials=5 if quick else 10,
        num_epochs=10 if quick else 100,
    )
    
    hpo = HPO(config)
    
    # Run full optimization
    results = hpo.run_full_study()
    
    # Additional analyses
    hpo.test_scalers()
    hpo.monthly_error()
    
    # Save and report
    hpo.save_results()
    hpo.report()
    
    return hpo

# Example: Run quick HPO for simple linear model
# hpo = run_complete_hpo("simple-linear", quick=True)