In [1]:
"""!pip install finrl
!pip install alpaca_trade_api
!pip install exchange_calendars
!pip install stockstats
!pip install wrds
!pip install yfinance"""

'!pip install finrl\n!pip install alpaca_trade_api\n!pip install exchange_calendars\n!pip install stockstats\n!pip install wrds\n!pip install yfinance'

In [2]:
import os
import pandas as pd
import numpy as np
import yfinance as yf
import heapq
import random
import torch
from scipy.spatial import KDTree
from torch import nn
import itertools
import IPython.display
import gc
from tqdm import tqdm
import re

import gymnasium as gym
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from gymnasium import spaces
from gymnasium.utils import seeding
from stable_baselines3.common.vec_env import DummyVecEnv

from finrl.meta.data_processor import DataProcessor
from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv

from finrl.agents.stablebaselines3.models import DRLAgent
from finrl.config import TRAINED_MODEL_DIR, RESULTS_DIR
from finrl.main import check_and_make_directories

from stable_baselines3 import DQN, SAC
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.buffers import *
from stable_baselines3.common.logger import configure


In [3]:
# for the EURUSD=X data

TRAIN_START_DATE = "2010-01-01"
TRAIN_END_DATE = "2021-10-01"
TEST_START_DATE = "2021-10-01"
TEST_END_DATE = "2023-03-01"

dfs = []

temp_df = yf.download(
    "EURUSD=X", start=TRAIN_START_DATE, end=TEST_END_DATE, auto_adjust=False
)
temp_df["tic"] = "EURUSD=X"
dfs.append(temp_df)

data_df = pd.concat(dfs)
data_df = data_df.reset_index()


data_df = data_df.rename(
    columns={
        "Date": "date",
        "Open": "open",
        "High": "high",
        "Low": "low",
        "Close": "close",
        "Adj Close": "adjcp",
        "Volume": "volume",
    }
)

data_df = data_df.drop(columns=["adjcp"])
data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))

# drop missing data
data_df = data_df.dropna()
data_df = data_df.reset_index(drop=True)
print("Shape of DataFrame: ", data_df.shape)

data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)
data_df.columns = data_df.columns.get_level_values(0)

data_df.head()

[*********************100%***********************]  1 of 1 completed

Shape of DataFrame:  (3428, 7)



  data_df = data_df.drop(columns=["adjcp"])


Price,date,close,high,low,open,volume,tic
0,2010-01-01,1.438994,1.440196,1.432706,1.432706,0,EURUSD=X
1,2010-01-04,1.442398,1.445191,1.426208,1.431004,0,EURUSD=X
2,2010-01-05,1.436596,1.44831,1.435194,1.44271,0,EURUSD=X
3,2010-01-06,1.440403,1.44346,1.429123,1.436596,0,EURUSD=X
4,2010-01-07,1.431803,1.444481,1.430206,1.4403,0,EURUSD=X


In [4]:
INDICATORS = ["macd", "rsi_30", "cci_30", "dx_30"]

fe = FeatureEngineer(
    use_technical_indicator=True,
    tech_indicator_list=INDICATORS,
    use_vix=False,
    use_turbulence=True,
    user_defined_feature=False,
)

processed = fe.preprocess_data(data_df)
print(processed.head())

list_ticker = processed["tic"].unique().tolist()
list_date = list(
    pd.date_range(processed["date"].min(), processed["date"].max()).astype(str)
)
combination = list(itertools.product(list_date, list_ticker))

processed_full = pd.DataFrame(combination, columns=["date", "tic"]).merge(
    processed, on=["date", "tic"], how="left"
)
processed_full = processed_full[processed_full["date"].isin(processed["date"])]
processed_full = processed_full.sort_values(["date", "tic"])

processed_full = processed_full.fillna(0)

Successfully added technical indicators
Successfully added turbulence index
         date     close      high       low      open  volume       tic  \
