In [None]:
!pip install stable-baselines3

In [None]:
from drl4t_data import download
from drl4t_env import DRL4TEnv

train_data, test_data = download('nyse.csv')
env = DRL4TEnv(train_data)

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN

model = DQN('MlpPolicy', DummyVecEnv([lambda: env]), learning_rate=0.001, verbose=1)
model.learn(total_timesteps=1000, log_interval=10)
model.save('nyse_dqn_model.pt')

In [None]:
model = DQN.load('nyse_dqn_model.pt')

model.set_env(DummyVecEnv([lambda: env]))
model.learn(total_timesteps=1000, log_interval=10)
model.save('nyse_dqn_model.pt')

In [None]:
import pandas as pd

model = DQN.load('nyse_dqn_model.pt')

logs = []
for symbol, data in test_data.items():
    env = DRL4TEnv({ symbol: data })
    model.set_env(DummyVecEnv([lambda: env]))

    obs = env.reset()
    done = False

    log = pd.DataFrame()
    while(not done):
        action, _ = model.predict(obs)
        obs, _, done, info = env.step(action)
        log = pd.concat([log, pd.DataFrame(info, index=[info['Date']])])
    logs.append(log)

In [None]:
val = pd.DataFrame()
for log in logs:
    log['Benchmark'] = env.starting_balance / log['Close'][0] * log['Close']
    log['Policy'] = log['Total']
    val = val.add(log[['Benchmark', 'Policy']], fill_value=0)
val.to_csv('nyse_dqn_val.csv')

In [None]:
val

In [None]:
import matplotlib.pyplot as plt

val = pd.read_csv('nyse_dqn_val.csv', parse_dates=True, index_col=0)

val['Policy'] /= val['Benchmark'][0]
val['Benchmark'] /= val['Benchmark'][0]

ax = val[['Policy', 'Benchmark']].plot(title='Normalized Policy vs. Benchmark')
ax.set_xlabel('Date')
ax.set_ylabel('Normalized Balance')
plt.show() 