In [13]:
from pathlib import Path

from rljax.algorithm import SAC_Discrete
from rljax.trainer import Trainer
import numpy as np
import pandas as pd

from micro_price_trading.config import TWENTY_SECOND_DAY
from micro_price_trading import Preprocess, OptimalExecutionEnvironment

PATH = Path().cwd()

In [14]:
raw = Preprocess('TBT_TBF_data.csv', res_bin=6)
data = raw.process()

In [15]:
# 23,400 seconds between 9:30am and 4pm broken in 10 second increments

NUM_AGENT_STEPS = 10000
SEED = 0

env = OptimalExecutionEnvironment(
    data,
    risk_weights=(2, 1),
    trade_penalty=100,
    max_purchase=4,
    steps=TWENTY_SECOND_DAY,
    end_units_risk=TWENTY_SECOND_DAY,  # Ideally, this should be `TWENTY_SECOND_DAY//5*2`
    seed=SEED
)
env_test = env.copy_env()

algo = SAC_Discrete(
    num_agent_steps=NUM_AGENT_STEPS,
    state_space=env.observation_space,
    action_space=env.action_space,
    seed=SEED,
    batch_size=256,
    start_steps=1000,
    update_interval=1,
    update_interval_target=400
)

trainer = Trainer(
    env=env,
    env_test=env_test,
    algo=algo,
    log_dir="",
    num_agent_steps=NUM_AGENT_STEPS,
    eval_interval=2500,
    seed=SEED,
)
trainer.train()

Num steps: 2500     Return: -1025196.0   Time: 0:00:44
Num steps: 5000     Return: -1024288.6   Time: 0:01:30
Num steps: 7500     Return: -42564.8   Time: 0:02:16
Num steps: 10000    Return: -40400.5   Time: 0:03:01


In [26]:
pd.set_option('display.max_rows', 2000)
pd.set_option('display.max_columns', 2000)
df = env_test.portfolios_to_df(n=20)
df

Unnamed: 0,time,cash,shares,prices,total_risk,res_imbalance_state,trade,penalty_trade,trade_asset,trade_shares,trade_risk,trade_price,trade_cost,trade_penalty,risk,next_risk_target,distance_to_next_risk_target,rewards,observations,raw_action,action
0,0,0.0,"(0, 0)","(17.845000000000002, 16.394999999999996)",0,221,,,,,,,,,,5.0,5.0,,,4.0,0.0
1,1,0.0,"(0, 0)","(17.785, 16.415)",0,201,,,,,,,,,,5.0,5.0,"(0, actual)","[3, 5]",0.0,-4.0
2,2,-71.14,"(4, 0)","(17.785, 16.415)",8,201,"Trade(asset=1, shares=4, risk=8, price=17.785,...",,1.0,4.0,8.0,17.785,71.14,False,,5.0,-3.0,"(-10.600000000000009, risk penalty)","[2, 0]",0.0,-4.0
3,3,-142.28,"(8, 0)","(17.775, 16.415)",16,200,"Trade(asset=1, shares=4, risk=8, price=17.785,...",,1.0,4.0,8.0,17.785,71.14,False,,5.0,-11.0,"(-12.200000000000031, risk penalty)","[1, 0]",0.0,-4.0
4,4,-213.38,"(12, 0)","(17.775, 16.415)",24,200,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,5.0,-19.0,"(-14.560000000000173, risk penalty)","[0, 0]",0.0,-4.0
5,5,-284.48,"(16, 0)","(17.775, 16.415)",32,201,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,5.0,5.0,-27.0,"(-16.480000000000246, risk penalty)","[4, 0]",1.0,-3.0
6,6,-337.805,"(19, 0)","(17.775, 16.415)",38,201,"Trade(asset=1, shares=3, risk=6, price=17.775,...",,1.0,3.0,6.0,17.775,53.325,False,,10.0,-28.0,"(-15.040000000000191, risk penalty)","[3, 0]",0.0,-4.0
7,7,-408.905,"(23, 0)","(17.775, 16.415)",46,201,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-36.0,"(-10.0, risk penalty)","[2, 0]",0.0,-4.0
8,8,-480.005,"(27, 0)","(17.775, 16.415)",54,211,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-44.0,"(-10.0, risk penalty)","[1, 0]",0.0,-4.0
9,9,-551.105,"(31, 0)","(17.775, 16.415)",62,211,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-52.0,"(-10.0, risk penalty)","[0, 0]",0.0,-4.0


In [27]:
s = 0
for i in df.rewards.fillna(0):
    if i:
        s += i[0]
s

-44832.85000000524

In [None]:
env.plot()

In [None]:
env_test.plot()

In [None]:
env_test.plot('position_history')

In [None]:
env_test.plot('asset_paths')

In [None]:
env_test.plot('summarize_decisions')

In [None]:
env_test.plot('learning_progress')

In [None]:
env.portfolio_values

In [None]:
env.portfolio_history

In [None]:
env.portfolio_history