# Reinforcement Learning Agent for Statistical Arbitrage

This notebook demonstrates how to configure and train a reinforcement learning agent for statistical arbitrage using a snapshot model. The agent will learn to identify and exploit statistical arbitrage opportunities in financial markets.

## Install necessary libraries

The following libraries are required for this notebook. If you haven't installed them yet, you can do so using running the cell below or by using pip install in your terminal.

In [None]:
%pip -q install -U numpy pandas pyarrow gdown gymnasium stable-baselines3 torch matplotlib tensorboard

## Configuration

The configuration section sets up the parameters for the reinforcement learning model.

In [None]:
CONFIG = {
    "DATA": {
        "drive_folder_id": "1uXEBUyySypdsW_ZqL-RZ3d1bWdIZisij", # google drive folder ID (can be found in the URL)
        "structure": {"1d":"ohlcv_1d", # subfolder structure inside the drive folder
                      "1h":"ohlcv_1h",
                      "15m":"ohlcv_15m",
                      "5m":"ohlcv_5m",
                      "1m":"ohlcv_1m"},
        "file_pattern": "{TICKER}_{FREQ}.parquet", # file naming pattern
        "tickers": ["BTC","ETH"], # pairs of assets to trade
        "sampling": "1h", # data sampling frequency (1m, 5m, 15m, 1h, 1d)
        "price_point": "close", # what price point to use for returns calculation (open, high, low, close, etc)
        "forward_fill": True, # forward fill missing data
        "drop_na_after_ffill": True, # drop rows with NA values after forward filling
        "cache_dir": "./data_cache", # local cache directory to store downloaded files
        "file_ids": {"BTC_1h": "1-sBNQpEFGEpVO3GDFCkSZiV3Iaqp2vB_", # map and define files to download (all data will be downloaded if left empty)
                     "ETH_1h": "1kj8G1scpFuEYTTXKEUzF9pwgGI2WFFL9"
                     },
    },
    "ENV": {
        "include_cash": True, # include cash as a third asset
        "shorting": True, # allow shorting
        "lookback_window": 64, # lookback for feature slicing, meaning that the observation for each step will include data from the last 'lookback_window' time steps
        "features": {
            "vol_window": 64, # rolling volatility window
            "rsi_period": 14, # RSI lookback period
            "volume_change": True, # include volume change as a feature
            "normalize": True # normalize features using rolling z-score
        },
        "transaction_costs": { # transaction costs settings for exchange (e.g. Hyperliquid)
            "commission_bps": 5.0, # commission in basis points (bps)
            "slippage_bps": 5.0, # slippage in basis points (bps)
        },
        "reward": {
            "risk_lambda":0.001 # risk penalty coefficient (lambda), a.k.a. risk aversion factor
            },
        "constraints": {
            "min_weight":0.0,
            "max_weight":1.0,
            "sum_to_one":True
            },
        "seed": 42
    },
    "SPLITS": { # date splits for training, validation, and testing
        "data_start":"2024-09-02", # start date of the entire dataset
        "data_end":"2025-09-02", # end date of the entire dataset
        "train":["2024-09-02","2025-06-15"], # training period
        "val":["2025-06-16","2025-07-15"], # validation period
        "test":["2025-07-16","2025-09-02"], # testing period
        "walk_forward": True, # whether to use walk-forward splits (e.g. sliding window)
        "wf_train_span_days": 180, # training window span in days for walk-forward
        "wf_test_span_days": 30, # testing window span in days for walk-forward
        "wf_step_days": 30 # step size in days to move the window forward
    },
    "RL": {
        "timesteps":10,
        "policy":"MlpPolicy",
        "gamma":0.99,
        "gae_lambda":0.95,
        "clip_range":0.2,
        "n_steps":1024,
        "batch_size":256,
        "learning_rate":3e-4,
        "ent_coef":0.0,
        "vf_coef":0.5,
        "max_grad_norm":0.5
    },
    "EVAL": {
        "plots":True,
        "reports_dir":"./reports"
    },
    "IO": {
        "models_dir":"./models",
        "tb_logdir":"./tb",
    }
}

