In [1]:
from src.grid_search.config import GridSearchConfig
from src.grid_search.core import run_grid_search
from src.grid_search.main import main
import pandas as pd
from src.grid_search.utils import load_data
from pathlib import Path

from src.grid_search.utils import create_empty_results, generate_model_combinations, split_predictions_for_grid_search
from src.splitter import TimeSeriesBacktest
from src.forecast.segmented import SegmentedForecastModel

from src.new_forecast.models.arima import ArimaModel
from src.grid_search.utils import load_data



In [2]:
output_dir = 'dataset'
feature_file_name = "feature_df"
grid_search_file_name = "grid_search_results"
date_col = "forecast_month"
segment_col = "eom_pattern_primary"

feature_df = load_data(f"{output_dir}/{feature_file_name}.csv", date_col, segment_col)
data = feature_df

[32m2025-10-06 10:57:50.245[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36mload_data[0m:[36m28[0m - [1mLoading data from dataset/feature_df.csv[0m
[32m2025-10-06 10:57:50.253[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36mload_data[0m:[36m33[0m - [1mLoaded 726 rows with 7 segments[0m
[32m2025-10-06 10:57:50.253[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36mload_data[0m:[36m34[0m - [1mDate range: 2023-01-01 00:00:00 to 2025-09-01 00:00:00[0m


In [3]:
feature_df

Unnamed: 0,dim_value,forecast_month,year,month_num,target_eom_amount,overall_importance_tier,eom_importance_tier,overall_importance_score,eom_importance_score,eom_risk_flag,...,raw_pm__eom_cv,raw_pm__monthly_cv,raw_pm__transaction_regularity,raw_pm__activity_rate,raw_pm__quarter_end_concentration,raw_pm__year_end_concentration,raw_pm__transaction_dispersion,raw_pm__has_eom_history,raw_pm__months_inactive,raw_pm__eom_periodicity
0,series_011,2024-11-01,2024,11,105.277692,NONE,MEDIUM,0.35071,0.49897,False,...,0.922591,0.199964,0.216814,0.366897,0.075569,0.508820,3.892909,True,2.0,0.963683
1,series_011,2023-12-01,2023,12,108.888773,NONE,MEDIUM,0.35796,0.52020,False,...,0.716601,0.017124,0.215836,0.384466,0.033778,0.573006,3.308501,True,1.0,0.564584
2,series_011,2024-12-01,2024,12,127.162087,NONE,MEDIUM,0.37565,0.59231,False,...,0.834834,0.086359,0.267368,0.420463,0.034566,0.590754,4.569974,True,2.0,0.079149
3,series_011,2024-07-01,2024,7,98.181968,NONE,MEDIUM,0.33320,0.53995,False,...,0.643883,0.037668,0.236763,0.364848,0.000000,0.645877,4.147771,True,3.0,0.057766
4,series_011,2025-04-01,2025,4,174.163281,NONE,MEDIUM,0.40839,0.54398,False,...,0.915389,0.022872,0.226845,0.447852,0.161968,0.637963,3.473558,True,-1.0,0.115853
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
721,others::OUT,2025-05-01,2025,5,45547.569702,,,,,,...,,,,,,,,,,
722,others::OUT,2025-06-01,2025,6,41291.877834,,,,,,...,,,,,,,,,,
723,others::OUT,2025-07-01,2025,7,48222.943785,,,,,,...,,,,,,,,,,
724,others::OUT,2025-08-01,2025,8,7968.221146,,,,,,...,,,,,,,,,,


In [4]:
model_configs = {
    # "arima": [
    #     {"type": "arima", "params": {"order": (1, 1, 1)}},
    #     {"type": "arima", "params": {"order": (2, 1, 1)}},
    # ],
    "moving_average": [
        {"type": "moving_average", "params": {"window": 3}},
        {"type": "moving_average", "params": {"window": 6}},
    ],
    "null": [
        {"type": "null", "params": {}},
    ],
}

model_segment_mapping = {
    "null": ["RARE_STALE", "NO_EOM", "EMERGING", "AGGREGATED_OTHERS"],
}

config = GridSearchConfig(
    feature_file_path="outputs/feature_df.csv",
    segment_col="eom_pattern_primary",
    target_col="target_eom_amount",
    date_col="forecast_month",
    dimensions=["dim_value"],
    test_predictions=6,
    validation_predictions=6,
    input_steps=12,
    min_backtest_iterations=3,
    primary_metric="mae",
    model_configs=model_configs,
    model_segment_mapping=model_segment_mapping,
    output_dir=f"{output_dir}/{grid_search_file_name}",
    save_detailed_results=True,
)

In [5]:
segments = sorted(feature_df[segment_col].unique())
unmapped_strategy: str = "all_models"
segment_col="eom_pattern_primary"
target_col = 'target_eom_amount'
dimensions = ['dim_value']

forecast_horizon = 1
input_steps = 12
expanding_window = True

date_col = 'forecast_month'
min_backtest_iterations = 1
stride = 1

model_combinations = generate_model_combinations(segments, model_configs, model_segment_mapping, unmapped_strategy)

[32m2025-10-06 10:57:50.271[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36mgenerate_model_combinations[0m:[36m56[0m - [1mGenerating model combinations...[0m
[32m2025-10-06 10:57:50.271[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36m_generate_mapped_combinations[0m:[36m90[0m - [1mUsing model-segment mapping to generate combinations[0m
[32m2025-10-06 10:57:50.271[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36m_generate_all_combinations[0m:[36m68[0m - [1mNo model-segment mapping specified, testing all models on all segments[0m
[32m2025-10-06 10:57:50.271[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36m_generate_all_combinations[0m:[36m79[0m - [1mGenerated 243 model combinations for 5 segments[0m
[32m2025-10-06 10:57:50.272[0m | [1mINFO    [0m | [36msrc.grid_search.utils[0m:[36m_generate_mapped_combinations[0m:[36m110[0m - [1mUnmapped segments using all_models: ['CONTINUOUS_STABLE', 'CONTINUOUS_VOLATI

In [6]:
model_mapping = model_combinations[0]

model_mapping

{'AGGREGATED_OTHERS': {'type': 'null', 'params': {}},
 'RARE_STALE': {'type': 'null', 'params': {}},
 'CONTINUOUS_STABLE': {'type': 'moving_average', 'params': {'window': 3}},
 'CONTINUOUS_VOLATILE': {'type': 'moving_average', 'params': {'window': 3}},
 'INTERMITTENT_ACTIVE': {'type': 'moving_average', 'params': {'window': 3}},
 'INTERMITTENT_DORMANT': {'type': 'moving_average', 'params': {'window': 3}},
 'RARE_RECENT': {'type': 'moving_average', 'params': {'window': 3}}}

In [7]:
splitter = TimeSeriesBacktest(
        forecast_horizon=forecast_horizon,
        input_steps=input_steps,
        expanding_window=expanding_window,
        stride=stride,
        date_column=date_col,
        min_backtest_iterations=min_backtest_iterations,
    )

In [8]:
train_idx, test_idx = next(splitter.split(feature_df))

In [9]:
feature_df.iloc[train_idx]['forecast_month'].unique()

<DatetimeArray>
['2023-12-01 00:00:00', '2023-01-01 00:00:00', '2023-03-01 00:00:00',
 '2023-08-01 00:00:00', '2023-07-01 00:00:00', '2023-09-01 00:00:00',
 '2023-11-01 00:00:00', '2023-05-01 00:00:00', '2023-10-01 00:00:00',
 '2023-06-01 00:00:00', '2023-04-01 00:00:00', '2023-02-01 00:00:00']
Length: 12, dtype: datetime64[ns]

In [10]:
feature_df.iloc[test_idx]['forecast_month'].unique()

<DatetimeArray>
['2024-01-01 00:00:00']
Length: 1, dtype: datetime64[ns]

In [11]:
model = SegmentedForecastModel(
                segment_col=segment_col,
                target_col=target_col,
                dimensions=dimensions,
                model_mapping=model_mapping,
                fallback_model={"type": "moving_average", "params": {"window": 3}},
            )

### Model fit

In [12]:
for segment in feature_df[segment_col].unique():
    print(segment)


segment = 'CONTINUOUS_STABLE'

INTERMITTENT_ACTIVE
CONTINUOUS_STABLE
INTERMITTENT_DORMANT
RARE_RECENT
RARE_STALE
CONTINUOUS_VOLATILE
AGGREGATED_OTHERS


In [13]:
segment_data = feature_df[feature_df[segment_col] == 'CONTINUOUS_STABLE'].copy()

In [14]:
segment_data

Unnamed: 0,dim_value,forecast_month,year,month_num,target_eom_amount,overall_importance_tier,eom_importance_tier,overall_importance_score,eom_importance_score,eom_risk_flag,...,raw_pm__eom_cv,raw_pm__monthly_cv,raw_pm__transaction_regularity,raw_pm__activity_rate,raw_pm__quarter_end_concentration,raw_pm__year_end_concentration,raw_pm__transaction_dispersion,raw_pm__has_eom_history,raw_pm__months_inactive,raw_pm__eom_periodicity
31,series_013,2025-07-01,2025,7,107.501197,NONE,LOW,1.00000,0.41360,False,...,0.421585,0.195559,0.850726,0.099710,0.000000,0.352295,4.862043,True,6.0,0.136223
32,series_013,2023-11-01,2023,11,61.043785,NONE,LOW,0.96766,0.33504,False,...,0.361300,0.177833,0.796668,0.011029,0.000000,0.315393,3.735004,False,10.0,0.126526
33,series_013,2024-05-01,2024,5,120.977242,NONE,LOW,1.00000,0.37457,False,...,0.419282,0.234080,0.862668,0.011046,0.136813,0.344309,4.167364,False,10.0,0.844816
34,series_013,2023-10-01,2023,10,58.807453,NONE,LOW,0.98661,0.37863,False,...,0.347832,0.248935,0.845964,0.107916,0.023575,0.371188,4.686743,True,10.0,0.162972
35,series_013,2024-03-01,2024,3,138.994946,NONE,LOW,0.93829,0.44439,False,...,0.445251,0.219682,0.862678,0.052377,0.000000,0.367856,4.459427,True,9.0,0.049849
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
262,series_016,2024-04-01,2024,4,12776.458580,MEDIUM,CRITICAL,0.10000,0.48527,False,...,0.395712,0.980536,0.138842,0.320893,1.000000,0.074728,2.295135,True,7.0,0.160554
267,series_016,2024-06-01,2024,6,8910.016358,MEDIUM,CRITICAL,0.24603,0.60790,False,...,0.584222,0.945195,0.157039,0.200267,0.962669,0.036633,3.301329,True,8.0,0.541252
268,series_016,2024-11-01,2024,11,11587.598804,MEDIUM,CRITICAL,0.21452,0.46715,False,...,0.449331,0.839778,0.042407,0.365175,0.907822,0.000000,1.723001,False,9.0,0.461974
276,series_016,2024-07-01,2024,7,9493.260831,MEDIUM,CRITICAL,0.19524,0.55830,False,...,0.578413,0.998381,0.118821,0.269128,1.000000,0.067439,2.956769,True,7.0,0.421235


In [15]:
model_config = model.model_mapping[segment]

In [16]:
mini_model = model._create_model(model_config)

In [17]:
mini_model

<src.forecast.models.moving_average.MovingAverageModel at 0x13a0d0c20>

In [18]:
arima = ArimaModel(
    dimensions=['dim_value'],
    target_col=target_col,
    date_col=date_col
)

In [19]:
arima.fit(data=data)

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  warn('Non-invertible starting MA parameters found.'
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  warn('Non-invertible starting MA parameters found.'
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, fr

ArimaModel(name='arima', order=(1, 1, 1), dimensions=['dim_value'], target_col='target_eom_amount', date_col='forecast_month', forecast_horizon=1, models={'others::IN': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a0d2660>, 'others::OUT': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a555090>, 'series_001': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a555d10>, 'series_002': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a03a190>, 'series_003': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a03ab10>, 'series_004': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a0e5eb0>, 'series_005': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a032470>, 'series_006': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a0327a0>, 'series_007': <statsmodels.tsa.arima.model.ARIMAResultsWrapper object at 0x13a509450>, 'series_008': <statsmodels.tsa.arima.model.ARIMAResultsWrappe

In [20]:
arima.predict(data=data)

  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  results = data_grouped.apply(


Unnamed: 0,dim_value,forecast_month,year,month_num,target_eom_amount,overall_importance_tier,eom_importance_tier,overall_importance_score,eom_importance_score,eom_risk_flag,...,raw_pm__transaction_regularity,raw_pm__activity_rate,raw_pm__quarter_end_concentration,raw_pm__year_end_concentration,raw_pm__transaction_dispersion,raw_pm__has_eom_history,raw_pm__months_inactive,raw_pm__eom_periodicity,level_1,prediction
0,series_011,2024-11-01,2024.0,11.0,105.277692,NONE,MEDIUM,0.35071,0.49897,False,...,0.216814,0.366897,0.075569,0.508820,3.892909,True,2.0,0.963683,,
1,series_011,2023-12-01,2023.0,12.0,108.888773,NONE,MEDIUM,0.35796,0.52020,False,...,0.215836,0.384466,0.033778,0.573006,3.308501,True,1.0,0.564584,,
2,series_011,2024-12-01,2024.0,12.0,127.162087,NONE,MEDIUM,0.37565,0.59231,False,...,0.267368,0.420463,0.034566,0.590754,4.569974,True,2.0,0.079149,,
3,series_011,2024-07-01,2024.0,7.0,98.181968,NONE,MEDIUM,0.33320,0.53995,False,...,0.236763,0.364848,0.000000,0.645877,4.147771,True,3.0,0.057766,,
4,series_011,2025-04-01,2025.0,4.0,174.163281,NONE,MEDIUM,0.40839,0.54398,False,...,0.226845,0.447852,0.161968,0.637963,3.473558,True,-1.0,0.115853,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
743,series_016,2025-10-01,,,,,,,,,...,,,,,,,,,33.0,6070.342468
744,series_017,2025-10-01,,,,,,,,,...,,,,,,,,,33.0,6726.378914
745,series_018,2025-10-01,,,,,,,,,...,,,,,,,,,33.0,67.116520
746,series_019,2025-10-01,,,,,,,,,...,,,,,,,,,33.0,1052.965092


In [None]:
data.groupby(['dim_value']).apply(lambda x: arima.models.to_dict()[x.name].forecast(steps=1).to_frame(name='predictions'))