In [None]:
from typing import List
import seaborn as sns
import os

import nnts
import nnts.data
from nnts import utils
import nnts.torch.preprocessing as preprocessing
import trainers
import nnts.torch.models
import nnts.metrics
import nnts.torch.datasets
import nnts.loggers
import nnts.datasets
import nnts.trainers
import nnts.torch.utils
import covs 
import torch.nn.functional as F
import torch.optim
sns.set()

In [None]:
model_name = "base-lstm"
dataset_name = "hospital"
results_path = "nb-results"

df_orig, metadata = nnts.datasets.load_dataset(dataset_name)
PATH = os.path.join(results_path, model_name, metadata.dataset)
utils.makedirs_if_not_exists(PATH)
params = utils.Hyperparams(optimizer=torch.optim.AdamW, loss_fn=F.smooth_l1_loss)

In [None]:
scenario_list: List[covs.CovariateScenario] = []

# Add the baseline scenarios
for seed in [42, 43, 44, 45, 46]:
    scenario_list.append(
        covs.CovariateScenario(metadata.prediction_length, error=0.0, covariates=0, seed=seed)
    )
## Models for full forecast horizon with covariates
for covariates in [1, 2, 3]:
    for error in covs.errors[metadata.dataset]:
        scenario_list.append( 
            covs.CovariateScenario(
                metadata.prediction_length, error, covariates=covariates
            )
        )

scenario_list.append( 
    covs.CovariateScenario(
        metadata.prediction_length, 0, covariates=3, skip=1
    )
)

In [None]:
for scenario in scenario_list[:1]:
    nnts.torch.utils.seed_everything(scenario.seed)
    df, scenario = covs.prepare(df_orig.copy(), scenario)
    trn_dl, val_dl, test_dl = nnts.torch.utils.create_dataloaders(
        df,
        nnts.datasets.split_test_val_train_last_horizon,
        metadata.context_length,
        metadata.prediction_length,
        Dataset=nnts.torch.datasets.TimeseriesDataset,
        dataset_options={
            "context_length": metadata.context_length,
            "prediction_length": metadata.prediction_length,
            "conts": scenario.conts,
        },
        batch_size=params.batch_size,
    )

    net = nnts.torch.models.BaseLSTM(
        nnts.torch.models.LinearModel,
        params,
        preprocessing.masked_mean_abs_scaling,
        scenario.covariates + 1,
    )
    logger = nnts.loggers.LocalFileRun(
        project=f"{model_name}-{metadata.dataset}",
        name=scenario.name,
        config={
            **params.__dict__,
            **metadata.__dict__,
            **scenario.__dict__,
        },
        path=PATH
    )
    trner = trainers.ValidationTorchEpochTrainer(
        nnts.trainers.TrainerState(),
        net,
        params,
        metadata,
        os.path.join(PATH, f"{scenario.name}.pt"),
        loss_fn=F.smooth_l1_loss,
    )
    logger.configure(trner.events)
    evaluator = trner.train(trn_dl, val_dl)
    handle = net.decoder.register_forward_hook(logger.log_activations)
    y_hat, y = evaluator.evaluate(
        test_dl, scenario.prediction_length, metadata.context_length, hooks=handle
    )
    handle.remove()
    test_metrics = nnts.metrics.calc_metrics(
        y, y_hat, nnts.metrics.calculate_seasonal_error(trn_dl, metadata.seasonality)
    )
    logger.log(test_metrics)
    logger.finish()

In [None]:
scenario = scenario_list[0]
import torch
import nnts.torch.trainers
import nnets.torch.utils

In [None]:
nnts.torch.utils.seed_everything(scenario.seed)
df, scenario = covs.prepare(df_orig.copy(), scenario)
trn_dl, val_dl, test_dl = nnts.torch.utils.create_dataloaders(
    df,
    nnts.datasets.split_test_val_train_last_horizon,
    metadata.context_length,
    metadata.prediction_length,
    Dataset=nnts.torch.datasets.TimeseriesDataset,
    dataset_options={
        "context_length": metadata.context_length,
        "prediction_length": metadata.prediction_length,
        "conts": scenario.conts,
    },
    batch_size=params.batch_size,
)
net = covs.model_factory(model_name, params, scenario, metadata)
best_state_dict = torch.load(f"{PATH}/{scenario.name}.pt")
net.load_state_dict(best_state_dict)

class ActivationVisualizer:
    def __init__(self):
        self.activations = []

    def hook_handler(self, module, input, output):
        input_0 = output[0][0]
        self.activations.append(input_0[:, -1].detach().cpu().numpy())

visualiser = ActivationVisualizer()
handle = net.decoder.register_forward_hook(visualiser.hook_handler)
net.eval()

In [None]:
batch = next(iter(test_dl))

In [None]:
with torch.no_grad():
    output = nnts.torch.trainers.validate(net, batch, scenario.prediction_length, metadata.context_length)

In [None]:
handle.remove()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.figure(figsize=(10, 6))
sns.heatmap(visualiser.activations, cmap="coolwarm", linewidths=0.5)
plt.show()

In [None]:
batch["X"][0, : metadata.context_length, ...].cpu().numpy(), batch["X"][
    0, - scenario.prediction_length:, ...
].cpu().numpy()

In [None]:
output[0][0]

In [None]:
plt.figure(figsize=(10, 6))
sns.heatmap(output[0][0], cmap="coolwarm", linewidths=0.5)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
sns.heatmap([ v[-1:] for v in visualiser.activations], cmap="coolwarm", linewidths=0.5)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
sns.heatmap([v[-2:] for v in visualiser.activations], cmap="coolwarm", linewidths=0.5)
plt.show()

In [None]:
import nnts.datasets


csv_aggregator = nnts.utils.CSVFileAggregator(PATH, "results")
results = csv_aggregator()

In [None]:
import nnts.experiments.plotting


df_list = covs.add_y_hat(df, y_hat, scenario.prediction_length)
sample_preds = nnts.experiments.plotting.plot(df_list, scenario.prediction_length)

In [None]:
univariate_results = results.loc[
    (results["covariates"] == 0)
    & (results["prediction_length"] == metadata.prediction_length),
    ["smape", "mape", "rmse", "mae"],
]

univariate_results.mean(), univariate_results.std(), univariate_results.count()

In [None]:
cols = ["dataset", "error", "pearson", "covariates", "prediction_length", "smape"]

In [None]:
results.loc[
    (results["covariates"] > 0)
    & (results["error"].isin([0.000000, 0.4714285714285714, 1.65])),
    cols,
].sort_values(by=["covariates", "error"])

In [None]:
sorted(results["error"].unique().tolist())