
This notebook runs a `RandomizedSearchCV` to find optimal parameters for the stable models in the benchmark

---


In [1]:

import json
import gc
import time
import numpy as np
import polars as pl
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
import xgboost

# Custom wrappers/utils
from utils.migbt import SklearnMIXGBooster
from utils.kriging_wrapper import PyKrigeWrapper
from utils.gam_wrapper import PyGAMWrapper
from utils.functions import AddCoordinatesRotation, ConvertToPandas
from utils.s3 import get_df_from_s3

---
# Configuration
---
## Models

In [2]:
# Adjust N_ITER based on your time constraints. 10-20 is usually enough for RandomizedSearch.
N_ITER = 15
CV_FOLDS = 3
RANDOM_STATE = 42

TUNABLE_MODELS = [
    {
        "name": "random_forest",
        "class": RandomForestRegressor,
        "search_space": {
            "ml_model__n_estimators": [100, 250, 500],
            "ml_model__max_features": ["sqrt", 1.0],
            "ml_model__min_samples_leaf": [1, 5, 10, 20],
        }
    },
    {
        "name": "xgboost",
        "class": xgboost.XGBRegressor,
        "search_space": {
            "ml_model__n_estimators": [200, 500, 800],
            "ml_model__learning_rate": [0.01, 0.05, 0.1],
            "ml_model__max_depth": [6, 8, 12],
            "ml_model__subsample": [0.7, 0.8, 1.0],
        }
    },
    {
        "name": "mixgboost",
        "class": SklearnMIXGBooster,
        "search_space": {
            "ml_model__k": [10, 20, 40],
            "ml_model__lamb": [0.001, 0.01, 0.05, 0.1],
            "ml_model__n_estimators": [100, 300],
            "ml_model__max_depth": [6, 10, 14]
        }
    },
    {
        "name": "kriging",
        "class": PyKrigeWrapper,
        "search_space": {
            "ml_model__variogram_model": ["exponential", "spherical", "gaussian", "linear"],
            "ml_model__nlags": [6, 10, 20],
            "ml_model__weight": [True, False],
        }
    },
    {
        "name": "gam",
        "class": PyGAMWrapper,
        "search_space": {
            "ml_model__n_splines": [15, 25, 40, 60],
            "ml_model__lam": [0.1, 0.6, 1.5, 5.0],
            "ml_model__spline_order": [2, 3],
        }
    }
]

 ## Datasets

In [3]:
DATASETS_TO_TUNE = [
    {"name": "S-G-Sm", "path": "s3://projet-benchmark-spatial-interpolation/data/synthetic/S-G-Sm.parquet", "n": 5000},
    {"name": "S-G-Lg", "path": "s3://projet-benchmark-spatial-interpolation/data/synthetic/S-G-Lg.parquet", "n": 100000},
]

# %%
def get_data(path, n):
    print(f"  Fetching data from {path}...")
    ldf = get_df_from_s3(path)
    # Collect and ensure column names are consistent
    df = ldf.head(n).collect()
    X = df.select(["x", "y"])
    # Target is usually the last column or named 'value'/'val'
    target_col = [c for c in df.columns if c not in ["x", "y"]][0]
    y = df.select(pl.col(target_col)).to_numpy().ravel()
    return X, y

---
# Execution

---




In [6]:
results_to_export = {}

for ds_cfg in DATASETS_TO_TUNE:
    print(f"\n{'='*60}\nSTARTING TUNING FOR: {ds_cfg['name']}\n{'='*60}")
    
    # Load base data
    X_full, y_full = get_data(ds_cfg["path"], ds_cfg["n"])
    results_to_export[ds_cfg['name']] = {}

    for m_cfg in TUNABLE_MODELS:
        print(f"  > Optimizing {m_cfg['name']}...")
        
        # SPECIAL CASE: Limit Kriging to avoid O(N^3) crashes and singular matrices
        if m_cfg['name'] == 'kriging':
            k_limit = 2500 
            # 1. Selection using standard Polars indexing
            idx = np.random.choice(len(X_full), min(len(X_full), k_limit), replace=False)
            X_tune = X_full[idx]  # Universal Polars indexing
            y_tune = y_full[idx]  # y_full is already a numpy array from get_data
            
            # 2. Add jitter using Polars expressions to prevent singular matrices
            X_tune = X_tune.with_columns([
                (pl.col("x") + np.random.normal(0, 1e-9, len(X_tune))),
                (pl.col("y") + np.random.normal(0, 1e-9, len(X_tune)))
            ])
            
            current_n_jobs = 1 # Sequential to prevent RAM explosion
        else:
            X_tune, y_tune = X_full, y_full
            current_n_jobs = -1 # Parallelize other models

        pipeline = Pipeline([
            ("coord_rotation", AddCoordinatesRotation(coordinates_names=("x", "y"), number_axis=1)),
            ("pandas_converter", ConvertToPandas()),
            ("ml_model", m_cfg["class"]())
        ])

        search = RandomizedSearchCV(
            estimator=pipeline,
            param_distributions=m_cfg["search_space"],
            n_iter=N_ITER,
            cv=CV_FOLDS,
            scoring='r2',
            n_jobs=current_n_jobs,
            random_state=RANDOM_STATE
        )

        try:
            search.fit(X_tune, y_tune)
            clean_params = {k.split("__")[1]: v for k, v in search.best_params_.items()}
            results_to_export[ds_cfg['name']][m_cfg['name']] = clean_params
            print(f"    Done. Best R2: {search.best_score_:.4f}")
        except Exception as e:
            print(f"    [SKIP] {m_cfg['name']} failed: {e}")

    del X_full, y_full
    gc.collect()


print("\n" + "#" * 30)
print("FINAL TUNED PARAMETERS JSON")
print("#" * 30 + "\n")
print(json.dumps(results_to_export, indent=4))


STARTING TUNING FOR: S-G-Sm
  Fetching data from s3://projet-benchmark-spatial-interpolation/data/synthetic/S-G-Sm.parquet...
  > Optimizing random_forest...
    Done. Best R2: 0.4128
  > Optimizing xgboost...
    Done. Best R2: 0.3864
  > Optimizing mixgboost...
  [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=40)...   [MI-GBT] Compute W (k=40)...   [MI-GBT] Compute W (k=40)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=20)...   [MI-GBT] Compute W (k=20)...   [MI-GBT] Compute W (k=20)...   [MI-GBT] Compute W (k=40)...   [MI-GBT] Compute W (k=40)...   [MI-GBT] Compute W (k=40)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=10)...   [MI-GBT] Compute W (k=20)...   [MI-GBT] Compute W (k=20)...   [MI-GBT] Compute W (