# Stock Trading test

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from elegantrl.train.run_tutorial import *
from elegantrl.train.run_parallel import *
from elegantrl.train.config import Arguments
from elegantrl.agents import AgentPPO, AgentDDPG
from stock_env_multiple import StockEnvMultiple

np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})

In [2]:
tics = pd.read_csv('data/filtered_with_ti.csv', header=0)
all_tickers = tics['tic'].unique()
daterange = tics[tics['tic'] == 'AAPL']['Date'].to_numpy()

def check_tickers(stock):
    return stock in all_tickers

def check_date(date):
    return date in daterange

In [None]:
check_date('2006-06-16')

In [None]:
check_tickers('TSLA')

In [None]:
tics['Close']

In [None]:
# Arguments
tics[tics['Close'] > 1e4]

In [4]:
used_tickers = ['PNW', 'BBY', 'BIO', 'RHI', 'CI', 'CSX', 'KO', 'CCZ', 'CMA', 'ETN',
               'XOM', 'FDX', 'FRT', 'MTB', 'AJG', 'HGM', 'ITW', 'IFF', 'KSU', 'CVS',
               'NEM', 'XEL', 'OKE', 'PNC', 'PEP', 'MO', 'RGE', 'SO', 'SYK', 'TER',
               'TSN', 'AEE', 'WST', 'WDC', 'JKHY', 'TFC', 'MS', 'TROW', 'UNN', 'CDNS']

for tic in used_tickers:
    print(tic, check_tickers(tic))

# used_tickers = ['MSFT', 'AAPL', 'TSLA']


tickers = {x: 0 for x in used_tickers}
initial_stocks = np.zeros([len(tickers), 1])


gamma = 0.99
max_stock = 1e2
initial_capital = 1e6
initial_stocks = np.zeros(len(tickers), dtype=np.float32)

start_date = '2006-01-03'
end_date = '2015-06-16'
start_eval_date = '2016-01-01'
end_eval_date = '2021-01-01'

# Agent
ppo = AgentPPO.AgentPPO() # AgentSAC(), AgentTD3(), AgentDDPG()?
ppo.if_use_gae = True
ppo.lambda_entropy = 0.04
ppo.if_on_policy = True

# Environment
env = StockEnvMultiple(tickers=tickers, begin_date=start_date, end_date=end_date)
env.max_step = 100

args = Arguments(env, ppo)


# Hyperparameters
args.gamma = gamma
# args.cwd = './models/RLStockPPO_v3/'
args.break_step = int(2e5)
args.net_dim = 2 ** 9
args.max_step = args.env.max_step
args.max_memo = args.max_step * 4
args.reward_scale = 1e-4
args.batch_size = 2 ** 10
args.repeat_times = 1
args.eval_gap = 2 ** 4
args.eval_times1 = 2 ** 3
args.eval_times2 = 2 ** 5
args.worker_num = 4
args.thread_num = 16
args.if_allow_break = False
args.target_return = 1.1e7
args.rollout_num = 4 # the number of rollout workers (larger is not always faster)

PNW True
BBY True
BIO True
RHI True
CI True
CSX True
KO True
CCZ True
CMA True
ETN True
XOM True
FDX True
FRT True
MTB True
AJG True
HGM True
ITW True
IFF True
KSU True
CVS True
NEM True
XEL True
OKE True
PNC True
PEP True
MO True
RGE True
SO True
SYK True
TER True
TSN True
AEE True
WST True
WDC True
JKHY True
TFC True
MS True
TROW True
UNN True
CDNS True
Using stocks: ['PNW', 'BBY', 'BIO', 'RHI', 'CI', 'CSX', 'KO', 'CCZ', 'CMA', 'ETN', 'XOM', 'FDX', 'FRT', 'MTB', 'AJG', 'HGM', 'ITW', 'IFF', 'KSU', 'CVS', 'NEM', 'XEL', 'OKE', 'PNC', 'PEP', 'MO', 'RGE', 'SO', 'SYK', 'TER', 'TSN', 'AEE', 'WST', 'WDC', 'JKHY', 'TFC', 'MS', 'TROW', 'UNN', 'CDNS']


In [None]:
env.reset()
count = 0
out = False
while not out:
    state, _, out, _ = env.step(np.array([0.1] * 100))
    test = state.sum()
    if np.isnan(test):
        break
    count += 1
env.state
count

In [None]:
env.state

In [None]:
a = time.time()
train_and_evaluate_mp(args)
print(time.time() - a)

| Remove cwd: ./AgentPPO_RLStockEnv-v3_(0,)


In [None]:
args.env.stocks

In [None]:
start_eval_date = '2016-01-05'
end_eval_date = '2021-01-05'
# used_tickers = ['PNW', 'BBY', 'BIO', 'RHI', 'CI', 'CSX', 'KO', 'CCZ', 'CMA', 'ETN',
#            'XOM', 'FDX', 'FRT', 'MTB', 'AJG', 'HGM', 'ITW', 'IFF', 'KSU', 'CVS',
#            'NEM', 'XEL', 'OKE', 'PNC', 'PEP', 'MO', 'RGE', 'SO', 'SYK', 'TER',
#            'TSN', 'AEE', 'WST', 'WDC', 'JKHY', 'TFC', 'MS', 'TROW', 'UNN', 'CDNS',
#            'DRE', 'ABMD', 'WRB', 'VLO', 'PBCT', 'PTC', 'XLNX', 'AZO', 'REGN',
#            'AES', 'STE', 'DHI', 'COST', 'EMN', 'ABC', 'WAB', 'HSIC', 'EL', 'RL',
#            'BXP', 'MTD', 'VRSN', 'RSG', 'MCO', 'GRMN', 'MDLZ', 'LVS', 'CE', 'UAA',
#            'FIS', 'CBRE']

tickers = {x: 0 for x in used_tickers}



agent = AgentPPO.AgentPPO() # AgentSAC(), AgentTD3(), AgentDDPG()?
agent.if_use_gae = True
agent.lambda_entropy = 0.04

initial_stocks = np.zeros([len(tickers)]).tolist()
print(initial_stocks)

env_eval = StockEnvMultiple(tickers=tickers,
                            begin_date=start_eval_date,
                            end_date=end_eval_date)


agent.if_on_policy = False
args = Arguments(env_eval, agent)

args.if_remove = False
args.cwd = './AgentPPO_RLStockEnv-v3_(0,)'
args.gamma = gamma
args.break_step = int(2e5)
args.net_dim = 2 ** 9
args.max_step = args.env.max_step
args.max_memo = args.max_step * 4
args.batch_size = 2 ** 10
args.repeat_times = 2 ** 3
args.eval_gap = 2 ** 4
args.eval_times1 = 2 ** 3
args.eval_times2 = 2 ** 5
args.if_allow_break = False
args.target_return = 1.1e7

args.rollout_num = 6 # the number of rollout workers (larger is not always faster)
args.init_before_training()

env_eval.draw_cumulative_return(args, torch, 'PPO')
plt.show()

In [None]:
plt.show()