## Imports

In [None]:
import os
import json
import gdown
import math
import glob
import random
import pytz

from datetime import timedelta, datetime

from typing import Dict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Reinforcement Learning
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor

# Deep Learning
import torch

## Set Computation Device

This section sets the computation device for training the model. It checks if a GPU is available and sets it as the device; otherwise, it defaults to CPU.

In [None]:
# run on cuda GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")

# run on Apple Silicon
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available. Using Apple Silicon GPU.")

# run on CPU (slow)
else:
    device = torch.device("cpu")
    print("CUDA and MPS are not available. Using CPU.")

## Set seeds for reproducibility

This section sets the random seeds for various libraries to ensure that the results are reproducible.

Note: It is good practice to type (set data types) for function and method parameters for better code maintainability.

In [None]:
def set_all_seeds(seed: int = 42):
    random.seed(seed) # seed for random module
    np.random.seed(seed) # seed for numpy module
    try:
        torch.manual_seed(seed) # seed for torch module
        if torch.cuda.is_available(): # seed for CUDA device
            torch.cuda.manual_seed_all(seed)
        elif torch.backends.mps.is_available(): # seed for Apple Silicon device
            torch.backends.mps.manual_seed_all(seed)
    except Exception:
        pass

set_all_seeds(CONFIG["ENV"]["seed"]) # set all seeds for reproducibility

# set annualization factors for different timeframes
ANNUALIZATION = {"1m":365*24*60,
                 "5m":365*24*12,
                 "15m":365*24*4,
                 "1h":365*24,
                 "1d":365}

## Fetch Data from Google Drive

This section handles downloading data from Google Drive. It supports two methods: downloading an entire folder or downloading specific files by their IDs. The data will be cached locally for faster subsequent loading.

In [None]:
ROOT_ID = CONFIG["DATA"]["drive_folder_id"] # suffix for Google Drive download URL
CACHE_DIR = CONFIG["DATA"]["cache_dir"] # data will be cached here for faster subsequent loading
FILE_IDS = CONFIG["DATA"]["file_ids"] # mapping of data to its URL suffix

download_all = True if FILE_IDS == {} else False # this determines whether to download entire folder or specific files by their IDs (s. CONFIG)

# create directory if it doesn't already exist
def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

# download entire folder from Google Drive
def download_drive_folder(root_id: str, out_dir: str):
    print("Mirroring Google Drive folder locally...")
    gdown.download_folder(id=root_id, output=out_dir, quiet=False, use_cookies=False)
    print("Folder mirroring complete.")

# download specific files from Google Drive by their file IDs
def targeted_download_by_ids(file_id_map: Dict[str, str], out_dir: str):
    ensure_dir(out_dir)
    for name, fid in file_id_map.items():
        # check if file already exist in local cache folder
        if os.path.exists(os.path.join(out_dir, name)) or os.path.exists(os.path.join(out_dir, f"{name}.parquet")):
            print(f"File {name} already exists in cache. Skipping download.")
            continue
        
        suffix = name if name.endswith(".parquet") else f"{name}.parquet" # ensure .parquet suffix
        out_path = os.path.join(out_dir, suffix) # full output path
        print(f"Downloading {name} -> {out_path}") 
        url = f"https://drive.google.com/uc?id={fid}" # prefix download URL
        gdown.download(url, out_path, quiet=False, use_cookies=False) # download file

# check if cache directory exists, if not create it
ensure_dir(CACHE_DIR)

if download_all: # download entire folder
    download_drive_folder(ROOT_ID, CACHE_DIR)
else: # download specific files by their IDs
    targeted_download_by_ids(FILE_IDS, CACHE_DIR)

print("Download step complete.")

## Load data

Once the data is downloaded, this section loads the data into a pandas DataFrame for further processing.

