# Import packages

In [1]:
import os
import random
import re
import requests
import time
import typing
from typing import Any, Callable, Dict, Type
import warnings

from boruta import BorutaPy
import numpy as np
import optuna
from optuna.visualization import plot_optimization_history, plot_contour, plot_edf, \
    plot_intermediate_values, plot_optimization_history, plot_parallel_coordinate, \
    plot_param_importances, plot_slice
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sb3_contrib.tqc import TQC
from stable_baselines3.a2c import A2C
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.sac import SAC
import stockstats
import tushare
import yfinance as yf

from environment.MultiStockTradingEnv import MultiStockTradingEnv
from utils.sample_funcs import *
from utils.utils import *

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
warnings.simplefilter(action='ignore', category=FutureWarning)

# Setup directories
DATA_SAVE_DIR = 'datasets'
MODEL_DIR = 'models'
TENSORBOARD_LOG_DIR = 'tensorboard_log'
RAW_DATA_DIR = os.path.join(DATA_SAVE_DIR, 'raw')
CLEAN_DATA_DIR = os.path.join(DATA_SAVE_DIR, 'clean')
PREPROCESSED_DATA_DIR = os.path.join(DATA_SAVE_DIR, 'preprocessed')

check_and_make_directories([DATA_SAVE_DIR, MODEL_DIR, TENSORBOARD_LOG_DIR, \
     RAW_DATA_DIR, CLEAN_DATA_DIR, PREPROCESSED_DATA_DIR])

TRAIN_START_DAY = '2008-01-01'
TRAIN_END_DAY = '2016-12-31'
TEST_START_DAY = '2017-01-01'
TEST_END_DAY = '2019-12-31'
TRADE_START_DAY = '2020-01-01'
TRADE_END_DAY = '2022-12-31'

tushare_token = '2bf5fdb105eefda26ef27cc9caa94e6f31ca66e408f7cc54d4fce032'

# Download data

## Retrieve SSE 50 component list

In [31]:
# SSE 50 components from http://www.sse.com.cn/market/sseindex/indexlist/basic/index.shtml?COMPANY_CODE=000016&INDEX_Code=000016&type=1
SSE50_COM = \
"""包钢股份(600010)	中国石化(600028)	中信证券(600030)
三一重工(600031)	招商银行(600036)	保利发展(600048)
上汽集团(600104)	北方稀土(600111)	复星医药(600196)
恒瑞医药(600276)	万华化学(600309)	恒力石化(600346)
国电南瑞(600406)	片仔癀(600436)	通威股份(600438)
贵州茅台(600519)	海螺水泥(600585)	海尔智家(600690)
闻泰科技(600745)	山西汾酒(600809)	伊利股份(600887)
航发动力(600893)	长江电力(600900)	三峡能源(600905)
隆基绿能(601012)	中信建投(601066)	中国神华(601088)
兴业银行(601166)	陕西煤业(601225)	农业银行(601288)
中国平安(601318)	工商银行(601398)	中国太保(601601)
中国人寿(601628)	长城汽车(601633)	中国建筑(601668)
中国电建(601669)	华泰证券(601688)	中国石油(601857)
中国中免(601888)	紫金矿业(601899)	中远海控(601919)
中金公司(601995)	药明康德(603259)	合盛硅业(603260)
海天味业(603288)	韦尔股份(603501)	华友钴业(603799)
兆易创新(603986)	天合光能(688599)"""

In [38]:
tic_list = re.findall(r'\d{6}', SSE50_COM)
tic_list = [tic+'.SS' for tic in tic_list]

## Download SSE50 tickers with yfinace

In [35]:
def download_ticker_with_yfince(tic_list: List[str], download_dir: str) -> List[str]:
    retry_list = []
    for tic in tic_list:
        csv_path = os.path.join(download_dir, f'{tic}.csv')
        if os.path.exists(csv_path):
            print(f'File {csv_path} already exist. Skip')
            continue
        
        ticker = yf.Ticker(tic)
        df = ticker.history(period='max')
        if df.shape[0] > 0:
            df.to_csv(csv_path)
            print(f'Download {tic}.csv')         
            time.sleep(0.1)
        else:
            retry_list.append(tic)
    
    return retry_list

