In [1]:
from pathlib import Path

from rljax.algorithm import SAC_Discrete
from rljax.trainer import Trainer
import numpy as np
import pandas as pd

from micro_price_trading.config import TWENTY_SECOND_DAY
from micro_price_trading import Preprocess, OptimalExecutionEnvironment

PATH = Path().cwd()

In [2]:
raw = Preprocess('TBT_TBF_data.csv', res_bin=6)
data = raw.process()

In [3]:
# 23,400 seconds between 9:30am and 4pm broken in 10 second increments

NUM_AGENT_STEPS = 10000
SEED = 0

env = OptimalExecutionEnvironment(
    data,
    risk_weights=(2, 1),
    trade_penalty=100,
    max_purchase=4,
    steps=TWENTY_SECOND_DAY,
    end_units_risk=TWENTY_SECOND_DAY,  # Ideally, this should be `TWENTY_SECOND_DAY//5*2`
    seed=SEED
)
env_test = env.copy_env()

algo = SAC_Discrete(
    num_agent_steps=NUM_AGENT_STEPS,
    state_space=env.observation_space,
    action_space=env.action_space,
    seed=SEED,
    batch_size=256,
    start_steps=1000,
    update_interval=1,
    update_interval_target=400
)

trainer = Trainer(
    env=env,
    env_test=env_test,
    algo=algo,
    log_dir="",
    num_agent_steps=NUM_AGENT_STEPS,
    eval_interval=2500,
    seed=SEED,
)
trainer.train()



Num steps: 2500     Return: -81669.3   Time: 0:00:42
Num steps: 5000     Return: -82439.7   Time: 0:01:32
Num steps: 7500     Return: -83168.5   Time: 0:02:23
Num steps: 10000    Return: -80800.0   Time: 0:03:08


In [4]:
pd.set_option('display.max_rows', 2000)
pd.set_option('display.max_columns', 2000)
df = env_test.portfolios_to_df(n=20)
df

Unnamed: 0,time,cash,shares,prices,total_risk,res_imbalance_state,trade,penalty_trade,trade_asset,trade_shares,trade_risk,trade_price,trade_cost,trade_penalty,risk,next_risk_target,distance_to_next_risk_target,rewards,observations,raw_action,action
0,0,0.0,"(0, 0)","(17.845000000000002, 16.394999999999996)",0,221,,,,,,,,,,5.0,5.0,,,0.0,-4.0
1,1,-71.38,"(4, 0)","(17.785, 16.415)",8,201,"Trade(asset=1, shares=4, risk=8, price=17.8450...",,1.0,4.0,8.0,17.845,71.38,False,,5.0,-3.0,"(-40.12000000000002, over risk penalty)","[3, 0]",0.0,-4.0
2,2,-142.52,"(8, 0)","(17.785, 16.415)",16,201,"Trade(asset=1, shares=4, risk=8, price=17.785,...",,1.0,4.0,8.0,17.785,71.14,False,,5.0,-11.0,"(-42.20000000000003, over risk penalty)","[2, 0]",0.0,-4.0
3,3,-213.66,"(12, 0)","(17.775, 16.415)",24,200,"Trade(asset=1, shares=4, risk=8, price=17.785,...",,1.0,4.0,8.0,17.785,71.14,False,,5.0,-19.0,"(-43.800000000000054, over risk penalty)","[1, 0]",0.0,-4.0
4,4,-284.76,"(16, 0)","(17.775, 16.415)",32,200,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,5.0,-27.0,"(-46.480000000000246, over risk penalty)","[0, 0]",0.0,-4.0
5,5,-355.86,"(20, 0)","(17.775, 16.415)",40,201,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,5.0,5.0,-35.0,"(-48.40000000000032, over risk penalty)","[4, 0]",0.0,-4.0
6,6,-426.96,"(24, 0)","(17.775, 16.415)",48,201,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-38.0,"(-49.120000000000346, over risk penalty)","[3, 0]",0.0,-4.0
7,7,-498.06,"(28, 0)","(17.775, 16.415)",56,201,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-46.0,"(-40.0, over risk penalty)","[2, 0]",0.0,-4.0
8,8,-569.16,"(32, 0)","(17.775, 16.415)",64,211,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-54.0,"(-40.0, over risk penalty)","[1, 0]",0.0,-4.0
9,9,-640.26,"(36, 0)","(17.775, 16.415)",72,211,"Trade(asset=1, shares=4, risk=8, price=17.775,...",,1.0,4.0,8.0,17.775,71.1,False,,10.0,-62.0,"(-40.0, over risk penalty)","[0, 0]",0.0,-4.0


In [5]:
s = 0
for i in df.rewards.fillna(0):
    if i:
        s += i[0]
s

-85819.96000000606

In [None]:
env.plot()

In [None]:
env_test.plot()

In [None]:
env_test.plot('position_history')

In [None]:
env_test.plot('asset_paths')

In [None]:
env_test.plot('summarize_decisions')

In [None]:
env_test.plot('learning_progress')

In [None]:
env.portfolio_values

In [None]:
env.portfolio_history

In [None]:
env.portfolio_history