# Reinforcement Learning Agent for Statistical Arbitrage

This notebook demonstrates how to configure and train a reinforcement learning agent for statistical arbitrage using a snapshot model. The agent will learn to identify and exploit statistical arbitrage opportunities in financial markets.

## Install necessary libraries

The following libraries are required for this notebook. If you haven't installed them yet, you can do so using running the cell below or by using pip install in your terminal.

In [None]:
"""
%pip -q install -U numpy pandas pyarrow gdown gymnasium stable-baselines3 torch matplotlib tensorboard
"""

## Configuration

The configuration section sets up the parameters for the reinforcement learning model.

In [None]:
CONFIG = {
    "DATA": {
        "forward_fill": True, # forward fill missing data
        "drop_na_after_ffill": True, # drop rows with NA values after forward filling
        "cache_dir": "./data_cache", # local cache directory to store downloaded files
        "timestamp_format": "%Y-%m-%d %H:%M:%S",
        "asset_price_format": "{ASSET}_close",
        "pair_feature_format": "{ASSET1}_{ASSET2}_{FEATURE}",
        "timestamp_col": "timestamp",
        "file_names": {
            "features": "historical_pairs_with_spreads"
            },
        "file_ids": {
            "price_folder_id": "1uXEBUyySypdsW_ZqL-RZ3d1bWdIZisij",
            "ADA_1h": "1ydaR3T68ReE_7j5t3wZbj0F-zdRPYoxg",
            "APT_1h": "1CxG9N2bqWPs9fOPOUryYNtHmONXo4SRi",
            "ARB_1h": "136FSMlAW3XHG8WocxxTEcSKiLMUBwWMi",
            "ATOM_1h": "1mhSQgEwRHn3nvu8Qu1ctQGzdW5JuxATR",
            "BTC_1h": "1-sBNQpEFGEpVO3GDFCkSZiV3Iaqp2vB_",
            "DOGE_1h": "14XlkoQMYr8WWecGninAKUavvjB3qNxk0",
            "DOT_1h": "1kCWB4ZZu3FnadbAquTa3Rcdcwkhnq6-s",
            "ENA_1h": "1TYTxexlD24cs7qmhyVoTacX7lqGOsfky",
            "ETC_1h": "1coBd9QiEX03MndMgX5_549mOPyY23ZcI",
            "ETH_1h": "1kj8G1scpFuEYTTXKEUzF9pwgGI2WFFL9",
            "HBAR_1h": "1LVseecBvXKl3Wl9hbPLsROYKR1Gp8zhQ",
            "LINK_1h": "1ZLEraxdV3H8jpf1FmPeVs1ySL7TzMvH5",
            "LTC_1h": "18d3_jD-tuYTQQR2QOwXupckeDgqvAIvx",
            "NEAR_1h": "1PqI2hD2gbDxUaRDPnJpvDNH5wPYv47G6",
            "SOL_1h": "17CjYYSEsTEqBdmm51zGLgmpkslxxjiji",
            "SUI_1h": "1bToOJts-x2Ia48tqXcMs4qFIQ5OV1lAP",
            "TON_1h": "1SARYo5zB6AunG82kw7KGF4Nird3lQ4zB",
            "TRX_1h": "1FlcZo1WRtKFQMbBrsb61Lp3_pplISW4U",
            "UNI_1h": "15L-eKWliyg9MBKuznlZZ-FJzm52Ovt20",
            "WLD_1h": "1XqD1K4-YZzPxYFHKHY3KmKWnnwi3zO20",
            "XLM_1h": "1_3E5-mORLWh3X16Hi0ccHwzVKg5QxoT4",
            "XRP_1h": "1crt2g_t0qpYnaGpcozl35yDeHhd4tmi4",
            "features": "1AWXfCp2egL4d9D1lntAG6cKlzJJ-62c4"
            }
    },
    "COINTEGRATION": {
        "training_window_days": 2, # number of days to use for training after cointegration period
    },
    "ENV": {
        "include_cash": True, # include cash as a third asset
        "shorting": True, # allow shorting
        "trading_window_days": "2D",
        "sliding_window_step": "1D",
        "lookback_window": 64, # lookback for feature slicing, meaning that the observation for each step will include data from the last 'lookback_window' time steps
        "transaction_costs": { # transaction costs settings for exchange (e.g. Hyperliquid)
            "commission_bps": 5.0, # commission in basis points (bps)
            "slippage_bps": 5.0, # slippage in basis points (bps)
        },
        "reward": {
            "risk_lambda":0.001 # risk penalty coefficient (lambda), a.k.a. risk aversion factor
        },
        "leverage": {
            "use_leverage": False,
            "long_cap": 2.0,  # maximum leverage for long positions
            "short_cap": 2.0, # maximum leverage for short positions
            "use_asymmetric": True, # whether to allow different caps for long and short positions
        },
        "constraints": {
            "min_weight": -2.0, # minimum weight for each asset (for shorting)
            "max_weight": 2.0,  # maximum weight for each asset (for leverage)
            "sum_to_one": False # whether weights must sum to one (need to be False for leverage)
        },
        "seed": 42 # random seed for reproducibility
    },
    "SPLITS": { # date splits for training, validation, and testing
        "data_start":"2024-09-02", # start date of the entire dataset
        "data_end":"2025-09-02", # end date of the entire dataset
        "train":["2024-09-02","2025-06-15"], # training period
        "val":["2025-06-16","2025-07-15"], # validation period
        "test":["2025-07-16","2025-09-02"], # testing period
        "walk_forward": True, # whether to use walk-forward splits (e.g. sliding window)
        "wf_train_span_days": 180, # training window span in days for walk-forward
        "wf_test_span_days": 30, # testing window span in days for walk-forward
        "wf_step_days": 30 # step size in days to move the window forward
    },
    "RL": {
        "timesteps":10,
        "policy":"MlpPolicy",
        "gamma":0.99,
        "gae_lambda":0.95,
        "clip_range":0.2,
        "n_steps":1024,
        "batch_size":256,
        "learning_rate":3e-4,
        "ent_coef":0.0,
        "vf_coef":0.5,
        "max_grad_norm":0.5
    },
    "EVAL": {
        "plots":True,
        "reports_dir":"./reports"
    },
    "IO": {
        "models_dir":"./models",
        "tb_logdir":"./tb"
    }
}