In [39]:
save_dir = os.path.join(RAW_DATA_DIR, 'SSE50')
check_and_make_directories(save_dir)
retry_list = download_ticker_with_yfince(tic_list, save_dir)

Download 600010.SS.csv
Download 600028.SS.csv
Download 600030.SS.csv
Download 600031.SS.csv
Download 600036.SS.csv
Download 600048.SS.csv
Download 600104.SS.csv
Download 600111.SS.csv
Download 600196.SS.csv
Download 600276.SS.csv
Download 600309.SS.csv
Download 600346.SS.csv
Download 600406.SS.csv
Download 600436.SS.csv
Download 600438.SS.csv
Download 600519.SS.csv
Download 600585.SS.csv
Download 600690.SS.csv
Download 600745.SS.csv
Download 600809.SS.csv
Download 600887.SS.csv
Download 600893.SS.csv
Download 600900.SS.csv
Download 600905.SS.csv
Download 601012.SS.csv
Download 601066.SS.csv
Download 601088.SS.csv
Download 601166.SS.csv
Download 601225.SS.csv
Download 601288.SS.csv
Download 601318.SS.csv
Download 601398.SS.csv
Download 601601.SS.csv
Download 601628.SS.csv
Download 601633.SS.csv
Download 601668.SS.csv
Download 601669.SS.csv
Download 601688.SS.csv
Download 601857.SS.csv
Download 601888.SS.csv
Download 601899.SS.csv
Download 601919.SS.csv
Download 601995.SS.csv
Download 60

In [None]:
retry_list = download_ticker_with_yfince(retry_list, RAW_DATA_DIR)

## Download SZSE Growth 40 components

In [45]:
# download .xls from http://www.szse.cn/market/exponent/sample/index.html
xls = pd.read_excel('./datasets/深圳成长40指数.xlsx')
tic_list = [f'{code:06d}.SZ' for code in xls['证券代码']]

  warn("Workbook contains no default style, apply openpyxl's default")


In [46]:
save_dir = os.path.join(RAW_DATA_DIR, 'SZSEGrowth40')
check_and_make_directories(save_dir)
retry_list = download_ticker_with_yfince(tic_list, save_dir)

Download 000661.SZ.csv
Download 000725.SZ.csv
Download 002030.SZ.csv
Download 002049.SZ.csv
Download 002129.SZ.csv
Download 002271.SZ.csv
Download 002414.SZ.csv
Download 002460.SZ.csv
Download 002475.SZ.csv
Download 002555.SZ.csv
Download 002709.SZ.csv
Download 002714.SZ.csv
Download 002932.SZ.csv
Download 300014.SZ.csv
Download 300059.SZ.csv
Download 300122.SZ.csv
Download 300450.SZ.csv
Download 300502.SZ.csv
Download 300593.SZ.csv
Download 300604.SZ.csv
Download 300630.SZ.csv
Download 300638.SZ.csv
Download 300661.SZ.csv
Download 300671.SZ.csv
Download 300676.SZ.csv
Download 300677.SZ.csv
Download 300724.SZ.csv
Download 300763.SZ.csv
Download 300769.SZ.csv
Download 300772.SZ.csv
Download 300782.SZ.csv
Download 300850.SZ.csv
Download 300869.SZ.csv
Download 300888.SZ.csv
Download 300894.SZ.csv
Download 300896.SZ.csv
Download 300957.SZ.csv
Download 300973.SZ.csv
Download 301050.SZ.csv
Download 301080.SZ.csv


## Download DOW 30 components

In [48]:
tic_dict = {
    'SZSE_Growth_40': '399326.SZ',
    'SSE_50': '000016.SS',
    'DOW_30': '^DJI'
    }

