In [1]:
%%capture
import os
from IPython.utils import io
from datetime import datetime

%cd ../src

In [2]:
### load local libraries
from helper_config import Config
from helper_trainer import train
from helper_utils import prepare_configuration

### GARCH config
garch_config = [
    # single model (without garch)
    [],
    
    # single garch models
    ['garch'],
    ['egarch'],
    ['gjr_garch'],
    
    # combine two garch models
    ['garch', 'egarch'],
    ['garch', 'gjr_garch'],
    ['egarch', 'gjr_garch'],
    
    # combine three garch models
    ['garch', 'egarch', 'gjr_garch'],
]

def train_for_index(index):
    for window in [21, 28]:      # iterate over rolling window size
        for model in ['LSTM', 'Peephole_LSTM', 'GRU']:
            for garch_cfg in garch_config:
                for use_commodity_prices in [True, False]:
                    model_type = [model] + garch_cfg
                    comm_label = 'with_commodity_prices' if use_commodity_prices else 'without_commodity_prices'
                    
                    specs = dict(
                        index       = index,
                        window      = window,
                        model_type  = model_type,
                        use_commodity_prices = use_commodity_prices,
                        model_name  = f'{index}/{window}/{"_".join(model_type)}/{comm_label}',
                        data_dir    = '../inputs',
                        models_dir  = '../models'
                    )
                    
                    print(f'[{datetime.now()}]', specs['model_name'])
                    
                    with io.capture_output() as captured:
                        # prepare configuration
                        config = prepare_configuration(specs = specs)
                    
                        # Train model
                        _ = train(config)

In [3]:
# Train
commodities = ['Gold_Features.csv', 'Crude_Oil_WTI Futures.csv']
#indices = [f.replace('.csv', '') for f in os.listdir('../inputs/data') if (f not in commodities) & ('.csv' in f)]
indices = ['ATX']
for index in indices:
    print('-'*100)
    train_for_index(index = index)

----------------------------------------------------------------------------------------------------
[2023-09-16 12:23:34.554399] ATX/21/LSTM/with_commodity_prices
[2023-09-16 12:24:29.901348] ATX/21/LSTM/without_commodity_prices
[2023-09-16 12:24:56.980606] ATX/21/LSTM_garch/with_commodity_prices
[2023-09-16 12:26:06.012718] ATX/21/LSTM_garch/without_commodity_prices
[2023-09-16 12:26:48.209977] ATX/21/LSTM_egarch/with_commodity_prices
[2023-09-16 12:27:56.876992] ATX/21/LSTM_egarch/without_commodity_prices
[2023-09-16 12:28:37.075243] ATX/21/LSTM_gjr_garch/with_commodity_prices
[2023-09-16 12:29:46.015132] ATX/21/LSTM_gjr_garch/without_commodity_prices
[2023-09-16 12:30:28.047177] ATX/21/LSTM_garch_egarch/with_commodity_prices
[2023-09-16 12:31:43.633460] ATX/21/LSTM_garch_egarch/without_commodity_prices
[2023-09-16 12:32:33.109478] ATX/21/LSTM_garch_gjr_garch/with_commodity_prices
[2023-09-16 12:33:49.318949] ATX/21/LSTM_garch_gjr_garch/without_commodity_prices
[2023-09-16 12:34:39.