In [None]:
sampling = CONFIG["DATA"]["sampling"] # data sampling frequency (1m, 5m, 15m, 1h, 1d)
subfolder = CONFIG["DATA"]["structure"][sampling] # subfolder name inside the drive folder is determined by the sampling frequency
pattern_fmt = CONFIG["DATA"]["file_pattern"] # file naming pattern (s. CONFIG)
tickers = CONFIG["DATA"]["tickers"] # list of tickers to load
forward_fill = CONFIG["DATA"]["forward_fill"] # whether to forward fill missing data
drop_na_after_ffill = CONFIG["DATA"]["drop_na_after_ffill"] # whether to drop NA values after forward filling

# function to find parquet file path
def find_parquet_path(ticker: str, sampling: str) -> str:
    fname = pattern_fmt.format(TICKER=ticker, FREQ=sampling)
    # try subfolder first
    candidates = glob.glob(os.path.join(CACHE_DIR, "**", subfolder, fname), recursive=True)
    # if not found, try flat cache
    if not candidates:
        candidates = glob.glob(os.path.join(CACHE_DIR, "**", fname), recursive=True)
    # if still not found, try direct file in cache (for file_ids downloads)
    if not candidates:
        direct_path = os.path.join(CACHE_DIR, fname)
        if os.path.exists(direct_path):
            candidates = [direct_path]
    # if nothing is found at all, raise an error
    if not candidates:
        raise FileNotFoundError(f"Could not find {fname} under {CACHE_DIR}.")
    return candidates[0]

# function to localize timestamps and align dataframes
def localize_and_align(df: pd.DataFrame, tz_in: str = None, tz_out: str = None) -> pd.DataFrame:
    # convert millisecond timestamps to datetime
    if 'datetime' in df.columns:
        df['timestamp'] = pd.to_datetime(df['datetime'], unit='ms', utc=True)
        df = df.set_index('timestamp')
    # make column names lowercase for consistency
    cols = {c: c.lower() for c in df.columns}
    df = df.rename(columns=cols)
    return df.sort_index()

dfs = {}
for t in tickers:
    pth = find_parquet_path(t, sampling) # find file path
    tmp = pd.read_parquet(pth) # read parquet file
    tmp = localize_and_align(tmp) # standardize timestamps and column names
    if forward_fill: # forward fill missing data if specified
        tmp = tmp.ffill()
    if drop_na_after_ffill: # drop rows with NA values after forward filling
        tmp = tmp.dropna()
    dfs[t] = tmp # store in dictionary

common_index = None
# find common index across all dataframes
for t, df in dfs.items():
    common_index = df.index if common_index is None else common_index.intersection(df.index)
# reindex all dataframes to the common index and drop any remaining NA values
for t in tickers:
    dfs[t] = dfs[t].reindex(common_index).dropna()

# print the shape of each dataframe to ensure alignment
print({t: dfs[t].shape for t in tickers})

## Feature engineering

This section performs feature engineering on the loaded data. It includes creating new features, normalizing data, and preparing the dataset for training the reinforcement learning agent. To ensure statistical arbitrage strategy is imitated, create relevant features.

In [None]:
feat_cfg = CONFIG["ENV"]["features"]
price_col = CONFIG["DATA"]["price_point"]

# function to compute relative strength index (momentum indicator)
def compute_rsi(series: pd.Series, period: int = 14) -> pd.Series:
    delta = series.diff() # compute price changes
    up = (delta.clip(lower=0)).ewm(alpha=1/period, adjust=False).mean() # average of upward price changes
    down = (-delta.clip(upper=0)).ewm(alpha=1/period, adjust=False).mean() # average of downward price changes
    rs = up / (down + 1e-12) # relative strength ratio indicates if upward or downward momentum is stronger
    rsi = 100 - (100 / (1 + rs)) # converts RS to RSI bounded between 0 and 100
    return rsi

# function to create features dataframe
def make_features(df: pd.DataFrame, price_col: str, vol_window: int, rsi_period: int, volume_change: bool):
    out = pd.DataFrame(index=df.index) # initialize output dataframe
    out["ret"] = np.log(df[price_col]).diff(1) # log returns
    out["vol"] = out["ret"].rolling(vol_window).std().fillna(0.0) # rolling volatility
    out["rsi"] = compute_rsi(df[price_col], rsi_period).fillna(50.0) # relative strength index
    if "volume" in df.columns and volume_change: # log volume change if volume data is available
        out["volchg"] = np.log(df["volume"].replace(0, np.nan)).diff().fillna(0.0) # log volume change
    else:
        out["volchg"] = 0.0 # if no volume data, set to zero
    return out

