In [None]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
sns.set_palette('icefire')

import warnings
warnings.filterwarnings('ignore')

import pickle

In [None]:
index = 'PSEi'
root = f'../models/{index}/log_sq_rtn'

### Summaries (Project metrics)

In [None]:
results = []
for path in glob.glob(f'{root}/*/*/*'):
    try:
        with open(os.path.join(path, 'project_metrics-1.pkl'), 'rb') as f:
            data = pickle.load(f)
            
        [index, _, window, model_config, commodity] = path.split('/')[2:]
        base_model = 'Peephole_LSTM' if model_config.startswith('Peephole_LSTM') else ('GRU' if 'GRU' in model_config else 'LSTM')
        model_config = model_config.replace(f'{base_model}_', '').replace(f'{base_model}', '')
        r = dict(index = index, window = window, base_model = base_model, model_config = model_config, commodity = commodity)
        r.update(data)
        results.append(r)
    except:
        print(path)

In [None]:
df = pd.DataFrame(results)
df['window'] = df['window'].apply(int)
df.head()

In [None]:
# Save summaries
df.to_csv(f'{root}/summary.csv')

##### Effect of training with commodity prices

In [None]:
plt.figure(figsize = (16, 4))
order = ['', 'garch', 'egarch','gjr_garch', 'garch_egarch', 'garch_gjr_garch', 'egarch_gjr_garch', 'garch_egarch_gjr_garch']
sns.boxplot(data = df, x = 'model_config', y = 'mse', hue = 'base_model', order = order)
plt.show()

In [None]:
plt.figure(figsize = (6, 4))
sns.boxplot(data = df, x = 'window', y = 'mse', hue = 'base_model')#, order = order)
plt.show()

##### Performance comparision

In [None]:
## Read predictions
dfs = []
for path in glob.glob(f'{root}/*/*/*'):
    with open(f'{path}/valid_predictions-1.pkl', 'rb') as f:
        preds = pickle.load(f)
        preds[-1] = preds[-1].flatten()
        df_preds = pd.DataFrame(preds).T
        df_preds.columns = ['date', 'gt', 'prediction']
        [index, _, window, model_config, commodity] = path.split('/')[2:]
        base_model = 'Peephole_LSTM' if model_config.startswith('Peephole_LSTM') else ('GRU' if 'GRU' in model_config else 'LSTM')
        model_config = model_config.replace(f'{base_model}_', '').replace(f'{base_model}', '')
        df_preds[['index', 'window', 'base_model' ,'model_config', 'commodity']] = [index, window, base_model, model_config, commodity]
        df_preds['date'] = pd.to_datetime(df_preds['date'])
        dfs.append(df_preds)
        
df_preds = pd.concat(dfs)
df_preds['window'] = df_preds['window'].apply(int)
df_preds.shape

In [None]:
model_config = ['', 'garch', 'egarch', 'gjr_garch', 'garch_egarch', 'garch_gjr_garch', 'egarch_gjr_garch', 'garch_egarch_gjr_garch']

In [None]:
def make_plot(base_model, window, save = False, show_plot = True):
    print(f'Base model: {base_model}, window: {window}')
    fig, axes = plt.subplots(4, 2, figsize = (12, 12), sharex = False)
    for idx, ax in enumerate(axes.flatten()):
        if idx < len(model_config):
            sub = df_preds[(df_preds['base_model'] == base_model) & \
                (df_preds['model_config'] == model_config[idx]) & \
                (df_preds['window'] == window)]
            sns.lineplot(data = sub[sub['commodity'] == 'with_commodity_prices'], ci = None,
                         x = 'date', y = 'gt', color = 'blue', label = 'Ground truth',  ax = ax)
            sns.lineplot(data = sub[sub['commodity'] == 'with_commodity_prices'], ci = None,
                        x = 'date', y = 'prediction', label = 'with_commodity', color = 'red', ax = ax)
            sns.lineplot(data = sub[sub['commodity'] == 'without_commodity_prices'], ci = None,
                        x = 'date', y = 'prediction', label = 'without_commodity', color = 'orange', ax = ax)
            ax.set_title(model_config[idx])
            ax.set_ylabel('Realized volatility')
            ax.legend()
            ax.set_xlabel(None)
        else:
            ax.remove()

    plt.tight_layout()
    os.makedirs(f'{root}/plots', exist_ok= True)
    if save:
        plt.savefig(f'{root}/plots/{base_model}_{window}.png', bbox_inches='tight')
    
    if show_plot:
        plt.show()
    else:
        plt.close()

In [None]:
for base_model in ['LSTM', 'GRU', 'Peephole_LSTM']:
    for window in [7, 14, 21, 28]:
        make_plot(base_model, window, save = True, show_plot = False)