## Imports

In [None]:
import os
import json
import gdown
import math
import glob
import random
import pytz
import csv

import time

from datetime import timedelta, datetime

from typing import Iterable, Tuple, Dict, List, Set, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re

# Reinforcement Learning
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor

# Deep Learning
import torch

print("Imports complete.")

## Helper Functions

In [None]:
# set annualization factors for different timeframes
ANNUALIZATION = {"1m":365*24*60,
                 "5m":365*24*12,
                 "15m":365*24*4,
                 "1h":365*24,
                 "1d":365}

# create directory if it doesn't already exist
def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

## Set Computation Device

This section sets the computation device for training the model. It checks if a GPU is available and sets it as the device; otherwise, it defaults to CPU.

In [None]:
# run on cuda GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")

# run on Apple Silicon GPU
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available. Using Apple Silicon GPU.")

# run on CPU (slow)
else:
    device = torch.device("cpu")
    print("CUDA and MPS are not available. Using CPU.")

## Set Seeds

This section sets the random seeds for various libraries to ensure that the results are reproducible.

Note: It is good practice to type (set data types) for function and method parameters for better code maintainability.

In [None]:
def set_all_seeds(seed: int = 42):
    random.seed(seed) # seed for random module
    np.random.seed(seed) # seed for numpy module
    try:
        torch.manual_seed(seed) # seed for torch module
        if torch.cuda.is_available(): # seed for CUDA device
            torch.cuda.manual_seed_all(seed)
        elif torch.backends.mps.is_available(): # seed for Apple Silicon device
            torch.backends.mps.manual_seed_all(seed)
    except Exception:
        pass

set_all_seeds(CONFIG["ENV"]["seed"]) # set all seeds for reproducibility

## Download Feature Data

In this section, we retrieve the .csv file created during feature engineering.

In [None]:
def download_file(file_name: str, file_id: str, out_dir: str):
    ensure_dir(out_dir)

    out_path = os.path.join(out_dir, f"{file_name}.csv")  # full output path

    if os.path.exists(out_path):
        print(f"File {file_name} already exists in cache. Skipping download.")
        return

    try:
        print(f"Downloading {file_name} -> {out_path}")
        url = f"https://drive.google.com/uc?id={file_id}"
        success = gdown.download(url, out_path, quiet=False, use_cookies=False, verify=False)
        return success
    except Exception as e:
        print(f"Download attempt failed for {file_name}. Error: {str(e)}")
    
# donwload features
file_name = CONFIG["DATA"]["file_names"]["features"]
file_id = CONFIG["DATA"]["file_ids"]["features"]
cache_dir = CONFIG["DATA"]["cache_dir"]

print("Downloading cointegration data...")
success = download_file(file_name, file_id, cache_dir)
print("Download complete.")

## Load Feature Data

Once the data is downloaded, this section loads the data into a pandas DataFrame for later processing.

In [None]:
def load_csv_to_df(
    path: str,
    parse_timestamp_col: str | None = "timestamp",
    encoding: str = "utf-8-sig",
    **read_csv_kwargs,
) -> pd.DataFrame:
    """
    Load a CSV into a pandas DataFrame.

    Parameters
    ----------
    path : str
        Filesystem path to the CSV.
    parse_timestamp_col : str | None
        If provided and present in the CSV, this column will be parsed to datetime.
        Set to None to skip datetime parsing.
    **read_csv_kwargs :
        Extra arguments passed to `pd.read_csv` (e.g., sep, dtype, usecols).

    Returns
    -------
    pd.DataFrame
    """
    # Try delimiter sniffing first
    with open(path, "r", encoding=encoding, newline="") as f:
        sample = f.read(2048)
    try:
        dialect = csv.Sniffer().sniff(sample, delimiters=[",",";","|","\t"])
        sep = dialect.delimiter
    except csv.Error:
        sep = ","  # fallback

    # Parse header-only to check for timestamp col presence
    head = pd.read_csv(path, sep=sep, encoding=encoding, nrows=0)
    if parse_timestamp_col and parse_timestamp_col in head.columns:
        read_csv_kwargs = {
            **read_csv_kwargs,
            "parse_dates": [parse_timestamp_col],
            "infer_datetime_format": True,
        }

    df = pd.read_csv(path, sep=sep, encoding=encoding, engine="python", **read_csv_kwargs)
    return df


