In [1]:
%%capture
import os
from IPython.utils import io
from datetime import datetime

%cd ../src

In [2]:
### load local libraries
from helper_config import Config
from helper_trainer import train
from helper_utils import prepare_configuration

### GARCH config
garch_config = [
    # single model (without garch)
    [],
    
    # single garch models
    ['garch'],
    ['egarch'],
    ['gjr_garch'],
    
    # combine two garch models
    ['garch', 'egarch'],
    ['garch', 'gjr_garch'],
    ['egarch', 'gjr_garch'],
    
    # combine three garch models
    ['garch', 'egarch', 'gjr_garch'],
]

def train_for_index(index):
    for window in [7, 14, 21, 28]:      # iterate over rolling window size
        for model in ['LSTM', 'GRU']:#, 'Peephole_LSTM']:
            for garch_cfg in garch_config:
                for use_commodity_prices in [False, True]:
                    model_type = [model] + garch_cfg
                    comm_label = 'with_commodity_prices' if use_commodity_prices else 'without_commodity_prices'
                    
                    specs = dict(
                        index       = index,
                        window      = window,
                        model_type  = model_type,
                        use_commodity_prices = use_commodity_prices,
                        model_name  = f'{index}/log_sq_rtn/{window}/{"_".join(model_type)}/{comm_label}',
                        data_dir    = '../inputs',
                        models_dir  = '../models'
                    )
                    
                    print(f'[{datetime.now()}]', specs['model_name'])
                    
                    with io.capture_output() as captured:
                        # prepare configuration
                        config = prepare_configuration(specs = specs)
                    
                        # Train model
                        _ = train(config)

In [3]:
# Train
commodities = ['Gold_Features.csv', 'Crude_Oil_WTI Futures.csv']
#indices = [f.replace('.csv', '') for f in os.listdir('../inputs/data') if (f not in commodities) & ('.csv' in f)]
indices = ['BIST100']
for index in indices:
    print('-'*100)
    train_for_index(index = index)

----------------------------------------------------------------------------------------------------
[2023-09-20 03:38:15.848717] BIST100/log_sq_rtn/7/LSTM/without_commodity_prices
[2023-09-20 03:39:55.256908] BIST100/log_sq_rtn/7/LSTM/with_commodity_prices
