In [1]:
%reload_ext autoreload
%autoreload 2

## Walk forward validation

In [2]:

from modules import utils
utils.configure_plotly_template()

In [3]:
import pandas as pd

In [4]:
df = pd.read_parquet('../../../data/UCIrvine/ElectricityLoadDiagrams20112014.parquet').asfreq('h').div(1_000)
df.columns = ['values']

In [5]:
df

Unnamed: 0,values
2012-01-01 00:00:00+00:00,416.921975
2012-01-01 01:00:00+00:00,450.201475
...,...
2014-12-31 22:00:00+00:00,580.852784
2014-12-31 23:00:00+00:00,513.952190


## TO REOMVE

In [6]:
df = df.iloc[:500]

In [7]:
df['values_diff'] = df['values'].diff()

In [8]:
fig = df.plot(facet_col='variable')
fig.update_yaxes(matches=None)

## Walk forward validation

In [9]:
from sklearn.model_selection import TimeSeriesSplit

### Bulk application with configs

In [10]:
horizon = 24
tsv = TimeSeriesSplit(test_size=horizon, max_train_size=horizon*365*3)

In [11]:
configs = {
    'sarima': {
        'model_params': {
            'order': (1, 1, 1),
            'seasonal_order': (1, 1, 1, 24),
            'enforce_stationarity': False,
            'enforce_invertibility': False,
        },
        'log_transform': True,
    },
    'ets': {
        'model_params': {
            'trend': 'add',
            'seasonal': 'mul',
            'seasonal_periods': 24,
        },
        'log_transform': False,
    },
    'prophet': {
        'model_params': {
            'seasonality_mode': 'multiplicative',
            'daily_seasonality': True,
        },
        'log_transform': True,
    },
}

configs

{'sarima': {'model_params': {'order': (1, 1, 1),
   'seasonal_order': (1, 1, 1, 24),
   'enforce_stationarity': False,
   'enforce_invertibility': False},
  'log_transform': True},
 'ets': {'model_params': {'trend': 'add',
   'seasonal': 'mul',
   'seasonal_periods': 24},
  'log_transform': False},
 'prophet': {'model_params': {'seasonality_mode': 'multiplicative',
   'daily_seasonality': True},
  'log_transform': True}}

In [12]:
from sklearn.metrics import root_mean_squared_error, mean_absolute_error

metrics = {
    'rmse': root_mean_squared_error,
    'mae': mean_absolute_error 
}

### Run experiment with all models

In [13]:
# burn_in = max(p + d + q, s * (P + D + Q))  # conservative offset
#         = max(1+1+1, 24*(1+1+1))
#         = max(3, 72)
#         = 72

In [14]:
from sklearn.metrics import root_mean_squared_error
from modules.utils import TimeSeriesForecaster

d_metrics = []
d_forecasts = []

series = df['values']
for fold, (train_idx, test_idx) in enumerate(tsv.split(series)):
    
    print(f"Fold {fold + 1}")
    
    series = series.tz_localize(None)
    
    data = {
        'train': series.iloc[train_idx],
        'test': series.iloc[test_idx],
    }
    
    tf = TimeSeriesForecaster(train=data['train'], test=data['test'], freq="h", idx_offset=72)
    
    df_forecast = tf.bulk_forecast(configs, metrics=metrics)
    df_forecast['fold'] = fold
    d_metrics.append(df_forecast)
    
    df_forecast = tf.combine_with_historical(df_forecast=df_forecast)
    df_forecast['fold'] = fold
    
    d_forecasts.append(df_forecast)

Fold 1



Maximum Likelihood optimization failed to converge. Check mle_retvals

19:25:48 - cmdstanpy - INFO - Chain [1] start processing
19:25:48 - cmdstanpy - INFO - Chain [1] done processing


Fold 2


19:25:51 - cmdstanpy - INFO - Chain [1] start processing
19:25:51 - cmdstanpy - INFO - Chain [1] done processing