# load features
file_name = CONFIG["DATA"]["file_names"]["features"]
cache_dir = CONFIG["DATA"]["cache_dir"]
file_path = os.path.join(cache_dir, f"{file_name}.csv")
features_df = load_csv_to_df(file_path, parse_timestamp_col="timestamp")

# print dataframe info
print("Features DataFrame Info:")
print(features_df.info())

## Identify Feature Structure

In [None]:
def identify_assets_features_pairs(
    df: pd.DataFrame,
    single_asset_format: str,
    pair_feature_format: str
    ) -> tuple[list[str], list[str], list[tuple[str, str]]]:
    """
    ----------
    Identify:
      - Distinct assets from columns matching `single_asset_format`
      - Distinct features from columns matching `pair_feature_format`
      - Distinct unordered asset pairs found in cross-asset feature columns

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe whose columns follow the naming conventions.
    single_asset_format : str
        Format string describing single-asset columns (e.g., "{ASSET}_close").
    pair_feature_format : str
        Format string describing pair-feature columns (e.g., "{ASSET1}_{ASSET2}_{FEATURE}").

    Returns
    -------
    (assets, features, asset_pairs)
        assets       : sorted list[str] of unique asset tickers (from single asset columns)
        features     : sorted list[str] of unique feature names
        asset_pairs  : sorted list[tuple[str, str]] of unique unordered pairs
    """

    # Build regex patterns from format strings
    # Example: "{ASSET}_close" → r"^(?P<ASSET>[A-Za-z0-9]+)_close$"
    def format_to_regex(fmt: str) -> re.Pattern:
        pattern = re.escape(fmt)
        # Replace placeholders like {ASSET}, {ASSET1}, {FEATURE}
        pattern = re.sub(r"\\\{(\w+)\\\}", r"(?P<\1>[A-Za-z0-9]+)", pattern)
        return re.compile(f"^{pattern}$")

    single_asset_pattern = format_to_regex(single_asset_format)
    pair_feature_pattern = format_to_regex(pair_feature_format)

    assets: Set[str] = set()
    features: Set[str] = set()
    pairs: Set[Tuple[str, str]] = set()

    for col in df.columns:
        if col == "timestamp":
            continue

        # Try single-asset pattern
        m1 = single_asset_pattern.match(col)
        if m1:
            assets.add(m1.group("ASSET"))
            continue

        # Try pair-feature pattern
        m2 = pair_feature_pattern.match(col)
        if m2:
            a1, a2, feat = m2.group("ASSET1"), m2.group("ASSET2"), m2.group("FEATURE")
            pair = tuple(sorted((a1, a2)))
            assets.update([a1, a2])
            pairs.add(pair)
            features.add(feat)

    return sorted(assets), sorted(features), sorted(pairs)

# identify assets, features, and asset pairs
single_asset_format = CONFIG["DATA"]["asset_price_format"]
pair_feature_format = CONFIG["DATA"]["pair_feature_format"]
assets, features, asset_pairs = identify_assets_features_pairs(
    features_df,
    single_asset_format,
    pair_feature_format
    )

print(f"Identified {len(assets)} assets: {assets}")
print(f"Identified {len(features)} features: {features}")
print(f"Identified {len(asset_pairs)} asset pairs: {asset_pairs}")

## Build Time Intervals

In [None]:
def build_time_intervals(
    df: pd.DataFrame,
    window: pd.Timedelta | str,
    step: Optional[pd.Timedelta | str] = None,
    timestamp_col: str = "timestamp",
    include_last_partial: bool = False,
) -> list[tuple[pd.Timestamp, pd.Timestamp]]:
    """
    Return fixed-length time intervals over the DataFrame's time span.

    Parameters
    ----------
    df : pd.DataFrame
        Must contain a datetime-like 'timestamp' column or have a DatetimeIndex.
    window : pd.Timedelta | str
        Size of each window, e.g. '2D', '60min', '15T'.
    step : pd.Timedelta | str | None
        Step between consecutive window starts. Defaults to `window` (non-overlapping).
        Use a smaller step than `window` for sliding/overlapping windows.
    timestamp_col : str
        Name of the timestamp column (ignored if index is a DatetimeIndex).
    include_last_partial : bool
        If True, include the trailing partial window shorter than `window`.

    Returns
    -------
    list[tuple[pd.Timestamp, pd.Timestamp]]
        Half-open intervals [start, end).
    """
    W = pd.Timedelta(window)
    S = pd.Timedelta(step) if step is not None else W

    # Extract, sanitize, and sort timestamps
    if timestamp_col in df.columns:
        ts = pd.to_datetime(df[timestamp_col]).dropna().sort_values()
    elif isinstance(df.index, pd.DatetimeIndex):
        ts = pd.Series(df.index).dropna().sort_values()
    else:
        raise ValueError(
            f"Timestamp column '{timestamp_col}' not found and index is not DatetimeIndex."
        )

    intervals: list[tuple[pd.Timestamp, pd.Timestamp]] = []
    if ts.empty:
        return intervals

    t_min = ts.iloc[0]
    t_max = ts.iloc[-1]
    cur = t_min

    while cur < t_max:
        end = cur + W
        if end <= t_max:
            intervals.append((cur, end))
        elif include_last_partial:
            intervals.append((cur, t_max))
            break
        else:
            break
        cur = cur + S

    return intervals