In [49]:
save_dir = os.path.join(RAW_DATA_DIR, 'index')
check_and_make_directories(save_dir)
download_ticker_with_yfince(list(tic_dict.values()), save_dir)

399326.SZ: 1d data not available for startTime=-2208994789 and endTime=1675075747. Only 100 years worth of day granularity data are allowed to be fetched per request.
000016.SS: 1d data not available for startTime=-2208994789 and endTime=1675075749. Only 100 years worth of day granularity data are allowed to be fetched per request.
Download ^DJI.csv


['399326.SZ', '000016.SS']

In [None]:
# Download SSE 50 history from https://www.investing.com/indices/shanghai-se-50-historical-data
# Download SZSE Growth 40 history from https://www.investing.com/indices/szse-growth-price-historical-data

## Download index history

# Clean data

In [51]:
def get_calendar_with_tushare(start: str, end: str) -> pd.Series:
    start = start.replace('-', '')
    end = end.replace('-', '')

    tushare.set_token(tushare_token)
    tu_pro = tushare.pro_api()
    calendar_ss = tu_pro.trade_cal(exchange='SSE', start_date=start, end_date=end, is_open=1)
    calendar_sz = tu_pro.trade_cal(exchange='SZSE', start_date=start, end_date=end, is_open=1)
    if calendar_ss.shape[0] != calendar_ss.shape[0]:
        calendar = pd.merge(calendar_ss.cal_date, calendar_sz.cal_date, on=['cal_date'], how='outer')
    else:
        calendar = calendar_ss.cal_date

    calendar = pd.to_datetime(calendar, format='%Y%m%d')
    calendar.rename('date', inplace=True)
    
    return calendar

In [52]:
def clean_data_from_yfinance(data: pd.DataFrame, calendar: pd.Series = None) -> pd.DataFrame:
    # TODO: calculate adjusted price.
    data.drop(labels=['Dividends', 'Stock Splits'], axis='columns', inplace=True)
    data.rename(columns={
        'Date': 'date',
        'Open': 'open',
        'High': 'high',
        'Low': 'low',
        'Close' : 'close',
        'Volume' : 'volume'
        }, inplace=True)
    data['date'] = pd.to_datetime(data['date'].apply(lambda s: s.split(' ')[0]), format='%Y-%m-%d')
    data = pd.merge(calendar, data, how='left', on='date')

    return data

In [53]:
NA_THRESHOLD = 0.1

def clean(from_dir, to_dir):
    calendar = get_calendar_with_tushare(TRAIN_START_DAY, TRADE_END_DAY)

    na_list = []
    _, _, files = next(os.walk(from_dir))
    for file in files:
        result_path = os.path.join(to_dir, file)
        if os.path.exists(result_path):
            continue

        file_path = os.path.join(from_dir, file)
        df = pd.read_csv(file_path, index_col=False)

        df = clean_data_from_yfinance(df, calendar)

        len_df = df.shape[0]
        df.dropna(inplace=True)
        df.drop_duplicates(['open', 'high', 'low', 'close', 'volume'], inplace=True)
        print(f'{len_df - df.shape[0]} rows droped from {file}.')

        if df.shape[0] >= len(calendar) * (1 - NA_THRESHOLD):
            df.to_csv(result_path, index=False)
        else:
            na_list.append(file)
            print(f'{file}: too many NaNs, discard.')
        
    return na_list

In [56]:
from_dir = os.path.join(RAW_DATA_DIR, 'SSE50')
to_dir = os.path.join(CLEAN_DATA_DIR, 'SSE50')
check_and_make_directories([from_dir, to_dir])
na_list = clean(from_dir, to_dir)
print('\n', len(na_list))
' '.join(na_list)