Fold 3



Maximum Likelihood optimization failed to converge. Check mle_retvals

19:25:55 - cmdstanpy - INFO - Chain [1] start processing
19:25:55 - cmdstanpy - INFO - Chain [1] done processing


Fold 4



Maximum Likelihood optimization failed to converge. Check mle_retvals


Optimization failed to converge. Check mle_retvals.

19:26:00 - cmdstanpy - INFO - Chain [1] start processing
19:26:00 - cmdstanpy - INFO - Chain [1] done processing


Fold 5


19:26:04 - cmdstanpy - INFO - Chain [1] start processing
19:26:05 - cmdstanpy - INFO - Chain [1] done processing


In [15]:
d_metrics[0]

Unnamed: 0,model,split,values,datetime,rmse,mae,fold
0,sarima,train,"[553.8216021808736, 464.5747858438577, 439.027...","DatetimeIndex(['2012-01-04 00:00:00', '2012-01...",11.065509,7.970610,0
1,sarima,test,"[974.6454045487005, 936.9443143484783, 853.581...","DatetimeIndex(['2012-01-16 20:00:00', '2012-01...",10.507690,8.464943,0
...,...,...,...,...,...,...,...
4,prophet,train,"[578.4625143888077, 489.36638149066613, 438.13...","DatetimeIndex(['2012-01-04 00:00:00', '2012-01...",21.402574,17.073393,0
5,prophet,test,"[1011.9958863342883, 951.249546716328, 836.419...","DatetimeIndex(['2012-01-16 20:00:00', '2012-01...",16.385133,13.027110,0


In [16]:
dfs = []
for df in d_metrics:
    for i, x in df.iterrows():
        start, end = x.datetime[[0, -1]]
        dfs.append({
            'fold': x.fold,
            'split': x.split,
            'start': start,
            'end': end,
            'model': x.model,
            'rmse': x.rmse,
        })

df = pd.DataFrame(dfs)
df

Unnamed: 0,fold,split,start,end,model,rmse
0,0,train,2012-01-04 00:00:00,2012-01-16 19:00:00,sarima,11.065509
1,0,test,2012-01-16 20:00:00,2012-01-17 19:00:00,sarima,10.507690
...,...,...,...,...,...,...
28,4,train,2012-01-04 00:00:00,2012-01-20 19:00:00,prophet,20.398574
29,4,test,2012-01-20 20:00:00,2012-01-21 19:00:00,prophet,22.020177


In [25]:
df.set_index(['fold', 'split', 'start', 'end', 'model']).unstack(level='model').style.background_gradient(cmap='Greens_r', axis=None).format(precision=2)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,rmse,rmse,rmse
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,model,ets,prophet,sarima
fold,split,start,end,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
0,test,2012-01-16 20:00:00,2012-01-17 19:00:00,13.96,16.39,10.51
0,train,2012-01-04 00:00:00,2012-01-16 19:00:00,11.24,21.4,11.07
1,test,2012-01-17 20:00:00,2012-01-18 19:00:00,15.44,19.37,9.06
1,train,2012-01-04 00:00:00,2012-01-17 19:00:00,11.05,21.33,10.99
2,test,2012-01-18 20:00:00,2012-01-19 19:00:00,12.17,20.49,8.0
2,train,2012-01-04 00:00:00,2012-01-18 19:00:00,10.99,21.11,10.9
3,test,2012-01-19 20:00:00,2012-01-20 19:00:00,10.34,18.3,16.26
3,train,2012-01-04 00:00:00,2012-01-19 19:00:00,10.81,20.54,10.7
4,test,2012-01-20 20:00:00,2012-01-21 19:00:00,17.41,22.02,14.97
4,train,2012-01-04 00:00:00,2012-01-20 19:00:00,10.52,20.4,10.61