# build intervals
window = CONFIG["ENV"]["trading_window_days"]
step = CONFIG["ENV"]["sliding_window_step"]
timestamp_col = CONFIG["DATA"]["timestamp_col"]

intervals = build_time_intervals(
    features_df,
    window,
    step,
    timestamp_col,
    include_last_partial=False
)

# print interval info
print(f"Built {len(intervals)} time intervals with window={window} and step={step}.")
print("First 3 intervals:")
for start, end in intervals[:3]:
    print(f"  {start} to {end}")

## Identify Feature Space

In [None]:
def is_timeframe_valid(
    df: pd.DataFrame,
    pair: tuple[str, str],
    start: pd.Timestamp,
    end: pd.Timestamp,
    feature_name: str = "spread",
    pair_feature_format: str = "{ASSET1}_{ASSET2}_{FEATURE}",
    timestamp_col: str = "timestamp"
) -> bool:
    """
    Check if the given time frame has complete data for the specified asset pair.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing the data.
    pair : tuple[str, str]
        Asset pair (asset1, asset2).
    start : pd.Timestamp
        Start of the time frame (inclusive).
    end : pd.Timestamp
        End of the time frame (exclusive).
    pair_feature_format : str
        Format string for pair-feature columns (e.g., "{ASSET1}_{ASSET2}_{FEATURE}").
    timestamp_col : str
        Name of the timestamp column.

    Returns
    -------
    bool
        True if the time frame is valid (no missing data), False otherwise.
    """

    asset1, asset2 = pair

    # Construct expected column names
    feature_col_name = pair_feature_format.format(ASSET1=asset1, ASSET2=asset2, FEATURE=feature_name)

    # Filter DataFrame to the specified time frame
    mask = (df[timestamp_col] >= start) & (df[timestamp_col] < end)
    df_timeframe = df.loc[mask, [timestamp_col, feature_col_name]]

    # Check for missing values in any of the required columns
    if df_timeframe.isnull().values.any():
        return False

    return True

valid_intervails_per_pair = {}
for pair in asset_pairs:
    valid_intervals = []
    for start, end in intervals:
        if is_timeframe_valid(
            features_df,
            pair,
            start,
            end,
            feature_name="spread",
            pair_feature_format=pair_feature_format,
            timestamp_col=timestamp_col):

            valid_intervals.append((start, end))
    valid_intervails_per_pair[pair] = valid_intervals
    print(f"Pair {pair} has {len(valid_intervals)} valid intervals out of {len(intervals)} total intervals.")

print("First 3 valid intervals for first 3 pairs:")
for pair in list(valid_intervails_per_pair.keys())[:3]:
    print(f"Pair {pair}:")
    for start, end in valid_intervails_per_pair[pair][:3]:
        print(f"  {start} to {end}")

## Identify Feature Splits

In [None]:
feat_cfg = CONFIG["ENV"]["features"]
price_col = CONFIG["DATA"]["price_point"]

# based on the columns in the file identify...
# ...assets
# ...pairs
# ...time frames
# ...features

# function to compute relative strength index (momentum indicator)
def compute_rsi(series: pd.Series, period: int = 14) -> pd.Series:
    delta = series.diff() # compute price changes
    up = (delta.clip(lower=0)).ewm(alpha=1/period, adjust=False).mean() # average of upward price changes
    down = (-delta.clip(upper=0)).ewm(alpha=1/period, adjust=False).mean() # average of downward price changes
    rs = up / (down + 1e-12) # relative strength ratio indicates if upward or downward momentum is stronger
    rsi = 100 - (100 / (1 + rs)) # converts RS to RSI bounded between 0 and 100
    return rsi

# function to create features dataframe
def make_features(df: pd.DataFrame, price_col: str, vol_window: int, rsi_period: int, volume_change: bool):
    out = pd.DataFrame(index=df.index) # initialize output dataframe
    out["ret"] = np.log(df[price_col]).diff(1) # log returns
    out["vol"] = out["ret"].rolling(vol_window).std().fillna(0.0) # rolling volatility
    out["rsi"] = compute_rsi(df[price_col], rsi_period).fillna(50.0) # relative strength index
    if "volume" in df.columns and volume_change: # log volume change if volume data is available
        out["volchg"] = np.log(df["volume"].replace(0, np.nan)).diff().fillna(0.0) # log volume change
    else:
        out["volchg"] = 0.0 # if no volume data, set to zero
    return out

features_by_ticker = {}
for t in tickers: # create features for each ticker
    fdf = make_features(dfs[t], price_col, feat_cfg["vol_window"], feat_cfg["rsi_period"], feat_cfg["volume_change"])
    features_by_ticker[t] = fdf # store features in dictionary