features_by_ticker = {}
for t in tickers: # create features for each ticker
    fdf = make_features(dfs[t], price_col, feat_cfg["vol_window"], feat_cfg["rsi_period"], feat_cfg["volume_change"])
    features_by_ticker[t] = fdf # store features in dictionary

panel_cols = []
for t in tickers: # create multi-index columns
    for col in ["ret","vol","rsi","volchg"]:
        panel_cols.append((t, col))
panel = pd.concat([features_by_ticker[t][["ret","vol","rsi","volchg"]] for t in tickers], axis=1) # combine features horizontally
panel.columns = pd.MultiIndex.from_tuples(panel_cols, names=["ticker","feature"]) # set multi-index columns
panel = panel.dropna() # remove rows with missing values

print(panel.tail())
panel.describe()

## Feature scaling and state tensor construction

This section normalizes the features and constructs the state tensors required for training the reinforcement learning agent. State tensors are multi-dimensional arrays that represent the current state of the environment that the RL agent uses to make decisions.

In [None]:
lookback = CONFIG["ENV"]["lookback_window"] # lookback window 
normalize = CONFIG["ENV"]["features"]["normalize"] # normalization method

# rolling z-score normalization for features
def rolling_zscore(df: pd.DataFrame, window: int = 256) -> pd.DataFrame:
    mu = df.rolling(window).mean() # rolling mean
    sigma = df.rolling(window).std().replace(0, np.nan) # rolling std dev, replace 0 with NaN to avoid division by zero
    z = (df - mu) / (sigma + 1e-12) # z-score normalization
    return z.fillna(0.0) # fill NaNs with 0.0 and return z-scored dataframe

# build state tensor
def build_state_tensor(panel: pd.DataFrame, lookback: int, normalize: bool = False):
    # normalization step for each feature
    if normalize: # group by feature and apply z-score while preserving MultiIndex
        normalized = pd.DataFrame(index=panel.index) # empty dataframe with same index as features dataframes to hold scaled features
        for feature in panel.columns.unique(level=1): # iterate over features
            feature_data = panel.xs(feature, level=1, axis=1)
            z_scored = rolling_zscore(feature_data, window=max(lookback*2, 256)) # take double the lookback value for rolling
            for ticker in z_scored.columns: # reconstruct MultiIndex columns
                normalized[(ticker, feature)] = z_scored[ticker]
        normalized.columns = pd.MultiIndex.from_tuples(normalized.columns, names=["ticker", "feature"]) # set MultiIndex columns
    else: # no normalization, use raw features
        normalized = panel.copy()

    # organize data by tickers, features, and time
    tickers = sorted(panel.columns.unique(level=0)) # ensure consistent order
    features = sorted(panel.columns.unique(level=1)) # feature order
    times = normalized.index # time index

    # create sliding windows of lookback length
    # extract feature data for the window and stack into a 3D tensor (tickers x features x lookback)
    X, y_ret, inst_vol = [], [], [] # lists to hold state tensors, next returns, and current volatilities
    for i in range(lookback, len(times)-1): # iterate over time index in lookback window to have a sliding window with each time point as a step
        window_slice = normalized.iloc[i-lookback:i]
        frames = []
        for t in tickers: # save the windows of each ticker
            frames.append(window_slice[t].T.values)
        tensor = np.stack(frames, axis=0) # stack into 3D tensor
        X.append(tensor) # append tensor to list
        nxt = panel.iloc[i+1] # next period returns
        y_ret.append(np.array([nxt[(t, "ret")] for t in tickers], dtype=float)) # save to list
        cur = panel.iloc[i] # current period volatilities
        inst_vol.append(np.array([cur[(t, "vol")] for t in tickers], dtype=float)) # save to list

    X = np.array(X, dtype=np.float32) # convert list of state tensors of tensors to 3D numpy array
    y_ret = np.array(y_ret, dtype=np.float32) # convert list of next returns to 2D numpy array
    inst_vol = np.array(inst_vol, dtype=np.float32) # convert list of volatilities to 2D numpy array
    return X, y_ret, inst_vol, tickers, features, times[lookback+1:]

