## Imports

In [None]:
from pathlib import Path

import os
import torch
import skorch
import numpy as np
import pandas as pd
from skorch.callbacks import EarlyStopping
from tqdm import tqdm
from matplotlib import pyplot as plt
from time_series_predictor import TimeSeriesPredictor

from src.oze_dataset import OzeEvaluationDataset, OzeNPZDataset, npz_check, OZELoss
from src.model import BenchmarkLSTM

## Construct and configure the time series predictor

In [None]:
tsp = TimeSeriesPredictor(
    BenchmarkLSTM(
        hidden_dim=70,
        num_layers=3
    ),
    early_stopping=EarlyStopping(patience=30),
    lr=1e-2,
    max_epochs=500,
    train_split=skorch.dataset.CVSplit(5),
    optimizer=torch.optim.Adam,
    criterion=torch.nn.MSELoss    # OZELoss(alpha=0.3)
)

## Load the dataset

In [None]:
labels_path = os.path.join('src', 'oze_dataset', 'labels.json')
dataset = OzeNPZDataset(
    dataset_path=npz_check(
        Path('datasets'),
        'dataset'
    ),
    labels_path=labels_path
)

## Train the benchmark

In [None]:
tsp.fit(dataset)

## Plot training evolution

In [None]:
# train_loss = tsp.ttr.regressor_['regressor'].history[:, 'train_loss']
# valid_loss = tsp.ttr.regressor_['regressor'].history[:, 'valid_loss']
# plt.figure(figsize=(20, 20))
# plt.plot(train_loss, 'o-', label='training')
# plt.plot(valid_loss, 'o-', label='validation')
# axes = plt.axes()
# axes.set_xlabel('Epoch')
# axes.set_ylabel('MSE')
# plt.legend()

### Plot the results for a training example

In [None]:
# # Select training example
# idx = np.random.randint(0, len(tsp.dataset))
# dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
# x, y = dataloader.dataset[idx]

# # Run predictions
# netout = tsp.sample_predict(x)

# d_output = netout.shape[1]
# plt.figure(figsize=(30, 30))
# for idx_output_var in range(d_output):
#     # Select real temperature
#     y_true = y[:, idx_output_var]

#     y_pred = netout[:, idx_output_var]

#     plt.subplot(d_output, 1, idx_output_var+1)
    
#     plt.plot(y_true, label="Truth")
#     plt.plot(y_pred, label="Prediction")
#     plt.title(dataloader.dataset.labels["X"][idx_output_var])
#     plt.legend()
# plt.savefig("fig")

# Evaluation

### Load evaluation dataset

In [None]:
K=tsp.dataset.get_x_shape()[1]
dataset_eval = OzeEvaluationDataset(
    os.path.join('datasets', 'x_test_QK7dVsy.csv'),
    K,
    labels_path=labels_path
)
dataset_eval_length = len(dataset_eval)

### Run prediction

In [None]:
# predictions = tsp.predict(dataset_eval.x)

In [None]:
predictions = tsp.forecast(dataset_eval_length)
predictions = predictions[-dataset_eval_length:,:,:]

### Export as csv

In [None]:
lines_output = predictions.reshape((dataset_eval_length, -1))
csv_header = [f"{var_name}_{k}" for var_name in dataset_eval.labels['X'] for k in range(K)]

df = pd.DataFrame(lines_output, columns=csv_header)
df.insert(0, 'index', df.index + len(dataset))
df.to_csv('y_bench.csv', index=False)