73 rows droped from 600010.SS.csv.
7 rows droped from 600028.SS.csv.
31 rows droped from 600030.SS.csv.
17 rows droped from 600031.SS.csv.
22 rows droped from 600036.SS.csv.
9 rows droped from 600048.SS.csv.
66 rows droped from 600104.SS.csv.
8 rows droped from 600111.SS.csv.
20 rows droped from 600196.SS.csv.
5 rows droped from 600276.SS.csv.
136 rows droped from 600309.SS.csv.
303 rows droped from 600346.SS.csv.
142 rows droped from 600406.SS.csv.
20 rows droped from 600436.SS.csv.
159 rows droped from 600438.SS.csv.
5 rows droped from 600519.SS.csv.
16 rows droped from 600585.SS.csv.
96 rows droped from 600690.SS.csv.
607 rows droped from 600745.SS.csv.
600745.SS.csv: too many NaNs, discard.
14 rows droped from 600809.SS.csv.
36 rows droped from 600887.SS.csv.
372 rows droped from 600893.SS.csv.
600893.SS.csv: too many NaNs, discard.
368 rows droped from 600900.SS.csv.
600900.SS.csv: too many NaNs, discard.
3268 rows droped from 600905.SS.csv.
600905.SS.csv: too many NaNs, discard.


'600745.SS.csv 600893.SS.csv 600900.SS.csv 600905.SS.csv 601012.SS.csv 601066.SS.csv 601225.SS.csv 601288.SS.csv 601633.SS.csv 601668.SS.csv 601669.SS.csv 601688.SS.csv 601888.SS.csv 601995.SS.csv 603259.SS.csv 603260.SS.csv 603288.SS.csv 603501.SS.csv 603799.SS.csv 603986.SS.csv 688599.SS.csv'

In [57]:
from_dir = os.path.join(RAW_DATA_DIR, 'SZSEGrowth40')
to_dir = os.path.join(CLEAN_DATA_DIR, 'SZSEGrowth40')
check_and_make_directories([from_dir, to_dir])
na_list = clean(from_dir, to_dir)
print('\n', len(na_list))
' '.join(na_list)

39 rows droped from 000661.SZ.csv.
21 rows droped from 000725.SZ.csv.
58 rows droped from 002030.SZ.csv.
278 rows droped from 002049.SZ.csv.
464 rows droped from 002129.SZ.csv.
002129.SZ.csv: too many NaNs, discard.
182 rows droped from 002271.SZ.csv.
706 rows droped from 002414.SZ.csv.
002414.SZ.csv: too many NaNs, discard.
715 rows droped from 002460.SZ.csv.
002460.SZ.csv: too many NaNs, discard.
686 rows droped from 002475.SZ.csv.
002475.SZ.csv: too many NaNs, discard.
1039 rows droped from 002555.SZ.csv.
002555.SZ.csv: too many NaNs, discard.
1497 rows droped from 002709.SZ.csv.
002709.SZ.csv: too many NaNs, discard.
1500 rows droped from 002714.SZ.csv.
002714.SZ.csv: too many NaNs, discard.
2561 rows droped from 002932.SZ.csv.
002932.SZ.csv: too many NaNs, discard.
495 rows droped from 300014.SZ.csv.
300014.SZ.csv: too many NaNs, discard.
593 rows droped from 300059.SZ.csv.
300059.SZ.csv: too many NaNs, discard.
739 rows droped from 300122.SZ.csv.
300122.SZ.csv: too many NaNs, dis

'002129.SZ.csv 002414.SZ.csv 002460.SZ.csv 002475.SZ.csv 002555.SZ.csv 002709.SZ.csv 002714.SZ.csv 002932.SZ.csv 300014.SZ.csv 300059.SZ.csv 300122.SZ.csv 300450.SZ.csv 300502.SZ.csv 300593.SZ.csv 300604.SZ.csv 300630.SZ.csv 300638.SZ.csv 300661.SZ.csv 300671.SZ.csv 300676.SZ.csv 300677.SZ.csv 300724.SZ.csv 300763.SZ.csv 300769.SZ.csv 300772.SZ.csv 300782.SZ.csv 300850.SZ.csv 300869.SZ.csv 300888.SZ.csv 300894.SZ.csv 300896.SZ.csv 300957.SZ.csv 300973.SZ.csv 301050.SZ.csv 301080.SZ.csv'

