In [1]:
import os
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ray
from ray import tune, air
from ray.tune import JupyterNotebookReporter
from ray.tune.logger import TBXLoggerCallback
from ray.rllib.algorithms import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from scipy.special import softmax
import seaborn as sns

from stocktradingv2.agent.mysac import MySAC, MySACConfig
from stocktradingv2.env.MultiStockTradingEnv import MultiStockTradingEnv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ray.init()

2023-03-11 21:38:26,766	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m


0,1
Python version:,3.9.13
Ray version:,2.3.0
Dashboard:,http://127.0.0.1:8266


In [None]:
dfs = []
for root, dirs, files in os.walk("./datasets/SSE50/"):
    for file in files:
        path = os.path.join("./datasets/SSE50/", file)
        df = pd.read_csv(path)
        dfs.append((file, df))

In [None]:
SEED = 114
np.random.shuffle(dfs)
dfs = dfs[:10]
tics = " ".join([tic for tic, df in dfs])
print(tics)

In [None]:
# calculate baseline
df = list_df[0]
df.rename(columns={"close_": "close", "close": "close_"}, inplace=True)
print(np.log(df.close.iloc[-1]/df.close.iloc[0]))
sns.lineplot(df.close_)

In [None]:
test_start = '2017-01-01'
trade_start = '2020-01-01'
test_start = pd.to_datetime(test_start, format='%Y-%m-%d')
trade_start = pd.to_datetime(trade_start, format='%Y-%m-%d')

# split
dfs_train = []
dfs_test = []
dfs_trade = []
for tic, df in dfs:
    df.date = pd.to_datetime(df.date, format='%Y-%m-%d')
    df_train = df.loc[df.date < test_start].sort_index(ascending=True).copy()
    df_test = df.loc[(df.date >= test_start) & (df.date < trade_start)].sort_index(ascending=True).copy()
    df_trade = df.loc[df.date >= trade_start].sort_index(ascending=True).copy()
    dfs_train.append(df_train)
    dfs_test.append(df_test)
    dfs_trade.append(df_trade)

# ensemble
for dfs_t in [dfs_train, dfs_test, dfs_trade]:
    dfs_t = pd.concat(dfs_t)
    # drop dates that missing data
    dfs_t = dfs_t.pivot_table(index=['date'], columns=['tic']).dropna().stack().reset_index()
    dfs_t.sort_values(['date', 'tic'], inplace=True)
    dfs_t.set_index(['date', 'tic'], inplace=True)
    dfs_t.head(5)

In [None]:
class MyCallbacks(DefaultCallbacks):
    def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
        self._eval_counter = 0
        self._train_counter = 0
        super().__init__(legacy_callbacks_dict)

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: EpisodeV2,
        env_index: int,
        **kwargs
    ):
        env = base_env.get_sub_environments()[0]

        episode.custom_metrics["log-ret"] = np.log(env.asset_memory[-1] / env.asset_memory[0])
        # episode.hist_data["action_memory"] = env.action_memory
        # episode.hist_data["asset_memory"] = env.asset_memory

        a = np.array(env.action_memory).transpose()
        a = softmax(a, axis=0)
        fig, ax = plt.subplots(figsize=(16, 7))
        plt.stackplot(np.arange(a.shape[1]), a)
        
        in_eval =  worker.policy_config["in_evaluation"]
        dqn_type = worker.policy_config["q_model_config"]["type"]
        if in_eval:
            fig.savefig(f"./{dqn_type}_{self._eval_counter}_{episode.episode_id:05d}.png")
            self._eval_counter += 1
        else:
            if self._train_counter % 10 == 0:
                fig.savefig(f"./{dqn_type}_{self._train_counter}_{episode.episode_id:05d}.png")
            self._train_counter += 1

In [None]:
param_space = MySACConfig().to_dict()
param_space.update(
    {
        "framework": "torch",
        "env": "MultiStockTrading",
        "env_config": {
            "df": dfs_train,
        # "verbose": True,
        },
        "policy_model_config": {
            "lstm_dim": 64,
            "net_arch": [256, 256],
        },
        "q_model_config": {
            "type": tune.grid_search(["dqn", "cqn", "qrdqn", "iqn"]),
            "lstm_dim": 64,
            "num_atoms": 50,
            "net_arch": [256, 256],
            "num_critics": 1,
            # cqn
            "vmin": -80.0,
            "vmax": 0,
            # iqn
            "risk_distortion_measure": None,
            "cos_embedding_dim": 64,
        },
        "tau": 0.01,
        "target_entropy": "auto",
        "n_step": 1,
        "train_batch_size": 256,
        "target_network_update_freq": 1,
        "grad_clip": 40,
        "min_sample_timesteps_per_iteration": 200,
        "num_steps_sampled_before_learning_starts": 256,
        "metrics_num_episodes_for_smoothing": 5,
        "num_workers": 0,
        "num_envs_per_worker": 2,
        "num_cpus_per_worker": 1,
        "num_steps_sampled_before_learning_starts": 256,
        "train_batch_size": 256,
        "target_network_update_freq": 1,
        "callbacks": MyCallbacks,

        "evaluation_interval": 3,
        "evaluation_duration": 1,
        "evaluation_duration_unit": "episodes",
        "evaluation_num_workers": 1,
        "evaluation_config": {
            "explore": False,
            "env_config": {
                "df": dfs_test,
            },
        }
    }
)
param_space

In [None]:
tuner = tune.Tuner(
    MySAC, 
    param_space=param_space,
    tune_config=tune.TuneConfig(num_samples=1),
    run_config=air.RunConfig(
        progress_reporter=JupyterNotebookReporter(),
        stop={
            "episode_reward_mean": 10,
            "timesteps_total": 10000,
        },
        callbacks=[TBXLoggerCallback()],
    )
)

In [None]:
results = tuner.fit()

In [None]:
result = results.get_best_result(metric="episode_reward_mean", mode="max")
cp = result.best_checkpoints[0][0]
algo = MySAC.from_checkpoint(cp)