# Hyperparameter Tuning with Cross-Validation

This notebook demonstrates how to use cross-validation to select optimal hyperparameters for KLRfome models. We'll explore:

1. Basic cross-validation usage
2. Grid search over parameter space
3. Model selection based on CV results
4. Understanding validation metrics


## Setup


In [1]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from rasterio.transform import from_bounds
import geopandas as gpd
from shapely.geometry import Point

from klrfome import KLRfome
from klrfome.data.simulation import create_simulated_raster_stack
from klrfome.visualization import plot_roc_curve

SEED = 42
np.random.seed(SEED)


## Create Test Dataset

We'll create a dataset with clear separation between sites and background for demonstration.


In [2]:
# Create raster stack
raster_stack = create_simulated_raster_stack(cols=80, rows=80, n_bands=3, seed=SEED)

# Create site locations
n_sites = 15
site_points = [Point(np.random.uniform(0.1, 0.9), np.random.uniform(0.1, 0.9)) 
               for _ in range(n_sites)]
sites_gdf = gpd.GeoDataFrame(geometry=site_points, crs=raster_stack.crs)

print(f"Raster: {raster_stack.data.shape}")
print(f"Sites: {len(sites_gdf)}")


Raster: (3, 80, 80)
Sites: 15


## Part 1: Basic Cross-Validation

Use the built-in `cross_validate()` method to evaluate model performance.


In [3]:
# Prepare training data
model = KLRfome(sigma=1.0, lambda_reg=0.1, n_rff_features=256, seed=SEED)
training_data = model.prepare_data(
    raster_stack=raster_stack,
    sites=sites_gdf,
    n_background=40,
    samples_per_location=20
)

print(f"Training data: {training_data.n_locations} locations")
print(f"  Sites: {training_data.n_sites}")
print(f"  Background: {training_data.n_background}")


Training data: 55 locations
  Sites: 15
  Background: 40


In [4]:
# Run cross-validation
cv_results = model.cross_validate(
    training_data=training_data,
    n_folds=5,
    stratified=True,
    seed=SEED
)

print(f"Cross-validation completed:")
print(f"  Number of folds: {cv_results['n_folds']}")
print(f"  Mean train size: {cv_results['mean_train_size']:.1f}")
print(f"  Mean test size: {cv_results['mean_test_size']:.1f}")
print(f"  Best fold: {cv_results['best_fold']}")


TypeError: KLRfome.cross_validate() got an unexpected keyword argument 'seed'

In [None]:
# Display metrics for each fold
print("\nPer-fold metrics:")
print("-" * 80)
for fold_result in cv_results['folds']:
    metrics = fold_result['metrics']
    print(f"Fold {fold_result['fold']}:")
    print(f"  Train: {fold_result['n_train']}, Test: {fold_result['n_test']}")
    print(f"  AUC: {metrics['AUC']:.3f}")
    print(f"  Accuracy: {metrics['Accuracy']:.3f}")
    print(f"  Sensitivity: {metrics['Sensitivity']:.3f}")
    print(f"  Specificity: {metrics['Specificity']:.3f}")
    print(f"  Kappa: {metrics['Kappa']:.3f}")
    print()


In [None]:
# Display aggregated metrics
print("Aggregated metrics (mean ± std across folds):")
print("-" * 80)
agg = cv_results['aggregated_metrics']
key_metrics = ['AUC', 'Accuracy', 'Sensitivity', 'Specificity', 'Kappa', 'Precision', 'F_Measure']
for metric in key_metrics:
    mean_key = f'{metric}_mean'
    std_key = f'{metric}_std'
    if mean_key in agg and std_key in agg:
        print(f"  {metric:15s}: {agg[mean_key]:.3f} ± {agg[std_key]:.3f}")


## Part 2: Grid Search Over Parameter Space

Systematically search over combinations of `sigma` and `lambda_reg` to find optimal hyperparameters.


In [None]:
# Define parameter grid
sigma_values = [0.5, 1.0, 2.0, 3.0]
lambda_values = [0.01, 0.1, 1.0]

print(f"Grid search over {len(sigma_values)} × {len(lambda_values)} = {len(sigma_values) * len(lambda_values)} combinations")
print(f"Sigma values: {sigma_values}")
print(f"Lambda values: {lambda_values}")


In [None]:
# Run grid search
grid_results = []

for sigma in sigma_values:
    for lambda_reg in lambda_values:
        print(f"\nTesting sigma={sigma:.2f}, lambda={lambda_reg:.2f}...")
        
        # Create model with these parameters
        model_grid = KLRfome(
            sigma=sigma,
            lambda_reg=lambda_reg,
            n_rff_features=256,
            window_size=3,
            seed=SEED
        )
        
        # Run cross-validation
        cv_result = model_grid.cross_validate(
            training_data=training_data,
            n_folds=5,
            stratified=True,
            seed=SEED
        )
        
        # Store results
        grid_results.append({
            'sigma': sigma,
            'lambda': lambda_reg,
            'mean_auc': cv_result['aggregated_metrics']['AUC_mean'],
            'std_auc': cv_result['aggregated_metrics']['AUC_std'],
            'mean_accuracy': cv_result['aggregated_metrics']['Accuracy_mean'],
            'std_accuracy': cv_result['aggregated_metrics']['Accuracy_std'],
            'mean_kappa': cv_result['aggregated_metrics']['Kappa_mean'],
        })
        
        print(f"  Mean AUC: {cv_result['aggregated_metrics']['AUC_mean']:.3f} ± {cv_result['aggregated_metrics']['AUC_std']:.3f}")

print(f"\n✓ Grid search completed!")


In [None]:
# Convert to DataFrame for easier analysis
results_df = pd.DataFrame(grid_results)
print("\nGrid search results:")
print(results_df.to_string(index=False))


