In [None]:
import pandas as pd
import yaml
import json
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
import os
import sys
import warnings
from dotenv import load_dotenv

warnings.filterwarnings("ignore")
load_dotenv()
REPO_PATH = os.getenv("REPO_PATH")
sys.path.insert(0, rf'{REPO_PATH}src_HF')

from utils.forecast_utils import ForecastModel

### Load models

In [None]:
MODEL_NAMES = [
    'CLc1_BiLSTM_sent_2024.05.23_19.11',
    'CLc1_BiLSTM_loss_2024.05.23_19.21'
]

model_dict = dict()
for i, model_name in enumerate(tqdm(MODEL_NAMES, desc='Loading models')):

    model_dict[model_name] = ForecastModel(model_name)


In [None]:
for model in model_dict.values():
    model.describe()
    print()

In [None]:
# Evaluate the model
view = 300

fig, ax = plt.subplots(figsize=(10, 5), dpi=200)

for i, model in enumerate(model_dict.values()):
    if i == 0:
        actual = model.test_targets[-view:]
        ax.plot(actual, label='Actual ' + model.model_name.split('_')[0], color='gray', lw=0.8)

    ax.plot(model.test_predictions[-view:], label=' '.join(model.model_name.split('_')[1:5]), lw=0.8)
    
ax.set_title('Realized 5-min volatiltiy, Model Fit vs Actual')
ax.set_xlabel('Samples')
ax.set_ylabel('Price')
ax.legend(frameon=False)
ax.grid(alpha=0.3)