In [None]:
!pip install tirex-ts utilsforecast

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from utilsforecast.plotting import plot_series
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import *

import warnings
warnings.filterwarnings('ignore')

In [None]:
DATA_URL = "https://raw.githubusercontent.com/marcopeix/FoundationModelsForTimeSeriesForecasting/refs/heads/main/data/Walmart.csv"

df = pd.read_csv(DATA_URL)
df['Date'] = pd.to_datetime(df['Date'], format='%d-%m-%Y')

df.head()

## Zero-shot forecasting

In [None]:
import torch
from tirex import ForecastModel, load_model

In [None]:
# Load model


In [None]:
HORIZON = 13

In [None]:
unique_stores = sorted(df['Store'].unique())
inputs_list = []

for store_id in unique_stores:
    store_df = df[df['Store'] == store_id].sort_values(by='Date')
    inputs_list.append(store_df['Weekly_Sales'].values)

input_tensor = torch.tensor(inputs_list, dtype=torch.float32)
print(input_tensor.shape)

In [None]:
# Get quantile forecast and mean forecast


In [None]:
print(quantiles.shape)
print(mean.shape)

In [None]:
# Create DataFrame from Tirex's output
def create_forecast_df(
    quantile_forecast_tensor: torch.Tensor,
    original_df: pd.DataFrame,
    id_col: str,
    time_col: str,
    target_col: str,
    horizon: int,
    freq: str,
):

    # Transform tensor to array

    # Get num_series, horizon, quantile output

    # List for all rows

    # Sorted unique_ids


    for i, id in enumerate(unique_ids):
        # Get the last known date for this store from the original DataFrame

        # Generate forecast dates

        for h in range(horizon):
            forecast_row = {
                id_col: id,
                time_col: forecast_dates[h],
                'tirex': quantile_forecast_array[i, h, 4],       # Median
                'tirex-lo-80': quantile_forecast_array[i, h, 0], # 10th percentile
                'tirex-hi-80': quantile_forecast_array[i, h, 8]  # 90th percentile
            }
            # Append all rows


    return pd.DataFrame(all_forecast_rows)

In [None]:
# Create forecast DataFrame

fcsts_df.head()

In [None]:
plot_series(
    df=df,
    forecasts_df=fcsts_df,
    id_col="Store",
    time_col="Date",
    target_col="Weekly_Sales",
    level=[80],
    max_ids=6,
)

## Cross-validation

In [None]:
# Cross-validation with Tirex
def tirex_cv(
    df: pd.DataFrame,
    model: ForecastModel,
    horizon: int,
    n_windows: int,
    id_col: str,
    time_col: str,
    target_col: str,
    freq: str,
):

    # List for all forecasts

    # Max date

    for i in range(n_windows):
        # Calculate the cutoff date for the current window

        # Create a training DataFrame up to the cutoff_date

        # Prepare inputs_list
        unique_ids = sorted(df_train[id_col].unique())
        inputs_list = []
        for id in unique_ids:
            sub_df = df_train[df_train[id_col] == id].sort_values(by=time_col)
            inputs_list.append(sub_df[target_col].values)

        # Generate forecasts

        # Convert forecasts to DataFrame
        fcsts_df = create_forecast_df(
            quantile_forecast_tensor=quantiles,
            original_df=df_train,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            horizon=horizon,
            freq=freq,
        )

        # Add cutoff column
        fcsts_df['cutoff'] = cutoff_date

        all_cv_forecasts.append(fcsts_df)

    cv_df = pd.concat(all_cv_forecasts, ignore_index=True)
    cv_df = cv_df.merge(df[[id_col, time_col, target_col]], how="left", on=[id_col, time_col])
    return cv_df

In [None]:
# Run cross-validation

cv_df.head()

In [None]:
plot_series(
    df=df,
    forecasts_df=cv_df.drop(columns=["cutoff", "Weekly_Sales"]),
    id_col="Store",
    time_col="Date",
    target_col="Weekly_Sales",
    level=[80],
    max_ids=6,
)

In [None]:
eval_df = evaluate(
    cv_df.drop(columns=["cutoff"]),
    metrics=[mae, smape],
    models=['tirex'],
    target_col='Weekly_Sales',
    id_col='Store',
    time_col="Date",
    agg_fn="mean"
)
eval_df