0  2010-01-01  1.438994  1.440196  1.432706  1.432706       0  EURUSD=X   
1  2010-01-04  1.442398  1.445191  1.426208  1.431004       0  EURUSD=X   
2  2010-01-05  1.436596  1.448310  1.435194  1.442710       0  EURUSD=X   
3  2010-01-06  1.440403  1.443460  1.429123  1.436596       0  EURUSD=X   
4  2010-01-07  1.431803  1.444481  1.430206  1.440300       0  EURUSD=X   

       macd      rsi_30      cci_30       dx_30  turbulence  
0  0.000000  100.000000   66.666667  100.000000         0.0  
1  0.000076  100.000000   66.666667  100.000000         0.0  
2 -0.000083   36.189879  100.000000   33.643671         0.0  
3 -0.000015   55.476684  -42.150727   60.221140         0.0  
4 -0.000321   32.513547 -140.430918   49.776959         0.0  


In [5]:
train = data_split(processed_full, TRAIN_START_DATE, TRAIN_END_DATE)
trade = data_split(processed_full, TEST_START_DATE, TEST_END_DATE)
print(len(train))
print(len(trade))

# Save to file
# train.to_csv('train_data.csv')
# trade.to_csv('trade_data.csv')

# train = train.set_index("date", drop=False)
# train.index.names = [""]
print(train.head())

# trade = trade.set_index("date", drop=False)
# trade.index.names = [""]

3060
368
         date       tic     close      high       low      open  volume  \
0  2010-01-01  EURUSD=X  1.438994  1.440196  1.432706  1.432706     0.0   
1  2010-01-04  EURUSD=X  1.442398  1.445191  1.426208  1.431004     0.0   
2  2010-01-05  EURUSD=X  1.436596  1.448310  1.435194  1.442710     0.0   
3  2010-01-06  EURUSD=X  1.440403  1.443460  1.429123  1.436596     0.0   
4  2010-01-07  EURUSD=X  1.431803  1.444481  1.430206  1.440300     0.0   

       macd      rsi_30      cci_30       dx_30  turbulence  
0  0.000000  100.000000   66.666667  100.000000         0.0  
1  0.000076  100.000000   66.666667  100.000000         0.0  
2 -0.000083   36.189879  100.000000   33.643671         0.0  
3 -0.000015   55.476684  -42.150727   60.221140         0.0  
4 -0.000321   32.513547 -140.430918   49.776959         0.0  


In [6]:
INDICATORS = ["macd", "rsi_30", "cci_30", "dx_30"]

stock_dimension = len(train.tic.unique())
state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")

buy_cost_list = sell_cost_list = [0.001] * stock_dimension
num_stock_shares = [0] * stock_dimension

env_kwargs = {
    "hmax": 100,
    "initial_amount": 1000000,
    "num_stock_shares": num_stock_shares,
    "buy_cost_pct": buy_cost_list,
    "sell_cost_pct": sell_cost_list,
    "state_space": state_space,
    "stock_dim": stock_dimension,
    "tech_indicator_list": INDICATORS,
    "action_space": stock_dimension,
    "reward_scaling": 1e-4,
}

Stock Dimension: 1, State Space: 7


In [7]:
print(train.head())
try:
    e_train_gym = StockTradingEnv(df=train, **env_kwargs)
    env_train, _ = e_train_gym.get_sb_env()
    print(type(env_train))
except KeyError as e:
    print(f"KeyError: {e}. Check if 'train' DataFrame is correctly initialized and not empty.")

         date       tic     close      high       low      open  volume  \
0  2010-01-01  EURUSD=X  1.438994  1.440196  1.432706  1.432706     0.0   
1  2010-01-04  EURUSD=X  1.442398  1.445191  1.426208  1.431004     0.0   
2  2010-01-05  EURUSD=X  1.436596  1.448310  1.435194  1.442710     0.0   
3  2010-01-06  EURUSD=X  1.440403  1.443460  1.429123  1.436596     0.0   
4  2010-01-07  EURUSD=X  1.431803  1.444481  1.430206  1.440300     0.0   

       macd      rsi_30      cci_30       dx_30  turbulence  