# Feature engineering

In [178]:
# columns after init_all()
df = pd.read_csv('./datasets/clean/000001.SZ.csv', index_col=False)
stats = stockstats.StockDataFrame.retype(df)
stats.init_all()
stats.columns

Index(['open', 'high', 'low', 'close', 'volume', 'change', 'rs_14', 'rsi',
       'rsi_14', 'stochrsi', 'rate', 'middle', 'tp', 'boll', 'boll_ub',
       'boll_lb', 'macd', 'macds', 'macdh', 'ppo', 'ppos', 'ppoh', 'rsv_9',
       'kdjk_9', 'kdjk', 'kdjd_9', 'kdjd', 'kdjj_9', 'kdjj', 'cr', 'cr-ma1',
       'cr-ma2', 'cr-ma3', 'cci', 'tr', 'atr', 'high_delta', 'um', 'low_delta',
       'dm', 'pdm', 'pdm_14_ema', 'pdm_14', 'atr_14', 'pdi_14', 'pdi', 'mdm',
       'mdm_14_ema', 'mdm_14', 'mdi_14', 'mdi', 'dx_14', 'dx', 'adx', 'adxr',
       'trix', 'tema', 'vr', 'close_10_sma', 'close_50_sma', 'dma', 'vwma',
       'chop', 'log-ret', 'mfi', 'wt1', 'wt2', 'wr', 'supertrend_ub',
       'supertrend_lb', 'supertrend'],
      dtype='object')

In [283]:
X_y_filename = 'x_y.csv'
X_y_path = os.path.join(DATA_SAVE_DIR, X_y_filename)

if not os.path.exists(X_y_path):
    X_y = None

    _, _, files = next(os.walk(CLEAN_DATA_DIR))
    for file in files:
        file_path = os.path.join(CLEAN_DATA_DIR, file)
        df = pd.read_csv(file_path, index_col=False)
        stats = stockstats.StockDataFrame.retype(df)
        stats.init_all()

        # drop duplicated columns
        stats.drop_column(['rsi', 'kdjk', 'kdjd', 'kdjj'], inplace=True)

        # add additional indicators: close_14_smma, close_14_mstd, close_14_mvar,
        # close_5_sma, wr_6, rsi_6,
        # log differential of high, low, open and volume
        # and log2(close / open)
        stats['close_14_smma']; stats['close_14_mstd']; stats['close_14_mvar'];
        stats['close_5_sma'];   stats['wr_6'];  stats['rsi_6']
        stats['log_diff_high'] = np.log2(stats['high'] / stats['high_-1_s'])
        stats['log_diff_low']= np.log2(stats['low'] / stats['low_-1_s'])
        stats['log_diff_open']= np.log2(stats['open'] / stats['open_-1_s'])
        stats['log_diff_vol']= np.log2(stats['volume'] / stats['volume_-1_s'])
        stats['log_close/open'] = np.log2(stats['close'] / stats['open'])
        stats.drop_column(['high_-1_s', 'low_-1_s', 'open_-1_s'], inplace=True)
        stats['log-ret_1_s']
        stats.rename(columns={'log-ret_1_s': 'y'}, inplace=True)

        # drop date
        stats.reset_index(drop=True, inplace=True)

        # deal with nan
        stats.dropna(inplace=True)

        if X_y is None:
            X_y = stats.copy()
        else:
            X_y = pd.concat([X_y, stats])
        print(f'Add {file} to X_y.')

X_y.to_csv(X_y_path, index=False)

