# TSMixer-Uncorrected

### Setup

In [1]:
import sys
sys.path.append('../')
from src.load_data import *
from src.realtime_utils import *
from src.plot_functions import *
from src.hp_tuning_functions import *
import torch
from darts.utils.likelihood_models import NegativeBinomialLikelihood
from pytorch_lightning.callbacks import RichProgressBar

In [2]:
from darts.models import TSMixerModel



### Configuration

In [3]:
NUM_SAMPLES = 1000
HORIZON = 4

ENCODERS = {
    'datetime_attribute': {'future': ['month', 'weekofyear']}
}

SHARED_ARGS = dict(
    output_chunk_length=HORIZON,
    likelihood=NegativeBinomialLikelihood(),
    pl_trainer_kwargs={
       "enable_progress_bar" : True,
       "enable_model_summary" : False,
       "accelerator" : "cpu",
       "callbacks" : [RichProgressBar(leave=True)]
    }
)

OPTIMIZER_DICT = {
    "Adam" : torch.optim.Adam,
    "AdamW" : torch.optim.AdamW,
    "SGD": torch.optim.SGD
}

In [4]:
RANDOM_SEEDS = [1, 2, 3, 4, 5 , 6, 7, 8, 9, 10]

In [5]:
forecast_dates = sorted([filename.split('/')[-1][:10] for filename in glob.glob('../data/nowcasts/KIT-baseline/*.csv')])
forecast_dates = [f for f in forecast_dates if f >= '2023-11-16' and f <= '2024-09-12']

In [6]:
targets, covariates = load_realtime_training_data()

In [7]:
def compute_forecast(model, forecast_date, horizon, num_samples):

    targets, covariates = load_realtime_training_data(as_of=forecast_date, drop_incomplete=False) 
      
    fct = model.predict(n=horizon, 
                        series=targets, 
                        past_covariates=covariates, 
                        num_samples=num_samples)

    df = reshape_forecast(fct)
    
    return df

In [8]:
def compute_ensemble(forecast_date, export=False):
    dfs = []
    for seed in RANDOM_SEEDS:
        print(seed)
        model_path = f'../models/post-covid/{forecast_date}/{forecast_date}-tsmixer_covariates-{seed}.pt'
        model = TSMixerModel.load(model_path)
        df = compute_forecast(model, forecast_date, HORIZON, NUM_SAMPLES)
        dfs.append(df)

    df = pd.concat(dfs)
    df = df.groupby(['location', 'age_group', 'forecast_date', 'target_end_date', 'horizon', 'type', 'quantile']).agg({'value': 'mean'}).reset_index()

    df = df.sort_values(['location', 'age_group', 'horizon', 'quantile'])
    
    if export:
        df.to_csv(f'../data/post-covid/submissions/tsmixer_uncorrected/{forecast_date}-icosari-sari-tsmixer_uncorrected.csv', index=False)
        return df
    else:
        return df

In [9]:
forecasts = []
for forecast_date in forecast_dates:
    print(forecast_date)
    forecast = compute_ensemble(forecast_date, export=True)
    forecasts.append(forecast)

