In [None]:
# In a Jupyter cell

import pandas as pd
import sys
import os
import torch

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import config
from src.model import load_model
from src.visualization import plot_predictions, plot_feature_importance

# Load the processed data and the trained model
df = pd.read_csv(config.PROCESSED_DATA_FILE, index_col='time', parse_dates=True)
model = load_model(config.MODEL_FILE)

# Create a dataloader for prediction on the last part of the dataset
encoder_data = df[lambda x: x.time_idx > x.time_idx.max() - config.ENCODER_LENGTH]
last_data = df[lambda x: x.time_idx == x.time_idx.max()]
decoder_data = df[lambda x: x.time_idx > x.time_idx.max() - config.ENCODER_LENGTH]


# Make predictions
raw_predictions = model.predict(decoder_data, mode="prediction", return_x=True)
predictions = model.transform_output(raw_predictions, target_normalizer=model.target_normalizer)

# Create a readable DataFrame for predictions
prediction_start_time = df.index[-1] + pd.Timedelta(hours=1)
prediction_index = pd.date_range(start=prediction_start_time, periods=config.PREDICTION_LENGTH, freq='H')
prediction_df = pd.DataFrame({
    'prediction': predictions[0].numpy(),
    'p10': raw_predictions.output[0][:, 2].numpy(), # Quantile 0.1
    'p90': raw_predictions.output[0][:, 5].numpy()  # Quantile 0.9
}, index=prediction_index)


# Visualize predictions
hist_df_for_plot = df.tail(config.ENCODER_LENGTH)
plot_predictions(hist_df_for_plot, prediction_df)

# Visualize feature importance
interp = model.interpret_output(raw_predictions.output, reduction="sum")
feature_names = model.encoder_variables + model.decoder_variables
fig = plot_feature_importance(interp, feature_names)
fig.savefig(config.FIGURES_DIR / "feature_importance.png")
plt.show()