panel_cols = []
for t in tickers: # create multi-index columns
    for col in ["ret","vol","rsi","volchg"]:
        panel_cols.append((t, col))
panel = pd.concat([features_by_ticker[t][["ret","vol","rsi","volchg"]] for t in tickers], axis=1) # combine features horizontally
panel.columns = pd.MultiIndex.from_tuples(panel_cols, names=["ticker","feature"]) # set multi-index columns
panel = panel.dropna() # remove rows with missing values

print(panel.tail())
panel.describe()

## Feature scaling and state tensor construction

This section normalizes the features and constructs the state tensors required for training the reinforcement learning agent. State tensors are multi-dimensional arrays that represent the current state of the environment that the RL agent uses to make decisions.

In [None]:
lookback = CONFIG["ENV"]["lookback_window"] # lookback window 
normalize = CONFIG["ENV"]["features"]["normalize"] # normalization method

# rolling z-score normalization for features
def rolling_zscore(df: pd.DataFrame, window: int = 256) -> pd.DataFrame:
    mu = df.rolling(window).mean() # rolling mean
    sigma = df.rolling(window).std().replace(0, np.nan) # rolling std dev, replace 0 with NaN to avoid division by zero
    z = (df - mu) / (sigma + 1e-12) # z-score normalization
    return z.fillna(0.0) # fill NaNs with 0.0 and return z-scored dataframe

# build state tensor
def build_state_tensor(panel: pd.DataFrame, lookback: int, normalize: bool = False):
    normalized = panel.copy()

    # organize data by tickers, features, and time
    tickers = sorted(panel.columns.unique(level=0)) # ensure consistent order
    features = sorted(panel.columns.unique(level=1)) # feature order
    times = normalized.index # time index

    # create sliding windows of lookback length
    # extract feature data for the window and stack into a 3D tensor (tickers x features x lookback)
    X, y_ret, inst_vol = [], [], [] # lists to hold state tensors, next returns, and current volatilities
    for i in range(lookback, len(times)-1): # iterate over time index in lookback window to have a sliding window with each time point as a step
        window_slice = normalized.iloc[i-lookback:i]
        frames = []
        for t in tickers: # save the windows of each ticker
            frames.append(window_slice[t].T.values)
        tensor = np.stack(frames, axis=0) # stack into 3D tensor
        X.append(tensor) # append tensor to list
        nxt = panel.iloc[i+1] # next period returns
        y_ret.append(np.array([nxt[(t, "ret")] for t in tickers], dtype=float)) # save to list
        cur = panel.iloc[i] # current period volatilities
        inst_vol.append(np.array([cur[(t, "vol")] for t in tickers], dtype=float)) # save to list

    X = np.array(X, dtype=np.float32) # convert list of state tensors of tensors to 3D numpy array
    y_ret = np.array(y_ret, dtype=np.float32) # convert list of next returns to 2D numpy array
    inst_vol = np.array(inst_vol, dtype=np.float32) # convert list of volatilities to 2D numpy array
    return X, y_ret, inst_vol, tickers, features, times[lookback+1:]

X_all, R_all, VOL_all, TICKER_ORDER, FEAT_ORDER, TIME_INDEX = build_state_tensor(
    panel, lookback=lookback, normalize=normalize
)

print("State tensor:", X_all.shape, "Returns:", R_all.shape, "InstVol:", VOL_all.shape)

## Define Splits and Adjust Timezones

This section defines the training and validation splits for the dataset. It ensures that the data is divided appropriately to train the model, validate its performance during training, and test its final performance on unseen data. Also, it adjusts the timezones of the datetime indices to ensure consistency across the dataset. This is necessary in case the data comes from multiple sources with different timezone settings.

In [None]:
# function to create boolean mask for date slicing
def date_slice_mask(times: pd.DatetimeIndex, start: str, end: str):
    # convert input dates to UTC timestamps
    start_ts = pd.Timestamp(start).tz_localize('UTC')
    end_ts = pd.Timestamp(end).tz_localize('UTC')
    
    # ensure time index is in UTC
    if times.tz is None: # if naive, localize to UTC
        times = times.tz_localize('UTC')
    elif times.tz != pytz.UTC: # if timezone-aware but not UTC, convert to UTC
        times = times.tz_convert('UTC')
        
    return (times >= start_ts) & (times <= end_ts) # return boolean mask for splits later on