0  0.000000  100.000000   66.666667  100.000000         0.0  
1  0.000076  100.000000   66.666667  100.000000         0.0  
2 -0.000083   36.189879  100.000000   33.643671         0.0  
3 -0.000015   55.476684  -42.150727   60.221140         0.0  
4 -0.000321   32.513547 -140.430918   49.776959         0.0  
<class 'stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv'>


In [8]:
# TODO: still need to fix/add "reward" strategy; also check coverage

class SERReplayBuffer(BaseBuffer):
    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        strategy: str,
        priority_queue_size: int,
        priority_queue_percent: float, # 10% use --> 0.1
        device: Union[th.device, str] = "auto",
        n_envs: int = 1,
        optimize_memory_usage: bool = False,
        handle_timeout_termination: bool = True,
    ):
        super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)

        # Adjust buffer size
        self.buffer_size = max(buffer_size // n_envs, 1)

        # Check that the replay buffer can fit into the memory
        if psutil is not None:
            mem_available = psutil.virtual_memory().available

        # there is a bug if both optimize_memory_usage and handle_timeout_termination are true
        # see https://github.com/DLR-RM/stable-baselines3/issues/934
        if optimize_memory_usage and handle_timeout_termination:
            raise ValueError(
                "ReplayBuffer does not support optimize_memory_usage = True "
                "and handle_timeout_termination = True simultaneously."
            )
        self.optimize_memory_usage = optimize_memory_usage

        self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)

        if not optimize_memory_usage:
            # When optimizing memory, `observations` contains also the next observation
            self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)

        self.actions = np.zeros(
            (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
        )

        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        # Handle timeouts termination properly if needed
        # see https://github.com/DLR-RM/stable-baselines3/issues/284
        self.handle_timeout_termination = handle_timeout_termination
        self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

        if psutil is not None:
            total_memory_usage: float = (
                self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
            )

            if not optimize_memory_usage:
                total_memory_usage += self.next_observations.nbytes

            if total_memory_usage > mem_available:
                # Convert to GB
                total_memory_usage /= 1e9
                mem_available /= 1e9
                warnings.warn(
                    "This system does not have apparently enough memory to store the complete "
                    f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
                )
        
        self.strategy = strategy  
        self.gamma = 0.99
        self.dist_threshold = 0.5
        self.priority_queue_size= priority_queue_size
        self.long_term_memory =[]  # (score, (obs, next_obs, action, reward, done))
        self.priority_queue_percent = priority_queue_percent
        self.q_net = None
        self.q_net_target = None

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        infos: list[dict[str, Any]],
    ) -> None:
        # Reshape needed when using multiple envs with discrete observations
        # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
        if isinstance(self.observation_space, spaces.Discrete):
            obs = obs.reshape((self.n_envs, *self.obs_shape))
            next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))

        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
        action = action.reshape((self.n_envs, self.action_dim))

        # Copy to avoid modification by reference
        self.observations[self.pos] = np.array(obs)

        if self.optimize_memory_usage:
            self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
        else:
            self.next_observations[self.pos] = np.array(next_obs)

        self.actions[self.pos] = np.array(action)
        self.rewards[self.pos] = np.array(reward)
        self.dones[self.pos] = np.array(done)

        if self.handle_timeout_termination:
            self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])

        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
            self.pos = 0

        # compute score using given strategy and push to priority queue
        # will get rid of the lowest score among the ones stored in long term mem if we exceed memory limit
        
        idx = (self.pos - 1) % self.buffer_size
        for env_idx in range(self.n_envs):
            score = self.compute_score(obs[env_idx], next_obs[env_idx], action[env_idx], reward[env_idx], done[env_idx])
            heapq.heappush(self.long_term_memory, (score, idx, env_idx))

        if len(self.long_term_memory) > self.priority_queue_size:
            heapq.heappop(self.long_term_memory)

    def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
        #print("sampling")
        long_size=int(batch_size*self.priority_queue_percent)
        fifo_size=batch_size-long_size
        
        # sample from FIFO
        if not self.optimize_memory_usage:
            return super().sample(batch_size=batch_size, env=env)
        # Do not sample the element with index `self.pos` as the transitions is invalid
        # (we use only one array to store `obs` and `next_obs`)
        if self.full:
            fifo_inds = (np.random.randint(1, self.buffer_size, size=fifo_size) + self.pos) % self.buffer_size
        else:
            fifo_inds = np.random.randint(0, self.pos, size=fifo_size)

        fifo_samples = self._get_samples(fifo_inds, env=env)

        # sample from long term memoru
        if len(self.long_term_memory) >= long_size:
            sampled_long_mem = random.sample(self.long_term_memory, long_size)
        else:
            sampled_long_mem = self.long_term_memory

        buffer_idxs = [idx for _, idx, env in sampled_long_mem]
        env_idxs = [env for _, idx, env in sampled_long_mem]

        long_obs = self.observations[buffer_idxs, env_idxs]
        if self.optimize_memory_usage:
            long_next_obs = self.observations[(np.array(buffer_idxs) + 1) % self.buffer_size, env_idxs]
        else:
            long_next_obs = self.next_observations[buffer_idxs, env_idxs]

        long_actions = self.actions[buffer_idxs, env_idxs]
        long_rewards = self.rewards[buffer_idxs, env_idxs].reshape(-1, 1)
        long_dones = (self.dones[buffer_idxs, env_idxs] * (1 - self.timeouts[buffer_idxs, env_idxs])).reshape(-1, 1)

        long_obs = self.to_torch(self._normalize_obs(long_obs, env))
        long_next_obs = self.to_torch(self._normalize_obs(long_next_obs, env))
        long_actions = self.to_torch(long_actions)
        long_rewards = self.to_torch(self._normalize_reward(long_rewards, env))
        long_dones = self.to_torch(long_dones)

        # combine FIFO + long-term mem
        obs = torch.cat([fifo_samples.observations, long_obs], dim=0)
        next_obs = torch.cat([fifo_samples.next_observations, long_next_obs], dim=0)
        actions = torch.cat([fifo_samples.actions, long_actions], dim=0)
        rewards = torch.cat([fifo_samples.rewards, long_rewards], dim=0)
        dones = torch.cat([fifo_samples.dones, long_dones], dim=0)

        return ReplayBufferSamples(obs, actions, next_obs, dones, rewards)

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)

        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
        )
        #print("sampling done")
        return ReplayBufferSamples(*tuple(map(self.to_torch, data)))

    def compute_score(self, obs, next_obs, action, reward, done):
        if self.strategy == "reward":
            return float(abs(reward)) # TODO: need to implement this one
        elif self.strategy == "distribution":
            return float(np.random.normal()) 
        
        elif self.strategy== "surprise":
            with torch.no_grad():
                obs_tensor = self.to_torch(obs).unsqueeze(0)
                next_obs_tensor = self.to_torch(next_obs).unsqueeze(0)  
                action_tensor = self.to_torch(action).long().unsqueeze(0)  
                reward_tensor = self.to_torch(reward).unsqueeze(0)
                done_tensor = self.to_torch(done).unsqueeze(0).float()

                q_values = self.q_net(obs_tensor)
                q_sa = q_values.gather(1, action_tensor)

                next_q_values = self.q_net_target(next_obs_tensor)
                max_q_next = next_q_values.max(1, keepdim=True).values

                td_target = reward_tensor + self.gamma * (1.0 - done_tensor) * max_q_next
                td_error = torch.abs(td_target - q_sa)
                return float(td_error.item())
            
        elif self.strategy == "coverage": 
           norm_obs = self._normalize_obs(obs, env=None).flatten()
           # build KD tree
           all = []
           for _, idx, env in self.long_term_memory:
               exist_obs = self._normalize_obs(self.observations[idx, env], env=None).flatten()
               all.append(exist_obs)

           if len(all) == 0:
               return 0
           tree= KDTree(np.stack(all))
           neighbors = tree.query_ball_point(norm_obs, r=self.dist_threshold)
           count= len(neighbors)
           return -count
        return 0
    
    def set_q_nets(self, q_net, q_net_target):
        self.q_net = q_net
        self.q_net_target = q_net_target
    
    @staticmethod
    def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
        """
        Cast `np.float64` action datatype to `np.float32`,
        keep the others dtype unchanged.
        See GH#1572 for more information.

        :param dtype: The original action space dtype
        :return: ``np.float32`` if the dtype was float64,
            the original dtype otherwise.
        """
        if dtype == np.float64:
            return np.float32
        return dtype

