In [None]:
Define the hyperparameter grid
hyper_params = {
    'ntrees': [400, 600, 1000],        # Number of trees
    'max_depth': [8, 10],              # Maximum tree depth
    'learn_rate': [0.01, 0.05, 0.1],   # Learning rate
    'col_sample_rate': [0.8],          # Column sample rate
    'min_rows': [15],                   # Minimum number of rows
}

# Define search criteria
search_criteria = {
    'strategy': "Cartesian"  # Other options: "RandomDiscrete", "LatinHypercube"
    # 'max_models': 20,       # For RandomDiscrete, limit the number of models
    # 'seed': 1234            # For reproducibility
}

# Initialize the XGBoost estimator for regression
xgb_estimator = h2o.estimators.H2OXGBoostEstimator(
    seed=1234,
    nfolds=5,                        # Number of cross-validation folds
    stopping_rounds=10,              # Early stopping if no improvement
    distribution='gaussian',         # Set distribution to Gaussian for regression
    score_tree_interval=10,          # Interval for scoring trees
    fold_assignment='Modulo',        # Method for assigning folds
    keep_cross_validation_predictions=True  # Keep CV predictions
)

# Initialize Grid Search
grid = H2OGridSearch(
    model=xgb_estimator,
    hyper_params=hyper_params,
    search_criteria=search_criteria
)

# Train the grid search
grid.train(
    x=features_lat,
    y=target_lat_sin,
    training_frame=train_data_shifted_without_validation,
    validation_frame=validation_data_shifted
)

# Display all models in the grid
print(grid)

# Sort the grid models by validation MSE (ascending order)
sorted_grid = grid.get_grid(
    sort_by='validation_mse',  # Sorting metric for regression
    decreasing=False            # Ascending order for MSE (lower is better)
)

print("Sorted Grid Models by Validation MSE:")
print(sorted_grid)

# Retrieve the best model
best_model = sorted_grid.models[0]
print("Best Model:", best_model)

# Evaluate the best model on the validation set
performance = best_model.model_performance(valid=True)
print("Performance on Validation Set:")
print(performance)