### Visualize Grid Search Results


In [None]:
# Create heatmap of AUC
pivot_auc = results_df.pivot(index='sigma', columns='lambda', values='mean_auc')

fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(pivot_auc, annot=True, fmt='.3f', cmap='viridis', ax=ax, cbar_kws={'label': 'Mean AUC'})
ax.set_title('Cross-Validation AUC by Hyperparameters')
ax.set_xlabel('Lambda (Regularization)')
ax.set_ylabel('Sigma (Kernel Bandwidth)')
plt.tight_layout()
plt.show()


In [None]:
# Find best parameters
best_idx = results_df['mean_auc'].idxmax()
best_params = results_df.loc[best_idx]

print("Best hyperparameters:")
print(f"  Sigma: {best_params['sigma']:.2f}")
print(f"  Lambda: {best_params['lambda']:.2f}")
print(f"  Mean AUC: {best_params['mean_auc']:.3f} ± {best_params['std_auc']:.3f}")
print(f"  Mean Accuracy: {best_params['mean_accuracy']:.3f}")
print(f"  Mean Kappa: {best_params['mean_kappa']:.3f}")


## Part 3: Fit Final Model with Best Parameters

Train the final model using the best hyperparameters found via grid search.


In [None]:
# Fit final model with best parameters
final_model = KLRfome(
    sigma=best_params['sigma'],
    lambda_reg=best_params['lambda'],
    n_rff_features=256,
    window_size=3,
    seed=SEED
)

final_model.fit(training_data)

print(f"Final model fitted:")
print(f"  Converged: {final_model._fit_result.converged}")
print(f"  Iterations: {final_model._fit_result.n_iterations}")
print(f"  Final loss: {final_model._fit_result.final_loss:.6f}")

# Compare to default parameters
default_model = KLRfome(sigma=1.0, lambda_reg=0.1, n_rff_features=256, window_size=3, seed=SEED)
default_model.fit(training_data)

print(f"\nDefault model (sigma=1.0, lambda=0.1):")
print(f"  Converged: {default_model._fit_result.converged}")
print(f"  Iterations: {default_model._fit_result.n_iterations}")
print(f"  Final loss: {default_model._fit_result.final_loss:.6f}")


## Part 4: Understanding Validation Metrics

KLRfome provides comprehensive metrics for model evaluation. Let's explore the key metrics.


In [None]:
from klrfome.utils.validation import metrics, CM_quads

# Get metrics from best fold
best_fold_result = next(f for f in cv_results['folds'] if f['fold'] == cv_results['best_fold'])
best_metrics = best_fold_result['metrics']

print("Key Metrics Explained:")
print("=" * 80)
print(f"AUC (Area Under ROC Curve): {best_metrics['AUC']:.3f}")
print("  - Measures ability to distinguish between classes")
print("  - Range: 0.5 (random) to 1.0 (perfect)")
print("  - >0.7: acceptable, >0.8: good, >0.9: excellent")
print()

print(f"Accuracy: {best_metrics['Accuracy']:.3f}")
print("  - Overall proportion of correct predictions")
print()

print(f"Sensitivity (Recall, TPR): {best_metrics['Sensitivity']:.3f}")
print("  - Proportion of actual sites correctly identified")
print("  - True Positive Rate")
print()

print(f"Specificity (TNR): {best_metrics['Specificity']:.3f}")
print("  - Proportion of actual background correctly identified")
print("  - True Negative Rate")
print()

print(f"Precision (PPV): {best_metrics['Precision']:.3f}")
print("  - Proportion of predicted sites that are actually sites")
print()

print(f"Kappa: {best_metrics['Kappa']:.3f}")
print("  - Agreement between predictions and observations")
print("  - Accounts for chance agreement")
print("  - Range: -1 to 1, >0.6: good agreement")
print()

print(f"F-Measure: {best_metrics['F_Measure']:.3f}")
print("  - Harmonic mean of precision and recall")
print("  - Balances precision and sensitivity")


### Confusion Matrix

The confusion matrix shows the breakdown of predictions vs. observations.


In [None]:
# Display confusion matrix for best fold
print(f"Confusion Matrix (Fold {cv_results['best_fold']}):")
print(f"  True Positives (TP):  {best_fold_result['TP']}")
print(f"  False Positives (FP): {best_fold_result['FP']}")
print(f"  True Negatives (TN):  {best_fold_result['TN']}")
print(f"  False Negatives (FN): {best_fold_result['FN']}")
print()

# Visualize confusion matrix
cm_data = np.array([
    [best_fold_result['TP'], best_fold_result['FP']],
    [best_fold_result['FN'], best_fold_result['TN']]
])

fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm_data, annot=True, fmt='d', cmap='Blues', ax=ax,
            xticklabels=['Predicted Site', 'Predicted Background'],
            yticklabels=['Actual Site', 'Actual Background'])
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.show()


## Summary

This tutorial demonstrated:

1. **Cross-Validation**: Use `model.cross_validate()` to evaluate model performance
2. **Grid Search**: Systematically search parameter space to find optimal hyperparameters
3. **Model Selection**: Choose best parameters based on CV metrics (typically AUC)
4. **Metrics Interpretation**: Understand key validation metrics

### Key Takeaways:

- **AUC** is often the best metric for model selection (handles class imbalance)
- **Sigma** controls kernel bandwidth (smoothness of predictions)
- **Lambda** controls regularization (prevents overfitting)
- Use **stratified CV** to maintain class balance across folds
- **5-fold CV** is a good default for most datasets

### Next Steps:

- Use the best parameters to fit your final model
- Generate predictions on your full dataset
- Validate on independent test data if available