Added 000001.SZ.csv.
Added 000002.SZ.csv.
Added 000063.SZ.csv.
Added 000069.SZ.csv.
Added 000100.SZ.csv.
Added 000157.SZ.csv.
Added 000166.SZ.csv.
Added 000301.SZ.csv.
Added 000338.SZ.csv.
Added 000425.SZ.csv.
Added 000538.SZ.csv.
Added 000568.SZ.csv.
Added 000596.SZ.csv.
Added 000625.SZ.csv.
Added 000651.SZ.csv.
Added 000661.SZ.csv.
Added 000708.SZ.csv.
Added 000723.SZ.csv.
Added 000725.SZ.csv.
Added 000733.SZ.csv.
Added 000768.SZ.csv.
Added 000786.SZ.csv.
Added 000800.SZ.csv.
Added 000858.SZ.csv.
Added 000876.SZ.csv.
Added 000877.SZ.csv.
Added 000895.SZ.csv.
Added 000938.SZ.csv.
Added 000963.SZ.csv.
Added 000977.SZ.csv.
Added 001979.SZ.csv.
Added 002001.SZ.csv.
Added 002007.SZ.csv.
Added 002008.SZ.csv.
Added 002027.SZ.csv.
Added 002032.SZ.csv.
Added 002049.SZ.csv.
Added 002050.SZ.csv.
Added 002064.SZ.csv.
Added 002074.SZ.csv.
Added 002120.SZ.csv.
Added 002142.SZ.csv.
Added 002179.SZ.csv.
Added 002180.SZ.csv.
Added 002202.SZ.csv.
Added 002230.SZ.csv.
Added 002236.SZ.csv.
Added 002241.

# Feature selection

In [None]:
X_y_filename = 'x_y.csv'
X_y_path = os.path.join(DATA_SAVE_DIR, X_y_filename)
if X_y is None:
    X_y = pd.read_csv(X_y_path, index_col=False)

In [None]:
model = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)

feat_selector = BorutaPy(
    verbose=2,
    estimator=model,
    n_estimators='auto',
    max_iter=10
)

X = np.array(X_y.drop(labels=['y']))
y = np.array(X_y['y'])
feat_selector.fit(X, y)

# print support and ranking for each feature
print("\n------Support and Ranking for each feature------")
for i in range(len(feat_selector.support_)):
    if feat_selector.support_[i]:
        print("Passes the test: ", X.columns[i],
              " - Ranking: ", feat_selector.ranking_[i])
    else:
        print("Doesn't pass the test: ",
              X.columns[i], " - Ranking: ", feat_selector.ranking_[i])

Used Features:
* volume
* ppo
* cr-ma3
* trix 
* log_diff_high
* log_diff_low
* log_diff_open
* log_close/open

# Preprocess data

In [59]:
def preprocess(from_dir, to_dir):
    df_dict = {}
    _, _, files = next(os.walk(from_dir))
    for file in files:
        # skip if already exists
        processed_file_path = os.path.join(to_dir, file)
        if os.path.exists(processed_file_path):
            continue
        
        # load
        clean_file_path = os.path.join(from_dir, file)
        stats = pd.read_csv(clean_file_path, index_col=False)
        df = pd.DataFrame(index=stats['date'])

        stats = stockstats.StockDataFrame.retype(stats)
        df['change'] = stats['change']

        # add indicators
        df['ppo'] = stats['ppo']
        df['cr-ma3'] = stats['cr-ma3']
        df['trix'] = stats['trix']

        # add differential features
        df['log_close/open'] = np.log2(stats['close'] / stats['open'])
        df['log-ret'] = stats['log-ret']
        df['log_diff_high'] = np.log2(stats['high'] / stats['high_-1_s'])
        df['log_diff_low'] = np.log2(stats['low'] / stats['low_-1_s'])
        df['log_diff_open'] = np.log2(stats['open'] / stats['open_-1_s'])

        # clean
        df.dropna(inplace=True)
        df.reset_index(inplace=True)

        # save
        df.to_csv(processed_file_path, index=False)
        tic = file.split('.')[0]
        df_dict[tic] = df.copy()
    return df_dict

In [62]:
from_dir = os.path.join(CLEAN_DATA_DIR, 'SSE50')
to_dir = os.path.join(PREPROCESSED_DATA_DIR, 'SSE50')
check_and_make_directories(to_dir)
df_dict_SSE50 = preprocess(from_dir, to_dir)

