# Stock Trading test

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from elegantrl.train.run_tutorial import *
from elegantrl.train.run_parallel import *
from elegantrl.train.config import Arguments
from elegantrl.agents import AgentPPO, AgentDDPG
from stock_env_multiple import StockEnvMultiple

np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})

In [2]:
tics = pd.read_csv('data/final_filtered.csv', header=0)
all_tickers = tics['tic'].unique()
daterange = tics[tics['tic'] == 'AAPL']['Date'].to_numpy()

In [None]:
all_tickers[:500:5]

In [None]:
def check_tickers(stock):
    return stock in all_tickers

def check_date(date):
    return date in daterange

In [None]:
check_date('2006-06-16')

In [None]:
check_tickers('TSLA')

In [None]:
# Arguments
tics

In [3]:
used_tickers = ['PNW', 'BBY', 'BIO', 'RHI', 'CI', 'CSX', 'KO', 'CCZ', 'CMA', 'ETN',
           'XOM', 'FDX', 'FRT', 'MTB', 'AJG', 'HGM', 'ITW', 'IFF', 'KSU', 'CVS',
           'NEM', 'XEL', 'OKE', 'PNC', 'PEP', 'MO', 'RGE', 'SO', 'SYK', 'TER',
           'TSN', 'AEE', 'WST', 'WDC', 'JKHY', 'TFC', 'MS', 'TROW', 'UNN', 'CDNS',
           'DRE', 'ABMD', 'WRB', 'VLO', 'PBCT', 'PTC', 'XLNX', 'AZO', 'REGN',
           'AES', 'STE', 'DHI', 'COST', 'EMN', 'ABC', 'WAB', 'HSIC', 'EL', 'RL',
           'BXP', 'MTD', 'VRSN', 'RSG', 'MCO', 'GRMN', 'MDLZ', 'LVS', 'CE', 'UAA',
           'FIS', 'CBRE']

used_tickers = ['MSFT', 'AAPL', 'TSLA']

tickers = {x: 0 for x in used_tickers}
initial_stocks = np.zeros([len(tickers), 1])


gamma = 0.99
max_stock = 1e2
initial_capital = 1e6
initial_stocks = np.zeros(len(tickers), dtype=np.float32)

start_date = '2006-01-03'
end_date = '2015-06-16'
start_eval_date = '2016-01-01'
end_eval_date = '2021-01-01'

# Agent
ppo = AgentPPO.AgentPPO() # AgentSAC(), AgentTD3(), AgentDDPG()?
ppo.if_use_gae = True
ppo.lambda_entropy = 0.04
ppo.if_on_policy = True

# Environment
env = StockEnvMultiple(tickers=tickers, begin_date=start_date, end_date=end_date)
env.max_step = 100

args = Arguments(env, ppo)


# Hyperparameters
args.gamma = gamma
args.cwd = './models/RLStockPPO_v3/'
args.break_step = int(2e5)
args.net_dim = 2 ** 9
args.max_step = args.env.max_step
args.max_memo = args.max_step * 4
args.reward_scale = 1e-4
args.batch_size = 2 ** 10
args.repeat_times = 1
args.eval_gap = 2 ** 4
args.eval_times1 = 2 ** 3
args.eval_times2 = 2 ** 5
args.worker_num = 4
args.thread_num = 16
args.if_allow_break = False
args.target_return = 1.1e7
args.rollout_num = 4 # the number of rollout workers (larger is not always faster)

Using stocks: ['MSFT', 'AAPL', 'TSLA']


In [4]:
day_data = env._get_data_dict(env.get_index('2006-01-04'))
env._build_state(day_data)
env.step(np.array([0.1] * 100))

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-01-04', 'Volume': 22110430.0, 'Close': 74.97, 'High': 75.98, 'Low': 74.5, 'Open': 74.96, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 50823260.0, 'cci': nan, 'adx': nan, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-01-04', 'Volume': 57967200.0, 'Close': 26.97, 'High': 27.0801, 'Low': 26.77, 'Open': 26.77, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 137937750.0, 'cci': nan, 'adx': nan, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-01-04', 'Volume': 22110430.0, 'Close': 74.97, 'High': 75.98, 'Low': 74.5, 'Open': 74.96, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 50823260.0, 'cci': nan, 'adx': nan, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-01-04', 'Volume': 57967200.0, 'Close': 26.97, 'High': 27.0801, 'Low': 26.77, 'Open': 26.77, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 1