In [17]:
df.pivot(
    index=['split', 'fold'],
    columns=['model'],
    values='rmse'
).style.background_gradient(cmap='Greens_r', axis=None).format(precision=2)

Unnamed: 0_level_0,model,ets,prophet,sarima
split,fold,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
test,0,13.96,16.39,10.51
test,1,15.44,19.37,9.06
test,2,12.17,20.49,8.0
test,3,10.34,18.3,16.26
test,4,17.41,22.02,14.97
train,0,11.24,21.4,11.07
train,1,11.05,21.33,10.99
train,2,10.99,21.11,10.9
train,3,10.81,20.54,10.7
train,4,10.52,20.4,10.61


In [18]:
df.pivot(
    index=['fold', 'model'],
    columns=['split'],
    values='rmse'
).style

Unnamed: 0_level_0,split,test,train
fold,model,Unnamed: 2_level_1,Unnamed: 3_level_1
0,ets,13.963529,11.24065
0,prophet,16.385133,21.402574
0,sarima,10.50769,11.065509
1,ets,15.444116,11.048169
1,prophet,19.368396,21.333012
1,sarima,9.056767,10.987483
2,ets,12.169242,10.991956
2,prophet,20.489893,21.110611
2,sarima,8.000882,10.896112
3,ets,10.34208,10.807888


In [18]:
dfs

