In [None]:
import matplotlib.pyplot as plt
import os
import torch
from typing import List
import pandas as pd

import nnts
import nnts.data
import nnts.experiments
from nnts import utils
import nnts.torch.preprocessing as preprocessing
import nnts.torch.models
import nnts.torch.trainers as trainers
import nnts.metrics
import nnts.torch.datasets
import nnts.loggers
import covs 
import nnts.pandas
import nnts


%load_ext autoreload
%autoreload 2

In [None]:
data_path = "data"
model_name = "seg-lstm"
dataset_name = "hospital"
results_path = "nb-results"
metadata_path = os.path.join(data_path, f"{model_name}-monash.json")
metadata = utils.load(dataset_name, path=metadata_path)
datafile_path = os.path.join(data_path, metadata.filename)
PATH = os.path.join(results_path, model_name, metadata.dataset)

df_orig, *_ = nnts.pandas.read_tsf(datafile_path)
params = utils.Hyperparams()
splitter = nnts.pandas.LastHorizonSplitter()

nnts.loggers.makedirs_if_not_exists(PATH)

In [None]:
scenario_list: List[nnts.experiments.CovariateScenario] = []

# Add the baseline scenarios
for seed in [42, 43, 44, 45, 46]:
    scenario_list.append(
        nnts.experiments.CovariateScenario(metadata.prediction_length, error=0.0, covariates=0, seed=seed)
    )

# Models for full forecast horizon with covariates
for covariates in [1, 2, 3]:
    for error in covs.errors[metadata.dataset]:
        scenario_list.append( 
            nnts.experiments.CovariateScenario(
                metadata.prediction_length, error, covariates=covariates
            )
        )

scenario_list.append(
    nnts.experiments.CovariateScenario(
        metadata.prediction_length, 0, covariates=3, skip=1
    )
)

In [None]:
for scenario in scenario_list[:1]:
    params.batches_per_epoch = 500
    nnts.torch.datasets.seed_everything(scenario.seed)
    df, scenario = covs.prepare(df_orig.copy(), scenario)
    split_data = splitter(df, metadata.context_length, metadata.prediction_length)
    trn_dl, val_dl, test_dl = nnts.data.create_trn_val_test_dataloaders(
        split_data,
        metadata,
        scenario,
        params,
        nnts.torch.data.TorchTimeseriesDataLoaderFactory(),
    )
    logger = nnts.loggers.WandbRun(
        project=f"{model_name}-{metadata.dataset}",
        name=scenario.name,
        config={
            **params.__dict__,
            **metadata.__dict__,
            **scenario.__dict__,
        },
        path=PATH,
    )
    net = covs.model_factory(model_name, params, scenario, metadata)

    trner = trainers.TorchEpochTrainer(
        nnts.trainers.TrainerState(), 
        net, 
        params, 
        metadata, 
        os.path.join(PATH, f"{scenario.name}.pt"),
    )
    logger.configure(trner.events)

    evaluator = trner.train(trn_dl, val_dl)
    handle = net.decoder.register_forward_hook(logger.log_activations)
    y_hat, y = evaluator.evaluate(
        test_dl, scenario.prediction_length, metadata.context_length, hooks=handle
    )
    handle.remove()
    test_metrics = nnts.metrics.calc_metrics(
        y, y_hat, nnts.metrics.calculate_seasonal_error(trn_dl, metadata)
    )
    logger.log(test_metrics)
    logger.finish()

In [None]:
csv_aggregator = nnts.pandas.CSVFileAggregator(PATH, "results")
results = csv_aggregator()

In [None]:
results = pd.read_csv(f"{PATH}/results.csv")
results

In [None]:
import nnts.experiments.plotting


df_list = covs.add_y_hat(df, y_hat, scenario.prediction_length)
sample_preds = nnts.experiments.plotting.plot(df_list, scenario.prediction_length)

In [None]:
univariate_results = results.loc[
    (results["covariates"] == 0)
    & (results["prediction_length"] == metadata.prediction_length),
    ["smape", "mape", "rmse", "mae"],
]

univariate_results.mean(), univariate_results.std(), univariate_results.count()

In [None]:
cols = ["dataset", "error", "pearson", "covariates", "prediction_length", "smape"]

In [None]:
results.loc[
    (results["covariates"] > 0)
    & (results["error"].isin([0.000000, 0.4714285714285714, 1.65])),
    cols,
].sort_values(by=['covariates', 'error'])

In [None]:
sorted(results["error"].unique().tolist())

In [None]:
import numpy as np
np.linspace(0, 0.6, 8).tolist()