In [1]:
%%capture
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive')

In [2]:
%%capture
import os
from IPython.utils import io
from datetime import datetime

%cd /content/drive/MyDrive/Developer/trading/Volatility Predictor/src

In [3]:
%%capture
!pip install arch

In [4]:
### load local libraries
from helper_config import Config
from helper_trainer import train
from helper_utils import prepare_configuration

### GARCH config
garch_config = [
    # single model (without garch)
    [],

    # single garch models
    ['garch'],
    ['egarch'],
    ['gjr_garch'],

    # combine two garch models
    ['garch', 'egarch'],
    ['garch', 'gjr_garch'],
    ['egarch', 'gjr_garch'],

    # combine three garch models
    ['garch', 'egarch', 'gjr_garch'],
]

root = '/content/drive/MyDrive/Developer/trading/Volatility Predictor'
def train_for_index(index):
    for window in [7, 14, 21, 28]:      # iterate over rolling window size
        for model in ['LSTM', 'GRU', 'Peephole_LSTM']:
            for garch_cfg in garch_config:
                for use_commodity_prices in [True]:
                    model_type = [model] + garch_cfg
                    comm_label = 'with_commodity_prices' if use_commodity_prices else 'without_commodity_prices'

                    specs = dict(
                        index       = index,
                        window      = window,
                        model_type  = model_type,
                        use_commodity_prices = use_commodity_prices,
                        model_name  = f'{index}/log_sq_rtn/{window}/{"_".join(model_type)}/{comm_label}',
                        data_dir    = f'{root}/inputs',
                        models_dir  = f'{root}/models'
                    )

                    print(f'[{datetime.now()}]', specs['model_name'])

                    with io.capture_output() as captured:
                        # prepare configuration
                        config = prepare_configuration(specs = specs)

                        # Train model
                        _ = train(config)

In [None]:
# Train
commodities = ['Gold_Features.csv', 'Crude_Oil_WTI Futures.csv']
indices = [f.replace('.csv', '') for f in os.listdir('../inputs/data') if (f not in commodities) & ('.csv' in f)]
#ignore_list = ['BIST100']
#indices = [index for index in indices if index not in ignore_list]
indices = ['BIST100']
for index in indices:
    print('-'*100)
    train_for_index(index = index)

----------------------------------------------------------------------------------------------------
[2023-09-24 13:05:30.874619] BIST100/log_sq_rtn/7/LSTM/with_commodity_prices
[2023-09-24 13:05:55.697481] BIST100/log_sq_rtn/7/LSTM_garch/with_commodity_prices
[2023-09-24 13:06:14.325160] BIST100/log_sq_rtn/7/LSTM_egarch/with_commodity_prices
[2023-09-24 13:06:33.457038] BIST100/log_sq_rtn/7/LSTM_gjr_garch/with_commodity_prices
[2023-09-24 13:06:51.966294] BIST100/log_sq_rtn/7/LSTM_garch_egarch/with_commodity_prices
[2023-09-24 13:07:11.788017] BIST100/log_sq_rtn/7/LSTM_garch_gjr_garch/with_commodity_prices
[2023-09-24 13:07:30.214562] BIST100/log_sq_rtn/7/LSTM_egarch_gjr_garch/with_commodity_prices
[2023-09-24 13:07:48.528331] BIST100/log_sq_rtn/7/LSTM_garch_egarch_gjr_garch/with_commodity_prices
[2023-09-24 13:08:24.580802] BIST100/log_sq_rtn/7/GRU/with_commodity_prices
[2023-09-24 13:08:50.780373] BIST100/log_sq_rtn/7/GRU_garch/with_commodity_prices
[2023-09-24 13:09:17.312576] BIST