X_all, R_all, VOL_all, TICKER_ORDER, FEAT_ORDER, TIME_INDEX = build_state_tensor(
    panel, lookback=lookback, normalize=normalize
)

print("State tensor:", X_all.shape, "Returns:", R_all.shape, "InstVol:", VOL_all.shape)

## Define Splits and Adjust Timezones

This section defines the training and validation splits for the dataset. It ensures that the data is divided appropriately to train the model, validate its performance during training, and test its final performance on unseen data. Also, it adjusts the timezones of the datetime indices to ensure consistency across the dataset. This is necessary in case the data comes from multiple sources with different timezone settings.

In [None]:
# function to create boolean mask for date slicing
def date_slice_mask(times: pd.DatetimeIndex, start: str, end: str):
    # convert input dates to UTC timestamps
    start_ts = pd.Timestamp(start).tz_localize('UTC')
    end_ts = pd.Timestamp(end).tz_localize('UTC')
    
    # ensure time index is in UTC
    if times.tz is None: # if naive, localize to UTC
        times = times.tz_localize('UTC')
    elif times.tz != pytz.UTC: # if timezone-aware but not UTC, convert to UTC
        times = times.tz_convert('UTC')
        
    return (times >= start_ts) & (times <= end_ts) # return boolean mask for splits later on

def build_splits(times: pd.DatetimeIndex, cfg: dict):
    splits_cfg = CONFIG["SPLITS"]
    
    # ensure time index is in UTC
    if times.tz is None: # if naive, localize to UTC
        times = times.tz_localize('UTC')
    elif times.tz != pytz.UTC: # if timezone-aware but not UTC, convert to UTC
        times = times.tz_convert('UTC')
    
    # if not using walk-forward, create a single static split
    if not splits_cfg["walk_forward"]:
        m_train = date_slice_mask(times, splits_cfg["train"][0], splits_cfg["train"][1]) # create boolean mask for training period
        m_val   = date_slice_mask(times, splits_cfg["val"][0], splits_cfg["val"][1]) # create boolean mask for validation period
        m_test  = date_slice_mask(times, splits_cfg["test"][0], splits_cfg["test"][1]) # create boolean mask for testing period
        return [{"name":"BaseSplit","train":m_train,"val":m_val,"test":m_test}]
    # if using walk-forward, create multiple overlapping splits
    else:
        # create UTC timestamps for start and end of entire dataset
        start = pd.Timestamp(splits_cfg["data_start"]).tz_localize('UTC')
        end   = pd.Timestamp(splits_cfg["data_end"]).tz_localize('UTC')
        spans = [] # list to hold all splits
        cur_train_start = start # initialize current training start date
        while True: # loop to create overlapping splits
            train_end = cur_train_start + timedelta(days=splits_cfg["wf_train_span_days"]) # calculate training end date
            test_end  = train_end + timedelta(days=splits_cfg["wf_test_span_days"]) # calculate testing end date
            if test_end > end: # if testing end date exceeds dataset end, stop
                break
            m_train = (times >= cur_train_start) & (times <= train_end) # create boolean mask for training period
            m_val   = (times > train_end) & (times <= train_end) # create boolean mask for validation period
            m_test  = (times > train_end) & (times <= test_end) # create boolean mask for testing period
            # append current split as dictionaries to list
            spans.append({
                "name": f"WF_{cur_train_start.strftime('%Y%m%d')}_{test_end.strftime('%Y%m%d')}",
                "train": m_train,
                "val": m_val,
                "test": m_test
            })
            cur_train_start = cur_train_start + timedelta(days=splits_cfg["wf_step_days"]) # move training start date forward by step size
        return spans

