In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import json
import seaborn as sns
from tqdm.notebook import tqdm

import os
import sys
import warnings
from dotenv import load_dotenv

load_dotenv()
warnings.filterwarnings("ignore")
REPO_PATH = os.getenv("REPO_PATH")
sys.path.insert(0, rf'{REPO_PATH}src')

from utils.forecast_utils import ForecastPredictions, forecast_rnn, load_models
from utils.var_utils import forecast_var

### RNN model results

In [None]:
BENCHMARK = 'CLc1_VAR_BASE_2024.05.30_16.00'

MODEL_NAMES = [
    'CLc1_GRU_BASE_2024.05.29_12.03',
    'CLc1_GRU_BASE(T)_2024.05.30_18.38'
]

benchmark_fc = ForecastPredictions(BENCHMARK, forecast_var)

model_dict, metric_df = load_models(
    MODEL_NAMES,
    benchmark_fc
)

pd.options.display.float_format = '{:.4f}'.format
display(metric_df)

In [None]:
# Evaluate the model
view = 400

fig, ax = plt.subplots(figsize=(10, 5), dpi=200)

colors = sns.color_palette('twilight', n_colors=len(model_dict))

PLOT_MODEL = MODEL_NAMES

for i, model_name in enumerate(PLOT_MODEL):
    model = model_dict[model_name]
    if i == 0:
        actual = model.y_test[-view:]
        ax.plot(
            actual,
            label=f'Actual {model.model_name.split("_")[0]} RV', 
            color='gray', 
            lw=0.8
        )
    ax.plot(
        model.y_pred[-view:], 
        label=' '.join(model.model_name.split('_')[1:3]),
        color=colors[i],
        lw=0.8
    )
    
ax.set_title('Realized 5-min volatiltiy, Model Fit vs Actual')
ax.set_xlabel('Time (5-min intervals)')
ax.set_ylabel('Realized Volatility')
ax.legend(frameon=False, ncols=3)
ax.grid(alpha=0.3)

### Loss comparison

In [None]:
loss_df_list= list()
for i, model_name in enumerate(MODEL_NAMES):
    if model_name.split('_')[1] == 'VAR':
        continue
    with open(f'model_archive/{model_name}/loss_data.json', 'r') as f:
        loss_dict = json.load(f)
        
        loss_df = pd.DataFrame(loss_dict).add_suffix(f'_{model_name}')
        loss_df_list.append(loss_df)

loss_df = pd.concat(loss_df_list, axis=1)

colors = sns.color_palette('twilight', n_colors=len(MODEL_NAMES))
                           
fig, ax = plt.subplots(figsize=(10, 5), dpi=200)
loss_df.filter(like='train').plot(
    title='Model Loss', 
    figsize=(10, 5), 
    lw=0.8, 
    ax=ax, 
    ls='--',
    color=colors
)
loss_df.filter(like='val').plot(
    title='Model Loss',
    figsize=(10, 5), 
    lw=0.8, 
    ax=ax,
    color=colors,
)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss [MSE]')
ax.set_yscale('log')
ax.legend(frameon=False)
ax.grid(alpha=0.3)