In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import json
import seaborn as sns
from tqdm.notebook import tqdm

import os
import sys
import warnings
from dotenv import load_dotenv

load_dotenv()
warnings.filterwarnings("ignore")
REPO_PATH = os.getenv("REPO_PATH")
sys.path.insert(0, rf'{REPO_PATH}src')

from utils.forecast_utils import ForecastModel

### Load RNN models and calculate error metrics

In [None]:
MODEL_NAMES = [
    'CLc1_BiGRU_test_1_2024.05.26_12.59',
    'CLc1_BiLSTM_test_1_2024.05.26_12.46'
]

metric_list = []
model_dict = dict()
for i, model_name in enumerate(tqdm(MODEL_NAMES, desc='Loading models')):
    model = ForecastModel(model_name)
    model_dict[model_name] = model
    metric_list.append(model.describe())

metric_df = pd.DataFrame(metric_list, index=MODEL_NAMES)

display(metric_df)

In [None]:
# Evaluate the model
view = 500

fig, ax = plt.subplots(figsize=(10, 5), dpi=200)

for i, model in enumerate(model_dict.values()):
    if i == 0:
        actual = model.test_targets[-view:]
        ax.plot(
            actual, 
            label=f'Actual {model.model_name.split("_")[0]} RV', 
            color='gray', 
            lw=0.8
        )

    ax.plot(model.test_predictions[-view:], label=' '.join(model.model_name.split('_')[1:5]), lw=0.8)
    
ax.set_title('Realized 5-min volatiltiy, Model Fit vs Actual')
ax.set_xlabel('Samples')
ax.set_ylabel('Price')
ax.legend(frameon=False)
ax.grid(alpha=0.3)

### Loss comparison

In [None]:


loss_df_list= list()
for i, model_name in enumerate(MODEL_NAMES):
    with open(f'model_archive/{model_name}/loss_data.json', 'r') as f:
        loss_dict = json.load(f)
        
        loss_df = pd.DataFrame(loss_dict).add_suffix(f'_{model_name}')
        loss_df_list.append(loss_df)

loss_df = pd.concat(loss_df_list, axis=1)

colors = sns.color_palette('twilight', n_colors=len(MODEL_NAMES))
                           
fig, ax = plt.subplots(figsize=(10, 5), dpi=200)
loss_df.filter(like='train').plot(
    title='Model Loss', 
    figsize=(10, 5), 
    lw=0.8, 
    ax=ax, 
    ls='--',
    color=colors
)
loss_df.filter(like='val').plot(
    title='Model Loss', 
    figsize=(10, 5), 
    lw=0.8, 
    ax=ax,
    color=colors,
)
ax.set_yscale('log')
ax.legend(frameon=False)
ax.grid(alpha=0.3)