In [1]:
%%capture
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive')

import os
from IPython.utils import io
from datetime import datetime

%cd /content/drive/MyDrive/Developer/trading/Volatility Predictor/src

In [2]:
%%capture
!pip install arch

In [3]:
### load local libraries
from helper_config import Config
from helper_trainer import train
from helper_utils import prepare_configuration

### GARCH config
garch_config = [
    # single model (without garch)
    [],

    # single garch models
    ['garch'],
    ['egarch'],
    ['gjr_garch'],

    # combine two garch models
    ['garch', 'egarch'],
    ['garch', 'gjr_garch'],
    ['egarch', 'gjr_garch'],

    # combine three garch models
    ['garch', 'egarch', 'gjr_garch'],
]

def train_for_index(index):
    for window in [7, 14, 21, 28]:      # iterate over rolling window size
        for model in ['LSTM', 'Peephole_LSTM', 'GRU']:
            for garch_cfg in garch_config:
                for use_commodity_prices in [True, False]:
                    model_type = [model] + garch_cfg
                    comm_label = 'with_commodity_prices' if use_commodity_prices else 'without_commodity_prices'

                    specs = dict(
                        index       = index,
                        window      = window,
                        model_type  = model_type,
                        use_commodity_prices = use_commodity_prices,
                        model_name  = f'{index}/{window}/{"_".join(model_type)}/{comm_label}',
                        data_dir    = '/content/drive/MyDrive/Developer/trading/Volatility Predictor/inputs',
                        models_dir  = '/content/drive/MyDrive/Developer/trading/Volatility Predictor/models'
                    )

                    # prepare configuration
                    config = prepare_configuration(specs = specs)

                    # Train model
                    _ = train(config)

In [None]:
# Train
commodities = ['Gold_Features.csv', 'Crude_Oil_WTI Futures.csv']
#indices = [f.replace('.csv', '') for f in os.listdir('../inputs/data') if (f not in commodities) & ('.csv' in f)]
indices = ['NSE']
for index in indices:
    print(f'[{datetime.now()}] Training {index}')
    with io.capture_output() as captured:   #capture print statements
        train_for_index(index = index)