def create_cointegration_splits(coint_df: pd.DataFrame, price_data_index: pd.DatetimeIndex):
    """Create training windows based on cointegration periods."""
    train_masks = []
    val_masks = []
    test_masks = []
    
    # Ensure indices are timezone-aware
    if price_data_index.tz is None:
        price_data_index = price_data_index.tz_localize('UTC')
    
    # Sort periods chronologically
    coint_df = coint_df.sort_values('train_start')
    
    # Split cointegration periods into train/val/test
    split_point1 = int(0.7 * len(coint_df))  # 70% for training
    split_point2 = int(0.85 * len(coint_df))  # 15% for validation
    
    train_periods = coint_df.iloc[:split_point1]
    val_periods = coint_df.iloc[split_point1:split_point2]
    test_periods = coint_df.iloc[split_point2:]
    
    # Create masks for each period type
    for periods, masks in [(train_periods, train_masks), 
                          (val_periods, val_masks), 
                          (test_periods, test_masks)]:
        for _, row in periods.iterrows():
            mask = (price_data_index >= row['train_start']) & (price_data_index <= row['train_end'])
            masks.append(mask)
    
    # Combine masks for each split with OR operation
    train_mask = np.logical_or.reduce(train_masks) if train_masks else np.zeros(len(price_data_index), dtype=bool)
    val_mask = np.logical_or.reduce(val_masks) if val_masks else np.zeros(len(price_data_index), dtype=bool)
    test_mask = np.logical_or.reduce(test_masks) if test_masks else np.zeros(len(price_data_index), dtype=bool)
    
    print(f"Training windows: {train_mask.sum()} timesteps")
    print(f"Validation windows: {val_mask.sum()} timesteps")
    print(f"Testing windows: {test_mask.sum()} timesteps")
    
    return train_mask, val_mask, test_mask

# Create training masks based on cointegration periods
train_mask, val_mask, test_mask = create_cointegration_splits(coint_df, TIME_INDEX)

# Create splits for training
SPLITS = [{
    "name": "CointegrationSplit",
    "train": train_mask,
    "val": val_mask,
    "test": test_mask
}]

## Environment Configuration

This is a custom Gymnasium environment for portfolio optimization.

In [None]:
class PortfolioWeightsEnv(gym.Env):
    metadata = {"render_modes": []}

    def __init__(self, X, R, VOL, tickers, lookback, cfg_env):
        super().__init__()
        self.X = X # state tensor
        self.R = R # next period returns
        self.VOL = VOL # current period volatilities
        self.tickers = tickers # list of assets
        self.lookback = lookback # lookback window
        self.cfg = cfg_env # environment configuration taken from CONFIG

        self.n_assets = len(tickers) # number of assets
        self.include_cash = cfg_env["include_cash"] # whether to include cash as an asset
        self.shorting = cfg_env["shorting"] # whether shorting is allowed
        self.dim_action = self.n_assets + (1 if self.include_cash else 0) # action space dimension: one weight per asset (+1 for cash if included)

        obs_dim = self.n_assets * self.X.shape[2] * self.lookback # observation space dimension: tickers x features x lookback
        self.observation_space = spaces.Box(low=-5, high=5, shape=(obs_dim,), dtype=np.float32) # needed for Gym to define feature bounds (±5 for z-scored features)
        
        # Update action space to allow for leveraged positions
        if self.shorting:
            self.action_space = spaces.Box(
                low=-cfg_env["leverage"]["short_cap"],
                high=cfg_env["leverage"]["long_cap"],
                shape=(self.dim_action,),
                dtype=np.float32
            )
        else:
            self.action_space = spaces.Box(
                low=0.0,
                high=cfg_env["leverage"]["long_cap"],
                shape=(self.dim_action,),
                dtype=np.float32
            )

        self.commission = cfg_env["transaction_costs"]["commission_bps"] / 1e4 # commission in decimal form
        self.slippage = cfg_env["transaction_costs"]["slippage_bps"] / 1e4 # slippage in decimal form
        self.risk_lambda = cfg_env["reward"]["risk_lambda"] # defines risk aversion in reward function

        # Leverage settings
        self.long_cap = cfg_env["leverage"]["long_cap"]
        self.short_cap = cfg_env["leverage"]["short_cap"]
        self.use_asymmetric = cfg_env["leverage"]["use_asymmetric"]

        self.reset(seed=cfg_env.get("seed", 42))

    def _to_obs(self, t):
        # Get the current window of observations
        arr = self.X[t].reshape(-1).astype(np.float32)
        return arr
    
    # project raw actions to asset weights with leverage
    def _project_weights(self, a):
        if self.use_asymmetric:
            # Apply asymmetric leverage caps
            w = np.clip(a, -self.short_cap if self.shorting else 0.0, self.long_cap)
        else:
            # Use symmetric leverage cap
            max_leverage = max(self.long_cap, self.short_cap)
            w = np.clip(a, -max_leverage if self.shorting else 0.0, max_leverage)
        return w

    # reset environment to initial state
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.t = 0 # reset time index
        self.portfolio_value = 1.0 # reset portfolio value
        self.w = np.zeros(self.dim_action) # initialize with zero weights
        if self.include_cash:
            self.w[-1] = 1.0  # start with all cash
        obs = self._to_obs(self.t)
        return obs, {}

    def step(self, action):
        w_target = self._project_weights(action) # get new target weights from raw actions
        turnover = np.sum(np.abs(w_target - self.w)) # calculate turnover as sum of absolute weight changes
        trading_cost = (self.commission + self.slippage) * turnover # total trading cost based on turnover

        asset_w_prev = self.w[:self.n_assets] # previous weights excluding cash
        asset_ret = np.dot(asset_w_prev, self.R[self.t]) # calculate asset return based on previous weights and next period returns
        inst_vol = np.dot(asset_w_prev, self.VOL[self.t]) # calculate instantaneous volatility based on previous weights and current volatilities

        # reward function: asset return minus trading costs and risk penalty
        reward = asset_ret - trading_cost - self.risk_lambda * inst_vol

        self.portfolio_value *= math.exp(asset_ret - trading_cost) # update portfolio value using log returns

        self.w = w_target # update calculated weights for next step
        self.t += 1 # move to next time step
        terminated = (self.t >= len(self.R)-1) # episode ends if we reach the end of the data
        truncated = False

        obs = self._to_obs(self.t) if not terminated else self._to_obs(self.t-1)
        obs = np.clip(obs, -5.0, 5.0) # clip observations to previously defined bounds (±5 for z-scored features)
        
        # Additional info for monitoring leveraged positions
        leverage = np.sum(np.abs(self.w[:self.n_assets]))
        info = {
            "portfolio_value": self.portfolio_value, 
            "turnover": turnover, 
            "inst_vol": inst_vol, 
            "asset_ret": asset_ret,
            "total_leverage": leverage
        }
        return obs, reward, terminated, truncated, info

