# TSMixer

### Setup

In [1]:
import ast
import os

import pandas as pd
import torch
from darts.models import TSMixerModel

from config import ENCODERS, FORECAST_DATES, HORIZON, NUM_SAMPLES, RANDOM_SEEDS, ROOT, SHARED_ARGS
from src.realtime_utils import (
    compute_forecast,
    load_realtime_training_data,
)


  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


In [2]:
OPTIMIZER_DICT = {
    "Adam" : torch.optim.Adam,
    "AdamW" : torch.optim.AdamW,
    "SGD": torch.optim.SGD
}

# Load best model

In [3]:
def get_best_parameters(csv_path: str) -> dict:
    """
    Load a gridsearch CSV, parse covariate columns, drop error columns,
    and return the configuration with the lowest WIS.
    """
    gs = pd.read_csv(csv_path)

    # convert string representations back into Python objects
    for col in ["lags_past_covariates", "lags_future_covariates"]:
        if col in gs.columns:
            gs[col] = gs[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)

    gs = gs.drop(columns=[c for c in ["error_flag", "error_msg"] if c in gs.columns])

    best_row = gs.loc[gs["WIS"].idxmin()].to_dict()
    wis = best_row.pop("WIS")  # remove from dict, keep separately

    for key in ["WIS_1", "WIS_2", "WIS_3", "WIS_std"]:
        best_row.pop(key, None)  # None avoids KeyError if not present

    print(f"WIS of best run: {wis:.3f}")
    return {k: best_row[k] for k in sorted(best_row)}


In [4]:
params = get_best_parameters('gridsearch_tsmixer.csv')

WIS of best run: 335.628


In [5]:
name = params.pop('model')
use_covariates = params.pop('use_covariates')
use_features = params.pop('use_features')
use_encoders = params.pop('use_encoders')
optimizer = params.pop('optimizer')
sample_weight = params.pop('sample_weight')

params['optimizer_cls'] = OPTIMIZER_DICT[optimizer]

In [6]:
optimizer_kwargs={
    "lr": params.pop("optimizer_kwargs.lr"),
    "weight_decay": params.pop("optimizer_kwargs.weight_decay")
}
params["optimizer_kwargs"] = optimizer_kwargs

In [7]:
params

{'activation': 'ReLU',
 'batch_size': 32,
 'dropout': 0.2,
 'ff_size': 64,
 'hidden_size': 32,
 'input_chunk_length': 8,
 'n_epochs': 1000,
 'norm_type': 'TimeBatchNorm2d',
 'normalize_before': False,
 'num_blocks': 6,
 'use_static_covariates': False,
 'optimizer_cls': torch.optim.adamw.AdamW,
 'optimizer_kwargs': {'lr': 0.0005, 'weight_decay': 0.0001}}

# Train model

In [8]:
for forecast_date in FORECAST_DATES[0:1]:
    path = ROOT / f"models/{forecast_date}/"
    os.makedirs(path, exist_ok=True)

    targets, covariates = load_realtime_training_data(as_of=forecast_date)

    for seed in RANDOM_SEEDS[:1]:
        model_path = path / f"{forecast_date}-tsmixer-{seed}.pt"
        model = TSMixerModel(
            **params, add_encoders=ENCODERS if use_encoders else None, **SHARED_ARGS, random_state=seed
        )
        model.fit(
            targets,
            past_covariates=covariates if use_covariates else None,
            sample_weight=sample_weight,
            dataloader_kwargs={"pin_memory": False},
        )
        model.save(str(model_path))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Epoch 999: 100%|██████████| 15/15 [00:00<00:00, 42.97it/s, train_loss=3.180]

`Trainer.fit` stopped: `max_epochs=1000` reached.


Epoch 999: 100%|██████████| 15/15 [00:00<00:00, 42.97it/s, train_loss=3.180]


# Forecast

In [13]:
NAME = 'tsmixer'

In [14]:
targets, covariates = load_realtime_training_data()