(array([998980.60, 10.00, 10.00, 0.00, 26.77, 74.96, 0.00, 26.97, 74.97,
        0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 137937750.00,
        50823260.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]),
 0.0,
 False,
 {})

In [5]:
a = time.time()
train_and_evaluate_mp(args)
print(time.time() - a)

| Remove cwd: ./models/RLStockPPO_v3/
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-01-04', 'Volume': 22110430.0, 'Close': 74.97, 'High': 75.98, 'Low': 74.5, 'Open': 74.96, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 50823260.0, 'cci': nan, 'adx': nan, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-01-04', 'Volume': 57967200.0, 'Close': 26.97, 'High': 27.0801, 'Low': 26.77, 'Open': 26.77, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 137937750.0, 'cci': nan, 'adx': nan, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-01-04', 'Volume': 22110430.0, 'Close': 74.97, 'High': 75.98, 'Low': 74.5, 'Open': 74.96, 'rsi': nan, 'macd': nan, 'macdsignal': nan, 'macdhist': nan, 'obv': 50823260.0, 'cci': nan, 'adx': nan, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-01-04', 'Volume': 57967200.0, 'Close': 26.97, 'High': 27.0801, 'Low': 26.77, 'Open': 26.77, 'rsi': nan, 'macd': nan, 'macds

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-03-24', 'Volume': 38281340.0, 'Close': 59.96, 'High': 60.94, 'Low': 59.03, 'Open': 60.25, 'rsi': 31.49231139365696, 'macd': -2.609892538862347, 'macdsignal': -2.285335560439872, 'macdhist': -0.3245569784224758, 'obv': -235591120.0, 'cci': -157.75438864587056, 'adx': 27.31865178821356, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-03-24', 'Volume': 69156990.0, 'Close': 27.01, 'High': 27.21, 'Low': 26.62, 'Open': 26.715, 'rsi': 43.70404726081485, 'macd': 0.080314677051998, 'macdsignal': 0.1082771031808004, 'macdhist': -0.0279624261288023, 'obv': -163216250.0, 'cci': -92.929724420011, 'adx': 18.097008586308565, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-03-27', 'Volume': 39570770.0, 'Close': 59.51, 'High': 61.38, 'Low': 59.4, 'Open': 60.335, 'rsi': 31.84058243949772, 'macd': -2.719236553531509, 'macdsignal': -2.372115759058199, 'macdhist': -0.3471207944733101, 'obv': -275161890.0, 'cci': -122

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-07-12', 'Volume': 33088870.0, 'Close': 52.96, 'High': 55.24, 'Low': 52.9176, 'Open': 55.17, 'rsi': 33.95126029261755, 'macd': -1.734762032719097, 'macdsignal': -1.5701827378415314, 'macdhist': -0.1645792948775652, 'obv': -510963611.0, 'cci': -161.60941238147612, 'adx': 30.80215759903111, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-07-12', 'Volume': 77374580.0, 'Close': 22.64, 'High': 22.88, 'Low': 22.62, 'Open': 22.79, 'rsi': 44.235449194244445, 'macd': 0.0605441957215902, 'macdsignal': 0.0273961510232345, 'macdhist': 0.0331480446983556, 'obv': -899947400.0, 'cci': -105.77278149049894, 'adx': 16.32350517057287, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-07-12', 'Volume': 33088870.0, 'Close': 52.96, 'High': 55.24, 'Low': 52.9176, 'Open': 55.17, 'rsi': 33.95126029261755, 'macd': -1.734762032719097, 'macdsignal': -1.5701827378415314, 'macdhist': -0.1645792948775652, 'obv': -510963611.0, 'cc

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-11-10', 'Volume': 13190950.0, 'Close': 83.12, 'High': 83.6, 'Low': 82.5, 'Open': 83.55, 'rsi': 67.70518326420725, 'macd': 1.820642669100493, 'macdsignal': 1.7067309358014708, 'macdhist': 0.1139117332990227, 'obv': -202134241.0, 'cci': 133.14783105022826, 'adx': 30.72746633087045, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2006-11-10', 'Volume': 36928550.0, 'Close': 29.24, 'High': 29.29, 'Low': 29.15, 'Open': 29.175, 'rsi': 70.39013704928206, 'macd': 0.4283221980409166, 'macdsignal': 0.426032708804716, 'macdhist': 0.0022894892362005, 'obv': 854098220.0, 'cci': 133.11804197416046, 'adx': 49.33242884105319, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2006-10-17', 'Volume': 17166050.0, 'Close': 74.29, 'High': 75.27, 'Low': 74.04, 'Open': 75.04, 'rsi': 56.76134810564377, 'macd': 0.8180784401629353, 'macdsignal': 1.1716103877592252, 'macdhist': -0.3535319475962899, 'obv': -403109601.0, 'cci': -32.523687

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2007-02-09', 'Volume': 30717650.0, 'Close': 83.27, 'High': 86.2, 'Low': 83.21, 'Open': 85.88, 'rsi': 47.53151224779924, 'macd': -0.8969127589889041, 'macdsignal': -0.5821492865198584, 'macdhist': -0.3147634724690457, 'obv': -339122691.0, 'cci': -109.09853249475808, 'adx': 20.833670512374784, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2007-02-09', 'Volume': 69808500.0, 'Close': 28.98, 'High': 29.4, 'Low': 28.93, 'Open': 29.345, 'rsi': 38.453766488133226, 'macd': -0.2160570138923603, 'macdsignal': 0.0091878298398331, 'macdhist': -0.2252448437321935, 'obv': 565034920.0, 'cci': -130.92968642077716, 'adx': 17.763822716022663, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2007-04-12', 'Volume': 23423460.0, 'Close': 92.19, 'High': 92.31, 'Low': 90.72, 'Open': 92.04, 'rsi': 49.48166037599426, 'macd': 1.1485830550023906, 'macdsignal': 1.5178490563079492, 'macdhist': -0.3692660013055586, 'obv': -190236711.0, 'cci'

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2007-08-06', 'Volume': 32993850.0, 'Close': 135.25, 'High': 135.27, 'Low': 128.3, 'Open': 132.9, 'rsi': 47.33883637996386, 'macd': 2.2571764515648454, 'macdsignal': 3.7052210911058814, 'macdhist': -1.448044639541037, 'obv': 519783049.0, 'cci': -97.28011070520296, 'adx': 27.58011333251528, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2007-08-06', 'Volume': 59296350.0, 'Close': 29.54, 'High': 29.54, 'Low': 28.75, 'Open': 29.05, 'rsi': 39.86763618631056, 'macd': -0.227198937371277, 'macdsignal': -0.0936044891376387, 'macdhist': -0.1335944482336383, 'obv': 938398290.0, 'cci': -69.0802319779586, 'adx': 20.181070071179214, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2007-08-07', 'Volume': 33890770.0, 'Close': 135.0301, 'High': 137.24, 'Low': 132.63, 'Open': 134.94, 'rsi': 50.28179398107674, 'macd': 1.97549086004608, 'macdsignal': 3.3592750448939217, 'macdhist': -1.3837841848478425, 'obv': 485892279.0, 'cci': -

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2007-11-27', 'Volume': 46984390.0, 'Close': 174.81, 'High': 175.79, 'Low': 170.01, 'Open': 175.22, 'rsi': 53.03065622388776, 'macd': -0.1850962330415484, 'macdsignal': 0.3374714754599908, 'macdhist': -0.5225677085015392, 'obv': 498753169.0, 'cci': 51.48942695786463, 'adx': 22.50805728239153, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2007-11-27', 'Volume': 84048140.0, 'Close': 33.06, 'High': 33.6, 'Low': 32.68, 'Open': 33.27, 'rsi': 47.23203750234606, 'macd': 0.3024234560504908, 'macdsignal': 0.6617482029618714, 'macdhist': -0.3593247469113806, 'obv': 1478524110.0, 'cci': -147.67142290923465, 'adx': 28.041214280097154, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2007-11-28', 'Volume': 40897670.0, 'Close': 180.22, 'High': 180.6, 'Low': 175.35, 'Open': 176.82, 'rsi': 54.47297107826687, 'macd': 0.5497858705280407, 'macdsignal': 0.3799343544736008, 'macdhist': 0.1698515160544399, 'obv': 539650839.0, 'cci':

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