In [None]:
def slice_by_mask(X, R, VOL, mask: np.ndarray):
    idx = np.where(mask)[0] # get all indices where mask is True
    
    # If we have no valid indices, return empty arrays with correct shapes
    if len(idx) == 0:
        empty_shape_x = list(X.shape)
        empty_shape_x[0] = 0
        empty_shape_r = list(R.shape)
        empty_shape_r[0] = 0
        empty_shape_v = list(VOL.shape)
        empty_shape_v[0] = 0
        return np.zeros(empty_shape_x), np.zeros(empty_shape_r), np.zeros(empty_shape_v)
    
    return X[idx], R[idx], VOL[idx] # return only the selected slices of data

def make_env_from_mask(mask, name="env"):
    X_s, R_s, V_s = slice_by_mask(X_all, R_all, VOL_all, mask)
    env = PortfolioWeightsEnv(X_s, R_s, V_s, TICKER_ORDER, CONFIG["ENV"]["lookback_window"], CONFIG["ENV"])
    env = Monitor(env, filename=None)
    return env

In [None]:
def annualize_factor(sampling: str):
    return ANNUALIZATION.get(sampling, 365*24)

def compute_metrics(equity_curve: pd.Series, sampling: str, turnover_series: pd.Series = None):
    ret = equity_curve.pct_change().dropna()
    ann = annualize_factor(sampling)
    mu = ret.mean() * ann
    sigma = ret.std() * math.sqrt(ann)
    sharpe = mu / (sigma + 1e-12)
    downside = ret[ret < 0].std() * math.sqrt(ann)
    sortino = mu / (downside + 1e-12)
    if len(equity_curve) > 1:
        # Calculate years based on number of samples and sampling frequency
        if isinstance(equity_curve.index, pd.DatetimeIndex):
            dt_years = (equity_curve.index[-1] - equity_curve.index[0]).total_seconds() / (365 * 24 * 3600)
        else:
            # If using RangeIndex, calculate based on sampling frequency
            samples = len(equity_curve)
            samples_per_year = annualize_factor(sampling)
            dt_years = samples / samples_per_year
        dt_years = float(dt_years) if float(dt_years) != 0 else 1e-12
        cagr = (equity_curve.iloc[-1] / equity_curve.iloc[0]) ** (1/dt_years) - 1
    else:
        cagr = 0.0
    cummax = equity_curve.cummax()
    dd = (equity_curve / cummax - 1).min()
    maxdd = float(dd)
    calmar = mu / (abs(maxdd) + 1e-12)
    hit_ratio = (ret > 0).mean()
    turnover = turnover_series.mean() if turnover_series is not None and len(turnover_series)>0 else np.nan
    return {"CAGR": cagr, "Sharpe": sharpe, "Sortino": sortino, "MaxDrawdown": maxdd, "Calmar": calmar, "Volatility": sigma, "Turnover": turnover, "HitRatio": hit_ratio}

def plot_series(series: pd.Series, title: str):
    plt.figure(figsize=(10,4))
    plt.plot(series.index, series.values)
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("Value")
    plt.grid(True)
    plt.show()

def plot_training_periods(coint_df: pd.DataFrame, price_index: pd.DatetimeIndex):
    """Plot timeline showing training periods and their overlap."""
    plt.figure(figsize=(15,5))
    y = 0
    for _, row in coint_df.iterrows():
        plt.hlines(y, row['train_start'], row['train_end'], 'blue', alpha=0.3)
        y += 1
    plt.title("Training Periods Timeline")
    plt.xlabel("Time")
    plt.ylabel("Period Index")
    plt.grid(True)
    plt.show()

def backtest_env(env: gym.Env, model=None):
    # Get the unwrapped environment
    unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
    
    obs, _ = env.reset()
    pv, turns = [], []
    leverage = []  # Track leverage over time
    
    for t in range(len(unwrapped.R)-1):
        if model is None:
            action = np.ones(unwrapped.dim_action)/unwrapped.dim_action
        else:
            action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, trunc, info = env.step(action)
        pv.append(info["portfolio_value"])
        turns.append(info["turnover"])
        leverage.append(info.get("total_leverage", 0))
        if done:
            break
            
    idx = pd.RangeIndex(start=0, stop=len(pv), step=1)
    ec = pd.Series(pv, index=idx)
    to = pd.Series(turns, index=idx)
    lev = pd.Series(leverage, index=idx)
    
    return ec, to, lev

