### Imports and setup

In [None]:
import pandas as pd
import yaml
import matplotlib.pyplot as plt

import os
import sys
import warnings
from dotenv import load_dotenv

load_dotenv()
REPO_PATH = os.getenv("REPO_PATH")
sys.path.insert(0, rf'{REPO_PATH}src')

from utils.model_utils import save_model_info, train_RNN
from utils.forecast_utils import ForecastPredictions

with open(f'{REPO_PATH}variable_config.yaml', 'r') as file:
    var_config = yaml.load(file, Loader=yaml.FullLoader)

warnings.filterwarnings("ignore")
pd.options.display.float_format = '{:.4f}'.format

### Build and fit model


In [None]:
# Define feature list and target
SELECTED_FEATURES = [
    *var_config['BASE'],
    *var_config['TEMPORAL'],
    *var_config['S1']
]

RNN_TYPE: str ='LSTM'
FUTURE: str = 'CLc1'
MAX_EPOCHS: int = 150
IDENTIFIER: str = 'S1'

MODEL_PARAMS: dict[str, any] = {
    "units_first_layer": 128,
    "units_second_layer": 96,
    "l2_strength": 0.0001251522417143188,
    "learning_rate": 0.0010428078227294694,
    "batch_size": 64,
    "noise_std": 0.04627878070884051,
    "window_size": 20
}

DATA_PARAMS: dict[str, any] = {
    'feature_columns': SELECTED_FEATURES,
    'target_column': 'REALIZED_VOL',
    'test_size': 0.2,
    'val_size': 0.2,
    'scaler_type': 'RobustScaler'
}

model, gen, loss_dict = train_RNN(
    FUTURE,
    DATA_PARAMS,
    MODEL_PARAMS,
    RNN_TYPE,
    MAX_EPOCHS
)

current_dt: str = pd.Timestamp.now().strftime('%Y.%m.%d_%H.%M')
model_name: str = f'{FUTURE}_{RNN_TYPE}_{IDENTIFIER}_{current_dt}'
# save model
save_model_info(
    model,
    model_name,
    MODEL_PARAMS,
    DATA_PARAMS,
    loss_dict
)


In [None]:
# Evaluate the model
VIEW = 500

forecast = ForecastPredictions(model_name)

metrics = pd.DataFrame(
    forecast.metrics(), 
    index=[IDENTIFIER]
).T
metrics.index = metrics.index.str.upper()

fig, ax = plt.subplots(figsize=(10, 5), dpi=200)
ax.plot(forecast.y_test[-VIEW:], label='Actual', lw=0.7)
ax.plot(forecast.y_pred[-VIEW:], label='Predicted', lw=0.7)
ax.set_title('Model Fit vs Actual')
ax.set_xlabel('Samples')
ax.set_ylabel('Price')
ax.legend(frameon=False)
ax.grid(alpha=0.3)

display(metrics)