In [9]:
base_config = {
    # TRAINING PARAMETERS
    "log_dir": "./train_logs/base",
    "n_envs": 32, # number of parallel environments to use for training
    "checkpoint": None, # path to a checkpoint to load from
    "checkpoint_freq": 10000, # save a model checkpoint every _ steps
    "eval_freq": 5000, # evaluate the model every _ steps
    "n_eval_episodes": 10, # number of episodes to evaluate the model on
    "n_train_timesteps": int(1e6), # total number of training steps
    "verbose_training": True,
    # RL PARAMETERS (all set to defaults right now, except for seed)
    "policy_args": {
        "net_arch": [64, 64],
        "activation_fn": nn.ReLU,
    },
    "algo_kwargs": {
        "learning_rate": 1e-4,
        "buffer_size": int(1e6),
        "learning_starts": 100, # how many steps of the model to collect transitions for before learning starts
        "batch_size": 32,
        "tau": 1.0, # the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
        "gamma": 0.99,
        "train_freq": (4, 'step'), # Update the model every ``train_freq`` steps.
        "gradient_steps": 1, # How many gradient steps to do after each rollout
        "target_update_interval": int(1e4), # update the target network every ``target_update_interval`` environment steps.
        #"exploration_fraction": 0.1, # fraction of entire training period over which the exploration rate is reduced
        #"exploration_initial_eps": 1.0, # initial value of random action probability
        #"exploration_final_eps": 0.05, # final value of random action probability
        "seed": 42,
    }
}