# Plot cointegration training periods
print("Visualizing training periods timeline...")
plot_training_periods(coint_df, TIME_INDEX)

In [None]:
ensure_dir(CONFIG["IO"]["models_dir"])
ensure_dir(CONFIG["EVAL"]["reports_dir"])

RESULTS = []

for split in SPLITS:
    print(f"\n=== Training on split: {split['name']} ===")
    train_env = make_env_from_mask(split["train"], name=f"{split['name']}_train")
    eval_env  = make_env_from_mask(split["test"], name=f"{split['name']}_test")

    vec_train = DummyVecEnv([lambda: train_env])
    vec_eval  = DummyVecEnv([lambda: eval_env])

    model = PPO(
        policy=CONFIG["RL"]["policy"],
        env=vec_train,
        gamma=CONFIG["RL"]["gamma"],
        gae_lambda=CONFIG["RL"]["gae_lambda"],
        clip_range=CONFIG["RL"]["clip_range"],
        n_steps=CONFIG["RL"]["n_steps"],
        batch_size=CONFIG["RL"]["batch_size"],
        learning_rate=CONFIG["RL"]["learning_rate"],
        ent_coef=CONFIG["RL"]["ent_coef"],
        vf_coef=CONFIG["RL"]["vf_coef"],
        max_grad_norm=CONFIG["RL"]["max_grad_norm"],
        tensorboard_log=CONFIG["IO"]["tb_logdir"],
        device=device,
        verbose=1
    )

    eval_callback = EvalCallback(
        vec_eval, 
        best_model_save_path=CONFIG["IO"]["models_dir"],
        log_path=CONFIG["IO"]["models_dir"], 
        eval_freq=10000,
        deterministic=True, 
        render=False
    )
    
    model.learn(total_timesteps=CONFIG["RL"]["timesteps"], callback=eval_callback)
    model_path = os.path.join(CONFIG["IO"]["models_dir"], f"ppo_{split['name']}.zip")
    model.save(model_path)
    print("Saved model:", model_path)

    test_env = make_env_from_mask(split["test"], name=f"{split['name']}_test")
    ec, to = backtest_env(test_env, model=model)

    idx = np.where(split["test"])[0]
    R_test = R_all[idx]
    ew = np.ones(len(TICKER_ORDER))/len(TICKER_ORDER)
    ec_bench = [1.0]
    # Only iterate through the same number of returns as we have in ec.index
    for i in range(len(ec.index)-1):
        ec_bench.append(ec_bench[-1]*math.exp(np.dot(ew, R_test[i])))
    ec_bench = pd.Series(ec_bench, index=ec.index)

    bh_btc, bh_eth = [1.0], [1.0]
    # Only iterate through the same number of returns as we have in ec.index
    for i in range(len(ec.index)-1):
        bh_btc.append(bh_btc[-1]*math.exp(R_test[i][0]))
        bh_eth.append(bh_eth[-1]*math.exp(R_test[i][1]))
    bh_btc = pd.Series(bh_btc, index=ec.index)
    bh_eth = pd.Series(bh_eth, index=ec.index)

    m_model = compute_metrics(ec, CONFIG["DATA"]["sampling"], to)
    m_ew    = compute_metrics(ec_bench, CONFIG["DATA"]["sampling"])
    m_btc   = compute_metrics(bh_btc, CONFIG["DATA"]["sampling"])
    m_eth   = compute_metrics(bh_eth, CONFIG["DATA"]["sampling"])

    RESULTS.append({"split": split["name"], "model": m_model, "equal_weight": m_ew, "buy_and_hold_BTC": m_btc, "buy_and_hold_ETH": m_eth})

    if CONFIG["EVAL"]["plots"]:
        plot_series(ec, f"Equity Curve — PPO ({split['name']})")
        plot_series((ec / ec.cummax()) - 1.0, f"Drawdown — PPO ({split['name']})")
        plot_series(ec_bench, f"Equity Curve — Equal-Weight Hold ({split['name']})")

print("Done. RESULTS collected.")

In [None]:

rows = []
for res in RESULTS:
    row = {"split": res["split"]}
    for k, metrics in res.items():
        if k == "split":
            continue
        for mname, mval in metrics.items():
            row[f"{k}_{mname}"] = mval
    rows.append(row)

df_results = pd.DataFrame(rows)
df_results

In [None]:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
ensure_dir(CONFIG["EVAL"]["reports_dir"])
out_json = os.path.join(CONFIG["EVAL"]["reports_dir"], f"metrics_{ts}.json")
out_csv  = os.path.join(CONFIG["EVAL"]["reports_dir"], f"metrics_{ts}.csv")
df_results.to_csv(out_csv, index=False)
with open(out_json, "w") as f:
    json.dump(RESULTS, f, indent=2)
print("Saved:", out_json, "and", out_csv)