SPLITS = build_splits(TIME_INDEX, CONFIG["SPLITS"])
print(f"Built {len(SPLITS)} split(s). Example:", SPLITS[0]["name"] if SPLITS else "None")

## Environment Configuration

This is a custom Gymnasium environment for portfolio optimization.

In [None]:
class PortfolioWeightsEnv(gym.Env):
    metadata = {"render_modes": []}

    def __init__(self, X, R, VOL, tickers, lookback, cfg_env, sampling="1h"):
        super().__init__()
        self.X = X # state tensor
        self.R = R # next period returns
        self.VOL = VOL # current period volatilities
        self.tickers = tickers # list of assets
        self.lookback = lookback # lookback window
        self.cfg = cfg_env # environment configuration taken from CONFIG
        self.sampling = sampling # data sampling frequency

        self.n_assets = len(tickers) # number of assets
        self.include_cash = cfg_env["include_cash"] # whether to include cash as an asset
        self.dim_action = self.n_assets + (1 if self.include_cash else 0) # action space dimension: one weight per asset (+1 for cash if included)

        obs_dim = self.n_assets * self.X.shape[2] * self.lookback # observation space dimension: tickers x features x lookback
        self.observation_space = spaces.Box(low=-10, high=10, shape=(obs_dim,), dtype=np.float32)
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(self.dim_action,), dtype=np.float32) # action space: weights between 0 and 1 (no shorting possible here)

        self.commission = cfg_env["transaction_costs"]["commission_bps"] / 1e4 # commission in decimal form
        self.slippage = cfg_env["transaction_costs"]["slippage_bps"] / 1e4 # slippage in decimal form
        self.risk_lambda = cfg_env["reward"]["risk_lambda"] # defines risk aversion in reward function

        # weight constraints
        self.min_w = cfg_env["constraints"]["min_weight"] # minimum weight per asset
        self.max_w = cfg_env["constraints"]["max_weight"] # maximum weight per asset
        self.sum_to_one = cfg_env["constraints"]["sum_to_one"] # whether weights should sum to one (irrelevant if shorting is allowed)

        self.reset(seed=cfg_env.get("seed", 42)) # ??????

    def _to_obs(self, t): # ??????
        arr = self.X[t].reshape(-1).astype(np.float32)
        return arr

    # project raw actions to valid weights
    def _project_weights(self, a):
        if self.sum_to_one: # if weights must sum to one, use softmax
            expo = np.exp(a - np.max(a))
            w = expo / np.sum(expo)
        else: # otherwise, just clip to min and max weights
            w = np.clip(a, self.min_w, self.max_w)
        if not self.cfg["shorting"]: # if shorting is not allowed, clip weights to [0, 1]
            w = np.clip(w, 0.0, 1.0)
        return w # return projected weights

    # reset environment to initial state
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.t = 0 # reset time index
        self.portfolio_value = 1.0 # reset portfolio value
        self.w = np.ones(self.dim_action) / self.dim_action # initialize equal weights
        obs = self._to_obs(self.t) # ???????
        return obs, {}

    def step(self, action):
        w_target = self._project_weights(action) # get new target weights from raw actions
        turnover = np.sum(np.abs(w_target - self.w)) # calculate turnover as sum of absolute weight changes
        trading_cost = (self.commission + self.slippage) * turnover # total trading cost based on turnover

        asset_w_prev = self.w[:self.n_assets]
        asset_ret = np.dot(asset_w_prev, self.R[self.t])
        inst_vol = np.dot(asset_w_prev, self.VOL[self.t])

        reward = asset_ret - trading_cost - self.risk_lambda * inst_vol

        self.portfolio_value *= math.exp(asset_ret - trading_cost)

        self.w = w_target
        self.t += 1
        terminated = (self.t >= len(self.R)-1)
        truncated = False

        obs = self._to_obs(self.t) if not terminated else self._to_obs(self.t-1)
        info = {"portfolio_value": self.portfolio_value, "turnover": turnover, "inst_vol": inst_vol, "asset_ret": asset_ret}
        return obs, reward, terminated, truncated, info