In [None]:
from_dir = os.path.join(CLEAN_DATA_DIR, 'SZSEGrowth40')
to_dir = os.path.join(PREPROCESSED_DATA_DIR, 'SZSEGrowth40')
check_and_make_directories(to_dir)
df_dict_SZSE40 = preprocess(from_dir, to_dir)

# Setup environment

In [3]:
dataset_dir = os.path.join(PREPROCESSED_DATA_DIR, 'SSE50')
# Just load data
if True or 'df_dict' not in locals():
    df_dict = {}
    _, _, files = next(os.walk(dataset_dir))
    for file in files:
        processed_file_path = os.path.join(dataset_dir, file)   
        df = pd.read_csv(processed_file_path, index_col=False)
        assert df.isna().sum().sum() == 0, f'Nan found in {file}.'
        tic = file.replace('.csv', '')
        df_dict[tic] = df.copy()

In [4]:
# Split data
df_dict_train = dict()
df_dict_test = dict()
df_dict_trade = dict()

TEST_START_DAY = pd.to_datetime(TEST_START_DAY, format='%Y-%m-%d')
TRADE_START_DAY = pd.to_datetime(TRADE_START_DAY, format='%Y-%m-%d')

for tic, df in df_dict.items():
    df.date = pd.to_datetime(df.date, format='%Y-%m-%d')
    df_dict_train[tic] = df.loc[df.date < TEST_START_DAY].sort_index(ascending=True).copy()
    df_dict_test[tic] = df.loc[(df.date >= TEST_START_DAY) & (df.date < TRADE_START_DAY)].sort_index(ascending=True).copy()
    df_dict_trade[tic] = df.loc[df.date >= TRADE_START_DAY].sort_index(ascending=True).copy()

In [5]:
def get_envs(n_tickers: int = 10) -> Tuple[MultiStockTradingEnv, MultiStockTradingEnv, MultiStockTradingEnv]:
    assert n_tickers <= len(df_dict_train)

    env_list = list()
    tic_list = random.sample(df_dict_train.keys(), n_tickers)
    for _df_dict in [df_dict_train, df_dict_test, df_dict_trade]:
        _dfs = list()
        for tic in tic_list:
            _df = _df_dict[tic]
            _df['tic'] = tic
            _dfs.append(_df)
        _dfs = pd.concat(_dfs)
        # drop dates that missing data
        _dfs = _dfs.pivot_table(df, index=['date'], columns=['tic']).dropna().stack().reset_index()
        _dfs.sort_values(['date', 'tic'], inplace=True)
        _dfs.set_index(['date', 'tic'], inplace=True)
        env_list.append(Monitor(MultiStockTradingEnv(_dfs)))
    
    return tuple(env_list)

# Hyper parameter tuning

In [6]:
VERBOSE = 0

In [7]:
def objective_factory(
    model_name: str, 
    model_class: Type[BaseAlgorithm], 
    sample_param_func: Callable[[optuna.Trial], Tuple[Dict, int]],
    ) -> Callable[[optuna.Trial], float]:
    
    def objective(trial: optuna.Trial):
        model_path = os.path.join(MODEL_DIR, model_name)
        model_path = os.path.join(model_path, f'trial_{trial.number}_best_model')
        tb_log_path = os.path.join(TENSORBOARD_LOG_DIR, model_name)
        check_and_make_directories([model_path, tb_log_path])

        # Create model with sampled hyperparameters and 
        # train it with early stop callback    
        hyperparameters, total_timesteps = sample_param_func(trial)
        hyperparameters['tensorboard_log'] = '/root/tf-logs' # tb_log_path # or 

        env_train, env_test, _ =  get_envs()
        model = model_class('MlpPolicy', env_train, **hyperparameters)

        stop_train_callback = StopTrainingOnNoModelImprovement(
            max_no_improvement_evals=4, min_evals=2, verbose=VERBOSE)
        eval_callback = EvalCallback(
            env_test, 
            callback_after_eval=stop_train_callback,
            n_eval_episodes=3,
            eval_freq=10000,
            best_model_save_path=model_path, 
            verbose=VERBOSE
            )

        # deal with gradient explosion
        try:
            model.learn(total_timesteps=total_timesteps, 
                tb_log_name=f'{model_name}_{trial.number}', callback=eval_callback)
        except ValueError as e:
            print(e)
            return -99
        except RuntimeError as e:
            print(e)
            return -99

        # validation
        mean_reward, _ = evaluate_policy(model, env_test, n_eval_episodes=3)

        return mean_reward

    return objective