acd': -0.0087947927986213, 'macdsignal': 0.0661122109057105, 'macdhist': -0.0749070037043319, 'obv': 17223366.0, 'cci': -38.70944721250751, 'adx': 13.419734950039024, 'multiplier': 1.0}}
{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2011-08-03', 'Volume': 26379400.0, 'Close': 392.57, 'High': 393.55, 'Low': 382.24, 'Open': 390.98, 'rsi': 62.712667252946986, 'macd': 13.244110346556624, 'macdsignal': 13.49334129419251, 'macdhist': -0.2492309476358851, 'obv': 1244532721.0, 'cci': 11.584572073050984, 'adx': 35.203817100623844, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2011-08-03', 'Volume': 64563970.0, 'Close': 26.92, 'High': 27.0, 'Low': 26.48, 'Open': 26.83, 'rsi': 53.36366208547075, 'macd': 0.4312783883761355, 'macdsignal': 0.5554138204670884, 'macdhist': -0.1241354320909529, 'obv': -955072550.0, 'cci': -93.60780336325942, 'adx': 28.47076156024371, 'multiplier': 1.0}, 'TSLA': {'gvkey': 184996, 'iid': 1, 'Date': '2011-08-03', 'Volume': 1790210.0, 'Close': 27.2, 'High':

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2011-08-25', 'Volume': 31089320.0, 'Close': 373.72, 'High': 375.45, 'Low': 365.0, 'Open': 365.08, 'rsi': 45.88547979885346, 'macd': 0.1127816540846993, 'macdsignal': 1.4627934108823604, 'macdhist': -1.350011756797661, 'obv': 1139563691.0, 'cci': 11.16631008932577, 'adx': 19.66960746550488, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2011-08-25', 'Volume': 48175630.0, 'Close': 24.57, 'High': 25.16, 'Low': 24.5, 'Open': 25.08, 'rsi': 46.26223854694869, 'macd': -0.4815243858201832, 'macdsignal': -0.4008440634305434, 'macdhist': -0.0806803223896397, 'obv': -1498639630.0, 'cci': -12.698412698412929, 'adx': 35.33234766911112, 'multiplier': 1.0}, 'TSLA': {'gvkey': 184996, 'iid': 1, 'Date': '2011-08-25', 'Volume': 679740.0, 'Close': 23.11, 'High': 23.87, 'Low': 22.9, 'Open': 23.87, 'rsi': 42.72335111781079, 'macd': -1.1540487557689223, 'macdsignal': -0.9833355844692876, 'macdhist': -0.1707131712996348, 'obv': 4942771.0, 'cci': -62

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 1.0}, 'TSLA': {'gvkey': 184996, 'iid': 1, 'Date': '2013-07-15', 'Volume': 9914409.0, 'Close': 127.26, 'High': 133.259, 'Low': 126.82, 'Open': 133.03, 'rsi': 78.28929057850289, 'macd': 8.86627462804583, 'macdsignal': 8.022743127644201, 'macdhist': 0.8435315004016282, 'obv': 201022124.0, 'cci': 115.42673768965582, 'adx': 41.32232592251167, 'multiplier': 1.0}}