In [None]:

def slice_by_mask(X, R, VOL, mask: np.ndarray):
    idx = np.where(mask)[0]
    return X[idx], R[idx], VOL[idx]

def make_env_from_mask(mask, name="env"):
    X_s, R_s, V_s = slice_by_mask(X_all, R_all, VOL_all, mask)
    env = PortfolioWeightsEnv(X_s, R_s, V_s, TICKER_ORDER, CONFIG["ENV"]["lookback_window"], CONFIG["ENV"], sampling=CONFIG["DATA"]["sampling"])
    env = Monitor(env, filename=None)
    return env

In [None]:
def annualize_factor(sampling: str):
    return ANNUALIZATION.get(sampling, 365*24)

def compute_metrics(equity_curve: pd.Series, sampling: str, turnover_series: pd.Series = None):
    ret = equity_curve.pct_change().dropna()
    ann = annualize_factor(sampling)
    mu = ret.mean() * ann
    sigma = ret.std() * math.sqrt(ann)
    sharpe = mu / (sigma + 1e-12)
    downside = ret[ret < 0].std() * math.sqrt(ann)
    sortino = mu / (downside + 1e-12)
    if len(equity_curve) > 1:
        # Calculate years based on number of samples and sampling frequency
        if isinstance(equity_curve.index, pd.DatetimeIndex):
            dt_years = (equity_curve.index[-1] - equity_curve.index[0]).total_seconds() / (365 * 24 * 3600)
        else:
            # If using RangeIndex, calculate based on sampling frequency
            samples = len(equity_curve)
            samples_per_year = annualize_factor(sampling)
            dt_years = samples / samples_per_year
        dt_years = float(dt_years) if float(dt_years) != 0 else 1e-12
        cagr = (equity_curve.iloc[-1] / equity_curve.iloc[0]) ** (1/dt_years) - 1
    else:
        cagr = 0.0
    cummax = equity_curve.cummax()
    dd = (equity_curve / cummax - 1).min()
    maxdd = float(dd)
    calmar = mu / (abs(maxdd) + 1e-12)
    hit_ratio = (ret > 0).mean()
    turnover = turnover_series.mean() if turnover_series is not None and len(turnover_series)>0 else np.nan
    return {"CAGR": cagr, "Sharpe": sharpe, "Sortino": sortino, "MaxDrawdown": maxdd, "Calmar": calmar, "Volatility": sigma, "Turnover": turnover, "HitRatio": hit_ratio}

def plot_series(series: pd.Series, title: str):
    plt.figure(figsize=(10,4))
    plt.plot(series.index, series.values)
    plt.title(title); plt.xlabel("Time"); plt.ylabel("Value"); plt.show()

def backtest_env(env: gym.Env, model=None):
    # Get the unwrapped environment
    unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
    
    obs, _ = env.reset()
    pv, turns = [], []
    for t in range(len(unwrapped.R)-1):
        if model is None:
            action = np.ones(unwrapped.dim_action)/unwrapped.dim_action
        else:
            action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, trunc, info = env.step(action)
        pv.append(info["portfolio_value"])
        turns.append(info["turnover"])
        if done:
            break
    idx = pd.RangeIndex(start=0, stop=len(pv), step=1)
    ec = pd.Series(pv, index=idx)
    to = pd.Series(turns, index=idx)
    return ec, to

In [None]:
ensure_dir(CONFIG["IO"]["models_dir"])
ensure_dir(CONFIG["EVAL"]["reports_dir"])

RESULTS = []

