In [4]:
import os
import math
import time
import datetime
import argparse

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn import metrics
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import talib as ta
import qlib
from qlib.data import D
from qlib.data.dataset.loader import QlibDataLoader
from qlib.constant import REG_CN, REG_US

provider_uri = "F:/qlib/qlib_data/us_data" # data dir
qlib.init(provider_uri=provider_uri, region=REG_US)
torch.set_num_threads(6)

In [3]:
from data_utlis import *
from eval_utlis import *
from model import HDAT, TotalLoss

In [3]:
parser = argparse.ArgumentParser(description='Stock Forecasting')

parser.add_argument('--prestart_time', type=str, default='2000-01-01')
parser.add_argument('--start_time', type=str, default='2004-10-31')
parser.add_argument('--end_time', type=str, default='2020-01-01')
parser.add_argument('--lagend_time', type=str, default='2020-10-31')
parser.add_argument('--save_path', type=str, default='./output')
parser.add_argument('--adj_path', type=str, default='./adj_rolling')
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--device', type=str, default='cuda:1')
parser.add_argument('--window_size', type=int, default=1)
parser.add_argument('--window_size_o', type=int, default=1)

args = parser.parse_args(args=[
    '--save_path', './output',
    '--window_size', '24', # Tr
    '--window_size_o', '12', # To
    '--prestart_time', '2013-06-01',
    '--start_time', '2014-01-01',
    '--end_time', '2019-12-31',
    '--lagend_time', '2020-10-30',
    '--device', 'cpu'
    ])

In [4]:
company_1 = []
com_path = './company_pool.txt'
with open(com_path, 'r', encoding='utf-8') as f:
    for line in f:
        line = line.replace(',', '')
        company_1.append(line.strip())
company_1.sort()

company_pool = get_base_company(args=args)
company_nonan, _ = get_stocks_nonan(company_1=company_pool, args=args)
selected_tickers = company_1

In [5]:
## features
rmv_feature_num = 6
features, labels, company_final, final_timestamps = get_features_n_labels(args=args, selected_tickers=selected_tickers)
binary_labels = (labels > 0) * torch.ones_like(labels)

Loading base technical data...
Loading indicators...


## performance evaluation

In [1]:
## hyper param
learning_rate = 0.0001
# weight_decay = 1e-4
total_epoch = 500
dropout = 0

device = "cuda:1" if torch.cuda.is_available() else "cpu"
print("Device: '{}'.".format(device))

In [7]:
def main(split_time, pprint):
    ## test window
    test_start = split_time[4]
    test_end = split_time[5]

    ## output path
    output_path = args.save_path + '/' + test_start + '__' + test_end
    model_filename = output_path + '/' + 'gdat.pt'
    train_log_filename = output_path + '/' + 'gdat.txt'
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        print("Output dir '{}' has been created.".format(output_path))
    else:
        print("Output dir '{}' is existed.".format(output_path))

    print("Creating model...")
    
    n_feat = features.size(2) - rmv_feature_num

    model = HDAT(n_feat=n_feat, dropout=dropout, args=args).to(device)

    criterion = TotalLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) #, weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    print("Done.")

    evaluate(model, criterion, optimizer, scheduler, features, labels, final_timestamps, args, rmv_feature_num, split_time, train_log_filename, model_filename, total_epoch, pprint)

In [2]:
day_start = '2019-02-01'
rolling_months = 1
for _ in range(100):
    if day_start >= '2020-01-01':
        break
    print(day_start)
    split_time = get_split_time(test_start=day_start, test_months_size=rolling_months)

    main(split_time=split_time, pprint=False)
    dt = datetime.datetime.strptime(day_start, '%Y-%m-%d')
    dt = dt + relativedelta(months=rolling_months)
    day_start = dt.strftime('%Y-%m-%d')