In [10]:
from stable_baselines3.common.logger import configure

agent = DRLAgent(env = env_train)
SAC_PARAMS = {
    "batch_size": 128,
    "buffer_size": 100000,
    "learning_rate": 0.0001,
    "learning_starts": 100,
    "ent_coef": "auto_0.1",
}

model_sac = agent.get_model("sac",model_kwargs = SAC_PARAMS)


# set up logger
tmp_path = RESULTS_DIR + '/sac'
new_logger_sac = configure(tmp_path, ["stdout", "csv"])
# Set new logger
model_sac.set_logger(new_logger_sac)

{'batch_size': 128, 'buffer_size': 100000, 'learning_rate': 0.0001, 'learning_starts': 100, 'ent_coef': 'auto_0.1'}
Using cuda device
Logging to results/sac


In [11]:
def train_rl(agent, env_train, eval_env, algo, config):
    print("Initializing...")

    log_dir = config["log_dir"]
    ckpt = config.get("checkpoint", None)
    algo_kwargs = config["algo_kwargs"]
    checkpoint_freq = config["checkpoint_freq"]
    eval_freq = config["eval_freq"]
    n_eval_episodes = config["n_eval_episodes"]
    n_train_timesteps = config["n_train_timesteps"]
    verbose_training = config["verbose_training"]

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    model = agent.get_model(algo, model_kwargs=algo_kwargs)

    if ckpt is not None:
        model = model.load(ckpt, env=env_train)

    # SER support if needed
    # if hasattr(model, "replay_buffer") and hasattr(model.replay_buffer, "set_q_nets"):
    #   model.replay_buffer.set_q_nets(model.q_net, model.q_net_target)

    # Logger
    new_logger = configure(log_dir, ["stdout", "csv"])
    model.set_logger(new_logger)

    # Callbacks
    checkpoint_callback = CheckpointCallback(
        save_freq=checkpoint_freq,
        save_path=log_dir,
        name_prefix=os.path.basename(log_dir),
        verbose=0,
    )

    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=log_dir,
        log_path=log_dir,
        eval_freq=eval_freq,
        n_eval_episodes=n_eval_episodes,
        deterministic=True,
        render=False,
        verbose=0,
    )

    print("Training...")
    try:
        model.learn(
            total_timesteps=n_train_timesteps,
            callback=[checkpoint_callback, eval_callback],
            log_interval=1 if verbose_training else None,
            reset_num_timesteps=False,
            progress_bar=True if verbose_training else False,
        )
    except BaseException as error:
        print(f"Error during training: {error}")
        raise

    return model