for split in SPLITS:
    print(f"\n=== Training on split: {split['name']} ===")
    train_env = make_env_from_mask(split["train"], name=f"{split['name']}_train")
    eval_env  = make_env_from_mask(split["test"], name=f"{split['name']}_test")

    vec_train = DummyVecEnv([lambda: train_env])
    vec_eval  = DummyVecEnv([lambda: eval_env])

    model = PPO(
        policy=CONFIG["RL"]["policy"],
        env=vec_train,
        gamma=CONFIG["RL"]["gamma"],
        gae_lambda=CONFIG["RL"]["gae_lambda"],
        clip_range=CONFIG["RL"]["clip_range"],
        n_steps=CONFIG["RL"]["n_steps"],
        batch_size=CONFIG["RL"]["batch_size"],
        learning_rate=CONFIG["RL"]["learning_rate"],
        ent_coef=CONFIG["RL"]["ent_coef"],
        vf_coef=CONFIG["RL"]["vf_coef"],
        max_grad_norm=CONFIG["RL"]["max_grad_norm"],
        tensorboard_log=CONFIG["IO"]["tb_logdir"],
        device=device,
        verbose=1
    )

    eval_callback = EvalCallback(
        vec_eval, 
        best_model_save_path=CONFIG["IO"]["models_dir"],
        log_path=CONFIG["IO"]["models_dir"], 
        eval_freq=1000,  # Reduced from 10000 to 1000 for more frequent feedback
        deterministic=True, 
        render=False
    )
    
    model.learn(total_timesteps=CONFIG["RL"]["timesteps"], callback=eval_callback)
    model_path = os.path.join(CONFIG["IO"]["models_dir"], f"ppo_{split['name']}.zip")
    model.save(model_path)
    print("Saved model:", model_path)

    test_env = make_env_from_mask(split["test"], name=f"{split['name']}_test")
    ec, to = backtest_env(test_env, model=model)

    idx = np.where(split["test"])[0]
    R_test = R_all[idx]
    ew = np.ones(len(TICKER_ORDER))/len(TICKER_ORDER)
    ec_bench = [1.0]
    # Only iterate through the same number of returns as we have in ec.index
    for i in range(len(ec.index)-1):
        ec_bench.append(ec_bench[-1]*math.exp(np.dot(ew, R_test[i])))
    ec_bench = pd.Series(ec_bench, index=ec.index)

    bh_btc, bh_eth = [1.0], [1.0]
    # Only iterate through the same number of returns as we have in ec.index
    for i in range(len(ec.index)-1):
        bh_btc.append(bh_btc[-1]*math.exp(R_test[i][0]))
        bh_eth.append(bh_eth[-1]*math.exp(R_test[i][1]))
    bh_btc = pd.Series(bh_btc, index=ec.index)
    bh_eth = pd.Series(bh_eth, index=ec.index)

    m_model = compute_metrics(ec, CONFIG["DATA"]["sampling"], to)
    m_ew    = compute_metrics(ec_bench, CONFIG["DATA"]["sampling"])
    m_btc   = compute_metrics(bh_btc, CONFIG["DATA"]["sampling"])
    m_eth   = compute_metrics(bh_eth, CONFIG["DATA"]["sampling"])

    RESULTS.append({"split": split["name"], "model": m_model, "equal_weight": m_ew, "buy_and_hold_BTC": m_btc, "buy_and_hold_ETH": m_eth})

    if CONFIG["EVAL"]["plots"]:
        plot_series(ec, f"Equity Curve — PPO ({split['name']})")
        plot_series((ec / ec.cummax()) - 1.0, f"Drawdown — PPO ({split['name']})")
        plot_series(ec_bench, f"Equity Curve — Equal-Weight Hold ({split['name']})")

print("Done. RESULTS collected.")

In [None]:

rows = []
for res in RESULTS:
    row = {"split": res["split"]}
    for k, metrics in res.items():
        if k == "split":
            continue
        for mname, mval in metrics.items():
            row[f"{k}_{mname}"] = mval
    rows.append(row)

df_results = pd.DataFrame(rows)
df_results

In [None]:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
ensure_dir(CONFIG["EVAL"]["reports_dir"])
out_json = os.path.join(CONFIG["EVAL"]["reports_dir"], f"metrics_{ts}.json")
out_csv  = os.path.join(CONFIG["EVAL"]["reports_dir"], f"metrics_{ts}.csv")
df_results.to_csv(out_csv, index=False)
with open(out_json, "w") as f:
    json.dump(RESULTS, f, indent=2)
print("Saved:", out_json, "and", out_csv)