In [28]:
import gym
from gym import spaces
from rljax.algorithm import DQN
from rljax.trainer import Trainer
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from itertools import product
from tqdm.auto import tqdm

from environment import Env

In [2]:
data = pd.read_csv("SH_SDS_data_4.csv", index_col=0)
data.current_state = [('000'+str(cs))[-3:] for cs in data.current_state]
prob = pd.read_csv("SH_SDS_transition_matrix.csv", index_col=0)
prob.index = [('000'+str(idx))[-3:] for idx in prob.index]

In [3]:
def base_reward(action, last_state, current_state):
    return 100 if action > 100 else action

In [None]:
NUM_AGENT_STEPS = 5000
SEED = 0

env = Env(data, prob, base_reward, [1000, -500], 1000)
env_test = Env(data, prob, base_reward, [1000, -500], 1000)

algo = DQN(
    num_agent_steps=NUM_AGENT_STEPS,
    state_space=env.state_space,
    action_space=env.action_space,
    seed=SEED,
    batch_size=256,
    start_steps=1000,
    update_interval=1,
    update_interval_target=400,
    eps_decay_steps=0,
    loss_type="l2",
    lr=1e-3,
)

trainer = Trainer(
    env=env,
    env_test=env_test,
    algo=algo,
    log_dir="",
    num_agent_steps=NUM_AGENT_STEPS,
    eval_interval=1000,
    seed=SEED,
)
trainer.train()

In [None]:
env.plot()

In [None]:
env.plot('share_history')

In [None]:
env.portfolio_history

In [None]:
env.share_history

In [4]:
def sniper(state_1, state_2, env, start_state):
    state = start_state
    done = False
    all_rewards = list()
    while not done:
        if state[0] == state_1:
            state, reward, done, _ = env.step(0)
        elif state[0] == state_2:
            state, reward, done, _ = env.step(1)
        else:
            state, reward, done, _ = env.step(2)
        all_rewards.append(reward)
    return all_rewards

In [6]:
env = Env(data, prob, base_reward, [1000, -500], 1000)

In [34]:
start = env.reset()
mappings = env.mapping
strategy_rewards = dict()

In [35]:
pairs = list(product(mappings.keys(), mappings.keys()))
bad_pairs = list()
for s1, s2 in tqdm(pairs):
    start = env.reset()
    new_mappings = env.mapping
    try:
        strategy_rewards[(s1, s2)] = sniper(new_mappings[s1], new_mappings[s2], env, start)[-1]
    except KeyError:
        bad_pairs.append((s1, s2))

HBox(children=(FloatProgress(value=0.0, max=1089.0), HTML(value='')))




In [37]:
len(bad_pairs)

112

In [38]:
len(strategy_rewards)

977