In [8]:
def tune(
    model_name: str, 
    model_class: Type[BaseAlgorithm],
    sample_param_func: Callable[[optuna.Trial], Any],
    n_trials: int = 100, 
    callbacks: List[Callable] = None
    ) -> optuna.Study:

    sampler = optuna.samplers.TPESampler(seed=None)
    objective = objective_factory(model_name, model_class, sample_param_func)

    study_name = f'{model_name}_study'
    storage_name = f'sqlite:///{study_name}.db'
    study = optuna.create_study(
        study_name=study_name, 
        direction='maximize',
        sampler=sampler,
        pruner=optuna.pruners.HyperbandPruner(),
        storage=storage_name
        )
    study.optimize(
        objective, 
        n_trials=n_trials,
        callbacks=callbacks,
        )

    return study

In [9]:
# To try to avoid CUDA OOM
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'max_split_size_mb:64'

In [10]:
# study_A2C = tune('A2C', A2C, \
#     sample_param_func=sample_a2c_param)

# plot_optimization_history(study_A2C)
# plot_param_importances(study_A2C)

study_SAC = tune('SAC', SAC, \
    sample_param_func=sample_sac_param)

plot_optimization_history(study_SAC)
plot_param_importances(study_SAC)

# study_TQC = tune('TQC', TQC, \
#     sample_param_func=sample_tqc_param)

# plot_optimization_history(study_TQC)
# plot_param_importances(study_TQC)

[32m[I 2023-02-04 09:31:17,409][0m A new study created in RDB with name: SAC_study[0m
since Python 3.9 and will be removed in a subsequent version.
  tic_list = random.sample(df_dict_train.keys(), n_tickers)


In [None]:
# TODO: test with strict condition
early_stop_callback = PruneCallback(
    threshold=1,
    patience=1,
    trial_number=1
    )

In [2]:
study_TQC = optuna.create_study(
    study_name='TQC_study', 
    direction='maximize', 
    sampler=optuna.samplers.TPESampler(seed=None),
    pruner=optuna.pruners.HyperbandPruner(),
    storage='sqlite:///TQC_study.db',
    load_if_exists=True
)

[32m[I 2023-02-04 09:19:10,566][0m Using an existing study with name 'TQC_study' instead of creating a new one.[0m


# Train models

In [None]:
from stable_baselines3.a2c import A2C
env_train, env_test, env_trade = get_envs()
model = A2C('MlpPolicy', env_train)
model.learn(total_timesteps=1000)

<stable_baselines3.a2c.a2c.A2C at 0x7efa7a793c10>

# Backtest

In [None]:
# df_t = dfs_test[3]
# list_asset, actions = simulate_trading_masked(env_factory([df_t]), model)
# sr_asset = pd.Series(list_asset)
# sr_return = get_daily_return(sr_asset)
# backtest_stats(sr_return)
# sr_baseline_return = get_daily_return(df_t.close).dropna()
# sr_baseline_return = sr_baseline_return[len(sr_baseline_return) - len(sr_asset):]
# backtest_stats(sr_baseline_return)
# %matplotlib inline
# sr_date = df_t.date
# sr_date = sr_date[len(sr_date) - len(sr_asset):]
# sr_return.set_axis(sr_date, inplace=True)
# sr_baseline_return.set_axis(sr_date, inplace=True)
# backtest_plot(sr_return, sr_baseline_return)
# sum(actions)

# Plot