In [15]:
use_covariates

False

In [16]:
def compute_ensemble(forecast_date, export=False):
    dfs = []
    for seed in RANDOM_SEEDS:
        print(seed)
        model_path = f'../models/post-covid/{forecast_date}/{forecast_date}-tsmixer-{seed}.pt'
        model = TSMixerModel.load(model_path)
        df = compute_forecast(model, targets, covariates if use_covariates else None, forecast_date, HORIZON, NUM_SAMPLES, vincentization=False, probabilistic_nowcast=True, local=True)
        dfs.append(df)

    df = pd.concat(dfs)
    df = df.groupby(['location', 'age_group', 'forecast_date', 'target_end_date', 'horizon', 'type', 'quantile']).agg({'value': 'mean'}).reset_index()

    df = df.sort_values(['location', 'age_group', 'horizon', 'quantile'])
    
    if export:
        df.to_csv(f'../data/post-covid/submissions/{NAME}/{forecast_date}-icosari-sari-{NAME}.csv', index=False)
        return df
    else:
        return df

In [17]:
forecasts = []
for forecast_date in FORECAST_DATES:
    print(forecast_date)
    forecast = compute_ensemble(forecast_date, export=True)
    forecasts.append(forecast)

# Oracle nowcast

In [13]:
RANDOM_SEEDS = [1, 2, 3, 4, 5 , 6, 7, 8, 9, 10]

In [14]:
NAME = 'tsmixer_oracle'

In [15]:
#forecast_dates = sorted([filename.split('/')[-1][:10] for filename in glob.glob('../data/nowcasts/KIT-baseline/*.csv')])

In [16]:
targets, covariates = load_realtime_training_data()

In [17]:
def compute_forecast(model, target_series, covariates, forecast_date, horizon, num_samples, vincentization=True, probabilistic_nowcast=True, local=False, oracle_nowcast=True):
    '''
    For every sample path given by the nowcasted quantiles, a probabilistic forecast is computed.
    These are then aggregated into one forecast by combining all predicted paths.
    '''
    
    if oracle_nowcast:
        target_list = target_series[:pd.Timestamp(forecast_date)]
        
    else:
        indicator = target_series.components[0].split('-')[1]
        ts_nowcast = load_nowcast(forecast_date, probabilistic_nowcast, indicator, local)
        target_list = make_target_paths(target_series, ts_nowcast)
        target_list = [encode_static_covariates(t, ordinal=False) for t in target_list]

        covariates = [covariates]*len(target_list) if covariates else None
      
    fct = model.predict(n=horizon, 
                        series=target_list, 
                        past_covariates=covariates, 
                        num_samples=num_samples)
    
    df = reshape_forecast(fct)
        
    return df

In [21]:
def compute_ensemble(forecast_date, export=False):
    dfs = []
    for seed in RANDOM_SEEDS:
        print(seed)
        model_path = f'../models/post-covid/{forecast_date}/{forecast_date}-tsmixer-{seed}.pt'
        model = TSMixerModel.load(model_path)
        df = compute_forecast(model, targets, covariates if use_covariates else None, forecast_date, HORIZON, NUM_SAMPLES, vincentization=False, probabilistic_nowcast=True, local=True, oracle_nowcast=True)
        dfs.append(df)

    df = pd.concat(dfs)
    df = df.groupby(['location', 'age_group', 'forecast_date', 'target_end_date', 'horizon', 'type', 'quantile']).agg({'value': 'mean'}).reset_index()

    df = df.sort_values(['location', 'age_group', 'horizon', 'quantile'])
    
    if export:
        df.to_csv(f'../data/post-covid/submissions/{NAME}/{forecast_date}-icosari-sari-{NAME}.csv', index=False)
        return df
    else:
        return df

In [22]:
forecasts = []
for forecast_date in forecast_dates:
    print(forecast_date)
    forecast = compute_ensemble(forecast_date, export=True)
    forecasts.append(forecast)

