# Grid Search

This quick guide shows how grid search can be used to find the best hyperparameters for ``ForecastFlowML``.

## Import packages

In [5]:
from forecastflowml.meta_model import ForecastFlowML
from forecastflowml.preprocessing import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

## Initialize Spark

In [6]:
spark = (
    SparkSession.builder.master("local[4]")
    .config("spark.driver.memory", "8g")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.sql.execution.arrow.enabled", "true")
    .getOrCreate()
)

## Sample Dataset

In [7]:
df = load_walmart_m5(spark)
df.show(10)

+--------------------+-----------+-------+------+--------+--------+----------+-----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-29|  2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-30|  5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-31|  3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-01|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-02|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-03|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-04|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-05|  1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      T

## Feature Engineering

In [8]:
feature_extractor = FeatureExtractor(
    id_col="id",
    date_col="date",
    target_col="sales",
    lag_window_features={
        "lag": [7 * (i + 1) for i in range(4)],
    },
)
df_features = feature_extractor.transform(df).localCheckpoint()
df_features.show(10)

+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|lag_7|lag_14|lag_21|lag_28|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-01-31|  2.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-01|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-02|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-03|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-04|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-05|  0.0| null|  null|  null|

## Train/Test Dataset

In [9]:
df_train = df_features.filter(F.col("date") < "2016-04-25")
df_test = df_features.filter(F.col("date") >= "2016-04-25")

## Initialize Model

In [10]:
forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=LGBMRegressor(),
)

## Search Hyperparameters with Grid Search

In [11]:
trials = forecast_flow.grid_search(
    df_train,
    param_grid={"n_estimators": [50, 100], "num_leaves": [20, 30]},
    n_cv_splits=3,
    scoring_metric="neg_mean_squared_error",
)
trials



Unnamed: 0,group,score,n_estimators,num_leaves
0,WI_3,-7.672328,100,20
1,WI_3,-7.672407,100,30
2,WI_3,-7.679507,50,20
3,WI_3,-7.680971,50,30
4,WI_2,-15.551556,100,30
5,WI_2,-15.552167,100,20
6,WI_2,-15.562159,50,30
7,WI_2,-15.563983,50,20
8,WI_1,-3.089326,50,20
9,WI_1,-3.089694,50,30


In [12]:
best_trial = trials.groupby("group", group_keys=False).apply(
    lambda x: x.sort_values("score", ascending=False).head(1)
)
best_params = best_trial.set_index("group")[["n_estimators", "num_leaves"]].to_dict(
    orient="index"
)
best_params

{'CA_1': {'n_estimators': 100, 'num_leaves': 30},
 'CA_2': {'n_estimators': 50, 'num_leaves': 20},
 'CA_3': {'n_estimators': 50, 'num_leaves': 20},
 'CA_4': {'n_estimators': 50, 'num_leaves': 20},
 'TX_1': {'n_estimators': 100, 'num_leaves': 30},
 'TX_2': {'n_estimators': 50, 'num_leaves': 20},
 'TX_3': {'n_estimators': 50, 'num_leaves': 20},
 'WI_1': {'n_estimators': 50, 'num_leaves': 20},
 'WI_2': {'n_estimators': 100, 'num_leaves': 30},
 'WI_3': {'n_estimators': 100, 'num_leaves': 20}}

In [16]:
group_models = {k: LGBMRegressor(**v) for k, v in best_params.items()}
group_models

{'CA_1': LGBMRegressor(num_leaves=30),
 'CA_2': LGBMRegressor(n_estimators=50, num_leaves=20),
 'CA_3': LGBMRegressor(n_estimators=50, num_leaves=20),
 'CA_4': LGBMRegressor(n_estimators=50, num_leaves=20),
 'TX_1': LGBMRegressor(num_leaves=30),
 'TX_2': LGBMRegressor(n_estimators=50, num_leaves=20),
 'TX_3': LGBMRegressor(n_estimators=50, num_leaves=20),
 'WI_1': LGBMRegressor(n_estimators=50, num_leaves=20),
 'WI_2': LGBMRegressor(num_leaves=30),
 'WI_3': LGBMRegressor(num_leaves=20)}

## Train ForecastFlowML with Optimized Hyperparameters 

In [17]:
forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=group_models,
)

In [18]:
forecast_flow.train(df_train).show()

+-----+--------------------+--------------------+--------------------+--------------------+---------------+
|group|    forecast_horizon|               model|          start_time|            end_time|elapsed_seconds|
+-----+--------------------+--------------------+--------------------+--------------------+---------------+
| CA_2|[[1, 2, 3, 4, 5, ...|[�clightgbm.skle...|11-Apr-2023 (03:4...|11-Apr-2023 (03:4...|            0.4|
| CA_3|[[1, 2, 3, 4, 5, ...|[�clightgbm.skle...|11-Apr-2023 (03:4...|11-Apr-2023 (03:4...|            0.5|
| WI_2|[[1, 2, 3, 4, 5, ...|[�clightgbm.skle...|11-Apr-2023 (03:4...|11-Apr-2023 (03:4...|            2.6|
| WI_3|[[1, 2, 3, 4, 5, ...|[�clightgbm.skle...|11-Apr-2023 (03:4...|11-Apr-2023 (03:4...|            0.8|
| CA_1|[[1, 2, 3, 4, 5, ...|[�clightgbm.skle...|11-Apr-2023 (03:4...|11-Apr-2023 (03:4...|            5.3|
| CA_4|[[1, 2, 3, 4, 5, ...|[�clightgbm.skle...|11-Apr-2023 (03:4...|11-Apr-2023 (03:4...|            0.6|
| TX_1|[[1, 2, 3, 4, 5, ...|