[{'fold': 0,
  'split': 'train',
  'start': Timestamp('2012-01-04 00:00:00'),
  'end': Timestamp('2012-01-16 19:00:00'),
  'model': 'sarima',
  'rmse': 11.065508684802587},
 {'fold': 0,
  'split': 'test',
  'start': Timestamp('2012-01-16 20:00:00'),
  'end': Timestamp('2012-01-17 19:00:00'),
  'model': 'sarima',
  'rmse': 10.50768998074886},
 {'fold': 0,
  'split': 'train',
  'start': Timestamp('2012-01-04 00:00:00'),
  'end': Timestamp('2012-01-16 19:00:00'),
  'model': 'ets',
  'rmse': 11.240649924823403},
 {'fold': 0,
  'split': 'test',
  'start': Timestamp('2012-01-16 20:00:00'),
  'end': Timestamp('2012-01-17 19:00:00'),
  'model': 'ets',
  'rmse': 13.963529011483931},
 {'fold': 0,
  'split': 'train',
  'start': Timestamp('2012-01-04 00:00:00'),
  'end': Timestamp('2012-01-16 19:00:00'),
  'model': 'prophet',
  'rmse': 21.40257440149653},
 {'fold': 0,
  'split': 'test',
  'start': Timestamp('2012-01-16 20:00:00'),
  'end': Timestamp('2012-01-17 19:00:00'),
  'model': 'prophet',
  

In [16]:
import plotly.express as px

fig = px.line(
    df,
    x="datetime",
    y="values",
    color="model",
    facet_col="fold",
    facet_row="split",
    category_orders={"split": ["train", "test"]},
    height=600,
    width=1200,
)

fig.update_yaxes(matches=None)
fig.update_xaxes(matches=None, tickangle=45)

In [16]:
import plotly.express as px

fig = px.line(
    df,
    x="datetime",
    y="values",
    color="model",
    facet_col="split",
    facet_row="fold",
    category_orders={"split": ["train", "test"]},
)
fig.update_yaxes(matches=None)
fig.update_xaxes(matches=None)

from collections import defaultdict

# Map: facet_col_value -> list of xaxis names (e.g., 'xaxis', 'xaxis2', ...)
facet_col_to_xaxes = defaultdict(list)

# Build mapping by parsing annotations
for ann in fig.layout.annotations:
    if "facet_col" in ann.text:  # skip titles
        continue
    split_label = ann.text  # assumes facet_col values are labeled here
    xref = ann.xref
    if xref.startswith("x"):
        facet_col_to_xaxes[split_label].append(xref)

# Assign .matches dynamically
for group in facet_col_to_xaxes.values():
    base = group[0]
    for xref in group[1:]:
        fig.layout[xref].matches = base
        
fig

In [17]:
dfs = []
for i in d:
    
    df = i[['model', 'split', 'datetime']]
    dfs.append(tf.last_combined_df)
# df.to_parquet('results/walkforward_results.parquet')

df = pd.concat(dfs)
df

Unnamed: 0,model,split,datetime,values,fold
0,ets,test,2012-01-20 20:00:00,1000.781306,4
1,ets,test,2012-01-20 21:00:00,962.776591,4
...,...,...,...,...,...
1898,historical,test,2012-01-21 18:00:00,1041.27012,4
1899,historical,test,2012-01-21 19:00:00,1039.680049,4


In [18]:
dfp = df.pivot(
    index=['fold', 'split', 'start', 'end'],
    columns='model',
    values='rmse'
)

dfp.style.background_gradient(cmap='Greens_r', axis=None).format(precision=2)

KeyError: 'start'

### How `TimeSeriesSplit` works

In [18]:
from sklearn.model_selection import TimeSeriesSplit

In [8]:
ts = TimeSeriesSplit(test_size=200)

In [9]:
splits = ts.split(X=df)

In [10]:
split1= next(splits)

In [11]:
split1

(array([   0,    1,    2, ..., 2647, 2648, 2649], shape=(2650,)),
 array([2650, 2651, 2652, 2653, 2654, 2655, 2656, 2657, 2658, 2659, 2660,
        2661, 2662, 2663, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2671,
        2672, 2673, 2674, 2675, 2676, 2677, 2678, 2679, 2680, 2681, 2682,
        2683, 2684, 2685, 2686, 2687, 2688, 2689, 2690, 2691, 2692, 2693,
        2694, 2695, 2696, 2697, 2698, 2699, 2700, 2701, 2702, 2703, 2704,
        2705, 2706, 2707, 2708, 2709, 2710, 2711, 2712, 2713, 2714, 2715,
        2716, 2717, 2718, 2719, 2720, 2721, 2722, 2723, 2724, 2725, 2726,
        2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2736, 2737,
        2738, 2739, 2740, 2741, 2742, 2743, 2744, 2745, 2746, 2747, 2748,
        2749, 2750, 2751, 2752, 2753, 2754, 2755, 2756, 2757, 2758, 2759,
        2760, 2761, 2762, 2763, 2764, 2765, 2766, 2767, 2768, 2769, 2770,
        2771, 2772, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2780, 2781,
        2782, 2783, 2784, 2785, 2786, 2787, 27

In [12]:
split2= next(splits)

In [13]:
split2

(array([   0,    1,    2, ..., 2847, 2848, 2849], shape=(2850,)),
 array([2850, 2851, 2852, 2853, 2854, 2855, 2856, 2857, 2858, 2859, 2860,
        2861, 2862, 2863, 2864, 2865, 2866, 2867, 2868, 2869, 2870, 2871,
        2872, 2873, 2874, 2875, 2876, 2877, 2878, 2879, 2880, 2881, 2882,
        2883, 2884, 2885, 2886, 2887, 2888, 2889, 2890, 2891, 2892, 2893,
        2894, 2895, 2896, 2897, 2898, 2899, 2900, 2901, 2902, 2903, 2904,
        2905, 2906, 2907, 2908, 2909, 2910, 2911, 2912, 2913, 2914, 2915,
        2916, 2917, 2918, 2919, 2920, 2921, 2922, 2923, 2924, 2925, 2926,
        2927, 2928, 2929, 2930, 2931, 2932, 2933, 2934, 2935, 2936, 2937,
        2938, 2939, 2940, 2941, 2942, 2943, 2944, 2945, 2946, 2947, 2948,
        2949, 2950, 2951, 2952, 2953, 2954, 2955, 2956, 2957, 2958, 2959,
        2960, 2961, 2962, 2963, 2964, 2965, 2966, 2967, 2968, 2969, 2970,
        2971, 2972, 2973, 2974, 2975, 2976, 2977, 2978, 2979, 2980, 2981,
        2982, 2983, 2984, 2985, 2986, 2987, 29

In [14]:
list_df_train = []
list_df_test = []

for index_train, index_test in ts.split(df):
    list_df_train.append(df.iloc[index_train])
    list_df_test.append(df.iloc[index_test])

In [15]:
list_df_train[0]

Unnamed: 0_level_0,values,values_diff
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
1981-01-01,20.7,
1981-01-02,17.9,-2.8
...,...,...
1988-04-03,16.4,1.8
1988-04-04,13.6,-2.8


In [16]:
list_df_test[0]

Unnamed: 0_level_0,values,values_diff
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
1988-04-05,15.9,2.3
1988-04-06,11.9,-4.0
...,...,...
1988-10-20,15.6,7.4
1988-10-21,10.3,-5.3


In [17]:
list_df_train[1]

Unnamed: 0_level_0,values,values_diff
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
1981-01-01,20.7,
1981-01-02,17.9,-2.8
...,...,...
1988-10-20,15.6,7.4
1988-10-21,10.3,-5.3


In [18]:
list_df_test[1]

Unnamed: 0_level_0,values,values_diff
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
1988-10-22,11.4,1.1
1988-10-23,9.7,-1.7
...,...,...
1989-05-09,9.5,-1.0
1989-05-10,12.5,3.0


In [19]:
from sklearn.model_selection import TimeSeriesSplit

horizon = int(52/4)
tscv = TimeSeriesSplit(test_size=horizon, max_train_size=52*5)
tscv

TimeSeriesSplit(gap=0, max_train_size=260, n_splits=5, test_size=13)

In [20]:
series = pd.concat([train, test])

In [21]:
series.shape

(523,)

In [22]:
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import root_mean_squared_error

# Fit SARIMA on weekly data
model = SARIMAX(train, order=(1, 1, 0))
model_fit = model.fit()

d = []
for fold, (train_idx, test_idx) in enumerate(tscv.split(series)):
    print(f"\nFold {fold + 1}")
    full_series = pd.concat([train, test])
    train_series = full_series.iloc[train_idx]
    test_series = full_series.iloc[test_idx][:horizon]

    print(train_series.shape)
    model = SARIMAX(
        train_series,
        order=(1, 1, 1),
        seasonal_order=(1, 1, 1, 52),
        enforce_invertibility=False,
        enforce_stationarity=False
    )
    model_fit = model.fit()
    
    start, end = train_series.index[[0, -1]]
    pred_train = model_fit.predict(start, end)
    
    start, end = test_series.index[[0, -1]]
    pred_test = model_fit.predict(start, end)

    r = {
        'training': {
            'start': train_series.index[0],
            'end': train_series.index[-1],
            'rmse': root_mean_squared_error(train_series, pred_train)
        },
        'test': {
            'start': test_series.index[0],
            'end': test_series.index[-1],
            'rmse': root_mean_squared_error(test_series, pred_test)
        },
    }
    
    print(r)
    
    d.append(r)


Fold 1
(260,)




{'training': {'start': Timestamp('1984-10-21 00:00:00'), 'end': Timestamp('1989-10-08 00:00:00'), 'rmse': 2.1349312247349927}, 'test': {'start': Timestamp('1989-10-15 00:00:00'), 'end': Timestamp('1990-01-07 00:00:00'), 'rmse': 1.931887387258325}}

Fold 2
(260,)




{'training': {'start': Timestamp('1985-01-20 00:00:00'), 'end': Timestamp('1990-01-07 00:00:00'), 'rmse': 2.306163179602357}, 'test': {'start': Timestamp('1990-01-14 00:00:00'), 'end': Timestamp('1990-04-08 00:00:00'), 'rmse': 1.1543645298858036}}

Fold 3
(260,)




{'training': {'start': Timestamp('1985-04-21 00:00:00'), 'end': Timestamp('1990-04-08 00:00:00'), 'rmse': 1.9631449951618536}, 'test': {'start': Timestamp('1990-04-15 00:00:00'), 'end': Timestamp('1990-07-08 00:00:00'), 'rmse': 1.4738060909566544}}

Fold 4
(260,)
{'training': {'start': Timestamp('1985-07-21 00:00:00'), 'end': Timestamp('1990-07-08 00:00:00'), 'rmse': 1.822987739535035}, 'test': {'start': Timestamp('1990-07-15 00:00:00'), 'end': Timestamp('1990-10-07 00:00:00'), 'rmse': 1.555563934881822}}

Fold 5
(260,)


KeyboardInterrupt: 

In [16]:
dfs = []
for i in d:
    df = pd.DataFrame(i)
    dfs.append(df)

df = pd.concat(dfs).reset_index(names='when')
df

Unnamed: 0,when,training,test
0,start,1981-01-04 00:00:00,1989-11-19 00:00:00
1,end,1989-11-12 00:00:00,1990-02-04 00:00:00
...,...,...,...
13,end,1990-10-14 00:00:00,1991-01-06 00:00:00
14,rmse,2.206555,1.696887


In [None]:
# Use a range of horizons
horizons = [4, 12, 26, 52]  # 1, 3, 6, and 12 months


In [17]:
df.pivot(columns='when')

Unnamed: 0_level_0,training,training,training,test,test,test
when,end,rmse,start,end,rmse,start
0,,,1981-01-04 00:00:00,,,1989-11-19 00:00:00
1,1989-11-12 00:00:00,,,1990-02-04 00:00:00,,
...,...,...,...,...,...,...
13,1990-10-14 00:00:00,,,1991-01-06 00:00:00,,
14,,2.206555,,,1.696887,


In [11]:
from sklearn.metrics import root_mean_squared_error

rmse = root_mean_squared_error(actuals, preds)
print(f"\nTimeSeriesSplit RMSE: {rmse:.4f}")


TimeSeriesSplit RMSE: 1.5995


In [4]:
from statsmodels.tsa.statespace.sarimax import SARIMAX

# Fit SARIMA on weekly data
model = SARIMAX(train, order=(1, 1, 1), seasonal_order=(1, 1, 1, 52))
model_fit = model.fit()

# Forecast
forecast = model_fit.forecast(steps=len(test))

# Plot
plt.figure(figsize=(10, 4))
plt.plot(test.index, test.values, label="Actual")
plt.plot(test.index, forecast, label="Forecast", color="orange")
plt.title("SARIMA Forecast (Weekly Data)")
plt.legend()

KeyboardInterrupt: 

In [7]:
horizon = 1  # forecast 1 week
history = train.copy()
preds = []
actuals = []

for i in range(0, len(test) - horizon + 1):
    model = SARIMAX(
        history,
        order=(1, 1, 1),
        seasonal_order=(1, 1, 1, 52),
        enforce_stationarity=False,
        enforce_invertibility=False,
    )
    model_fit = model.fit(disp=False)

    forecast = model_fit.forecast(steps=horizon)
    preds.extend(forecast)

    step_actual = test.iloc[i : i + horizon].value
    actuals.extend(step_actual)

    history = pd.concat([history, step_actual])

rmse = mean_squared_error(actuals, preds, squared=False)
print(f"Walk-forward RMSE: {rmse:.4f}")

KeyboardInterrupt: 

In [None]:
# Plot
plt.figure(figsize=(10, 4))
plt.plot(test.index[:len(preds)], actuals, label="Actual")
plt.plot(test.index[:len(preds)], preds, label="Forecast", color="orange")
plt.title("SARIMA Walk-Forward Forecast")
plt.legend()
plt.tight_layout()
plt.show()