{'AAPL': {'gvkey': 1690, 'iid': 1, 'Date': '2013-10-01', 'Volume': 12599180.0, 'Close': 487.96, 'High': 489.14, 'Low': 478.381, 'Open': 478.45, 'rsi': 48.64708335361049, 'macd': 1.0879749713668048, 'macdsignal': 0.4456481043324518, 'macdhist': 0.6423268670343529, 'obv': 929250449.0, 'cci': 66.08206103485105, 'adx': 18.482926346275445, 'multiplier': 1.0}, 'MSFT': {'gvkey': 12141, 'iid': 1, 'Date': '2013-10-01', 'Volume': 36664230.0, 'Close': 33.58, 'High': 33.61, 'Low': 33.3, 'Open': 33.35, 'rsi': 55.83506310639174, 'macd': 0.1731514185109119, 'macdsignal': 0.0914542536533875, 'macdhist': 0.0816971648575244, 'obv': -2159707480.0, '

Process Process-6:
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/saad/envs/stocks/lib/python3.7/site-packages/elegantrl/train/run_parallel.py", line 216, in run
    traj_lists = comm_exp.explore(agent)
  File "/home/saad/envs/stocks/lib/python3.7/site-packages/elegantrl/train/run_parallel.py", line 59, in explore
    traj_lists = [pipe1.recv() for pipe1 in self.pipe1s]
  File "/home/saad/envs/stocks/lib/python3.7/site-packages/elegantrl/train/run_parallel.py", line 59, in <listcomp>
    traj_lists = [pipe1.recv() for pipe1 in self.pipe1s]
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.

KeyboardInterrupt: 

In [None]:
args.env.stocks

In [None]:
start_eval_date = '2016-01-05'
end_eval_date = '2021-01-05'
# used_tickers = ['PNW', 'BBY', 'BIO', 'RHI', 'CI', 'CSX', 'KO', 'CCZ', 'CMA', 'ETN',
#            'XOM', 'FDX', 'FRT', 'MTB', 'AJG', 'HGM', 'ITW', 'IFF', 'KSU', 'CVS',
#            'NEM', 'XEL', 'OKE', 'PNC', 'PEP', 'MO', 'RGE', 'SO', 'SYK', 'TER',
#            'TSN', 'AEE', 'WST', 'WDC', 'JKHY', 'TFC', 'MS', 'TROW', 'UNN', 'CDNS',
#            'DRE', 'ABMD', 'WRB', 'VLO', 'PBCT', 'PTC', 'XLNX', 'AZO', 'REGN',
#            'AES', 'STE', 'DHI', 'COST', 'EMN', 'ABC', 'WAB', 'HSIC', 'EL', 'RL',
#            'BXP', 'MTD', 'VRSN', 'RSG', 'MCO', 'GRMN', 'MDLZ', 'LVS', 'CE', 'UAA',
#            'FIS', 'CBRE']

tickers = {x: 0 for x in used_tickers}



agent = AgentPPO.AgentPPO() # AgentSAC(), AgentTD3(), AgentDDPG()?
agent.if_use_gae = True
agent.lambda_entropy = 0.04

initial_stocks = np.zeros([len(tickers)]).tolist()
print(initial_stocks)

env_eval = StockEnvMultiple(tickers=tickers,
                            begin_date=start_eval_date,
                            end_date=end_eval_date)


agent.if_on_policy = False
args = Arguments(env_eval, agent)

args.if_remove = False
args.cwd = './models/RLStockPPO_v3_100/'
args.gamma = gamma
args.break_step = int(2e5)
args.net_dim = 2 ** 9
args.max_step = args.env.max_step
args.max_memo = args.max_step * 4
args.batch_size = 2 ** 10
args.repeat_times = 2 ** 3
args.eval_gap = 2 ** 4
args.eval_times1 = 2 ** 3
args.eval_times2 = 2 ** 5
args.if_allow_break = False
args.target_return = 1.1e7

args.rollout_num = 6 # the number of rollout workers (larger is not always faster)
args.init_before_training()

env_eval.draw_cumulative_return(args, torch)
plt.show()

In [None]:
plt.show()