In [None]:
e_train_gym = StockTradingEnv(df=train, **env_kwargs)
env_train, _ = e_train_gym.get_sb_env()

e_eval_gym = StockTradingEnv(df=trade, **env_kwargs)
env_eval, _ = e_eval_gym.get_sb_env()

agent = DRLAgent(env=env_train)

SAC_PARAMS = {
    "batch_size": 128,
    "buffer_size": 100000,
    "learning_rate": 0.0001,
    "learning_starts": 100,
    "ent_coef": "auto_0.1",
}

log_dir = "./train_logs/base"
base_config = {
    # TRAINING PARAMETERS
    "log_dir": log_dir,
    "n_envs": 32,  # number of parallel environments to use for training
    "checkpoint": None,  # path to a checkpoint to load from
    "checkpoint_freq": 10000,  # save a model checkpoint every _ steps
    "eval_freq": 5000,  # evaluate the model every _ steps
    "n_eval_episodes": 10,  # number of episodes to evaluate the model on
    "n_train_timesteps": int(1e6),  # total number of training steps
    "verbose_training": True,
    # RL PARAMETERS (all set to defaults right now, except for seed)
    "policy_args": {
        "net_arch": [64, 64],
        "activation_fn": nn.ReLU,
    },
    "algo_kwargs": SAC_PARAMS,
}


log_dir = "./train_logs/ser"
ser_config = base_config.copy()
ser_config["log_dir"] = log_dir
latest_step = -1
checkpoint_path = None
if os.path.exists(log_dir):
    # Find files matching the checkpoint pattern (e.g., ser_10000_steps.zip)
    pattern = re.compile(r"ser_(\d+)_steps\.zip")
    for filename in os.listdir(log_dir):
        match = pattern.match(filename)
        if match:
            steps = int(match.group(1))
            if steps > latest_step:
                latest_step = steps
                checkpoint_path = os.path.normpath(os.path.join(log_dir, filename))
if checkpoint_path is not None:
    print(f"Found latest checkpoint to load: {checkpoint_path}")
    ser_config["checkpoint"] = checkpoint_path
else:
    print("No previous checkpoint found. Starting training from scratch.")

ser_config["algo_kwargs"]["replay_buffer_class"] = SERReplayBuffer

ser_config["strategy"] = "distribution"
ser_config["priority_queue_size"] = 5000
ser_config["priority_queue_percent"] = 0.5


ser_config["algo_kwargs"]["replay_buffer_kwargs"] = {
    "strategy": ser_config["strategy"],
    "priority_queue_size": ser_config["priority_queue_size"],
    "priority_queue_percent": ser_config["priority_queue_percent"],
}
ser_config["algo_kwargs"]["device"] = "cpu"


IPython.display.clear_output(wait=True)
gc.collect()

try:
    tqdm._instances.clear()
except:
    pass


model = train_rl(agent, env_train, env_eval, "sac", ser_config)

NameError: name 'checkpoint_path' is not defined

In [None]:
e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs)
df_account_value_sac, df_actions_sac = DRLAgent.DRL_prediction(
    model=model, 
    environment = e_trade_gym) 

df_result_sac = (
    df_account_value_sac.set_index(df_account_value_sac.columns[0])
)

# we can add all the different SER strategies + FIFO here + maybe a baseline comparision?
# i think theres either mean variance or DJIA 

result = pd.DataFrame(
    {
        "sac": df_result_sac["account_value"] 
    }
)

plt.rcParams["figure.figsize"] = (15,5)
plt.figure()
result.plot()