# TSMixer-Skip

### Setup

In [2]:
import sys
sys.path.append('../')
from src.load_data import *
from src.realtime_utils import *
from src.plot_functions import *
from src.hp_tuning_functions import *
import torch
from darts.utils.likelihood_models import NegativeBinomialLikelihood
from pytorch_lightning.callbacks import RichProgressBar

In [3]:
from darts.models import TSMixerModel



In [4]:
OPTIMIZER_DICT = {
    "Adam" : torch.optim.Adam,
    "AdamW" : torch.optim.AdamW,
    "SGD": torch.optim.SGD
}

In [5]:
def compute_forecast(model, target_series, covariates, forecast_date, horizon, num_samples, vincentization=True, probabilistic_nowcast=True, local=False, skip_last=False):
    '''
    For every sample path given by the nowcasted quantiles, a probabilistic forecast is computed.
    These are then aggregated into one forecast by combining all predicted paths.
    '''
    indicator = target_series.components[0].split('-')[1]
    ts_nowcast = load_nowcast(forecast_date, probabilistic_nowcast, indicator, local)
    target_list = make_target_paths(target_series, ts_nowcast)
    target_list = [encode_static_covariates(t, ordinal=False) for t in target_list]
    
    if skip_last:
        target_list = [t[:-1] for t in target_list]
     
    covariates = [covariates]*len(target_list) if covariates else None
      
    fct = model.predict(n=horizon, 
                        series=target_list, 
                        past_covariates=covariates, 
                        num_samples=num_samples)
    
    if vincentization:
        df = reshape_hfc(fct)
        df = df.groupby(['location', 'age_group', 'forecast_date', 'target_end_date', 'horizon', 'type', 'quantile']).agg({'value': 'mean'}).reset_index()
    else:
        ts_forecast = concatenate(fct, axis='sample')
        df = reshape_forecast(ts_forecast)
    
    if skip_last:
        df.forecast_date = forecast_date
        df.horizon = df.horizon - 1
    
    return df

In [9]:
def compute_ensemble(forecast_date, export=False):
    dfs = []
    for seed in RANDOM_SEEDS:
        # print(seed)
        model_path = f'../models/post-covid/{forecast_date}/{forecast_date}-tsmixer_covariates-{seed}.pt'
        model = TSMixerModel.load(model_path)
        df = compute_forecast(model, targets, covariates, forecast_date, HORIZON, NUM_SAMPLES, vincentization=False, probabilistic_nowcast=True, local=True, skip_last=True)
        dfs.append(df)

    df = pd.concat(dfs)
    df = df.groupby(['location', 'age_group', 'forecast_date', 'target_end_date', 'horizon', 'type', 'quantile']).agg({'value': 'mean'}).reset_index()

    df = df.sort_values(['location', 'age_group', 'horizon', 'quantile'])
    
    if export:
        df.to_csv(f'../data/post-covid/submissions/tsmixer_skip/{forecast_date}-icosari-sari-tsmixer_skip.csv', index=False)
        return df
    else:
        return df

In [10]:
targets, covariates = load_realtime_training_data()

In [None]:
forecasts = []
for forecast_date in FORECAST_DATES:
    print(forecast_date)
    forecast = compute_ensemble(forecast_date, export=True)
    forecasts.append(forecast)