<a href="https://colab.research.google.com/github/diegoramfin/Reinforcement-Learning-Agent-For-Portfolio-Rebalancing/blob/main/Reinforcement_Learning_Agent_for_Portfolio_Rebalancing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reinforcement Learning for Portfolio Rebalancing

Treat rebalancing as a sequential decision problem, at each time step the agent observes market and portfolio data to then choose which portfolio action to do.

In [1]:
!pip install yfinance pandas numpy matplotlib gymnasium stable-baselines3[extra] finrl ta

Collecting finrl
  Downloading FinRL-0.3.7-py3-none-any.whl.metadata (909 bytes)
Collecting ta
  Downloading ta-0.11.0.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.7.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cudnn_cu12-9.1.0.7

In [2]:
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gymnasium
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [3]:
TICKERS = ['SPY', 'TLT', '^VIX', 'DX-Y.NYB', 'WTI', 'GC=F', 'QQQ', 'BND']
START = '2022-01-01'
END = '2024-12-31'
FREQ = '1wk'  # weekly
WINDOW = 6

# Download Market Data
prices = yf.download(TICKERS, start=START, end=END, interval=FREQ, progress=False, auto_adjust=False)['Adj Close']
prices = prices.dropna()
prices.columns = ['SPY', 'TLT', '^VIX', 'DXY', 'WTI', 'GOLD','QQQ', 'BND']
returns = np.log(prices) - np.log(prices.shift(1)) #Using Logartihmic Returns rather than PCT Change1
returns = returns.dropna()
returns.head()

Unnamed: 0_level_0,SPY,TLT,^VIX,DXY,WTI,GOLD,QQQ,BND
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2022-01-08,-0.002997,-0.005971,0.010793,0.000395,-0.002944,-0.001126,0.190097,0.022662
2022-01-15,0.00096,0.004926,0.008388,-0.077447,-0.059262,0.010709,-0.073233,0.407721
2022-01-22,-0.003965,0.016899,-0.025937,0.000313,0.009023,-0.003487,0.021819,-0.042123
2022-01-29,-0.011747,-0.018574,0.012084,0.017498,0.015158,-0.029207,0.049133,-0.174973
2022-02-05,-0.002743,0.006264,0.018754,-0.031063,-0.018535,-0.004091,0.040273,0.164068


In [4]:
price_matrix = prices.loc[returns.index].values
dates= returns.index

In [5]:
#2) Custom Gym Environment
# A simple continuous-action environment where actions represent *target weights* for each asset.
# - Observations: flattened recent log-return window + current weights
# - Action: continuous vector (N,) interpreted as target weights (non-negative; normalized)
# - Reward: next-step portfolio return minus transaction costs

class SimpleRebalEnv(gymnasium.Env):

  metadata = {"render_modes": ["human"]}

  def __init__(self, price_matrix, dates=None, init_cash=1e6, max_trade_cost= 0.001, window=WINDOW):
    self.prices = price_matrix
    self.dates = dates
    self.T, self.N = self.prices.shape
    self.init_cash = init_cash
    self.max_trade_cost = max_trade_cost
    self.window = window

    obs_dim = (self.window - 1) * self.N + self.N + 1  # +1 for cash weight
    self.observation_space = gymnasium.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)

    # action: target weights for N assets. We'll allow [0,1] per asset and normalize in step.
    self.action_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(self.N,), dtype=np.float32)

    self.seed()
    self.reset()

  def seed(self, seed=None):
    np.random.seed(seed)

  def reset(self, seed=None, options=None, **kwargs): # Accept arbitrary keyword arguments
    super().reset(seed=seed)
    self.t = self.window
    self.portfolio_value = float(self.init_cash)
    self.weights = np.zeros(self.N, dtype=float)
    self.cash_weight = 1.0
    self.history = {'portfolio_value': [self.portfolio_value], 'weights':[self.weights.copy()], 'dates':[] }
    return self._get_obs(), {} # Return observation and info dictionary

  def _get_obs(self):
    start = self.t - self.window
    window_prices = self.prices[start:self.t]
    lr = np.log(window_prices[1:] / window_prices[:-1])
    obs = np.concatenate([lr.flatten(), self.weights, [self.cash_weight]])
    return obs.astype(np.float32)

  def step (self, action):
    target = np.clip(action, 0.0, 1.0)
    if target.sum() <= 0:
      target_w = np.zeros_like(target)
      cash_w = 1.0
    else:
      target_w = target / target.sum() + 1e-12
      cash_w = 0.0

    prices_now = self.prices[self.t - 1]
    prices_next = self.prices[self.t]

    prev_value = self.portfolio_value

    #convert to dollar holdings

    target_value = target_w * prev_value
    prev_holdings_value = self.weights * prev_value

    # traded volume
    traded = np.abs(target_value - prev_holdings_value).sum()
    cost = traded * self.max_trade_cost

    new_holdings_value = (target_value * (prices_next / prices_now)).sum()

    cash_value = cash_w * prev_value

    new_portfolio_value = new_holdings_value + cash_value - cost
    reward = (new_portfolio_value - prev_value) / (prev_value + 1e-9)

    #update states

    self.portfolio_value = float(new_portfolio_value)
    self.weights = target_w
    self.cash_weight = cash_w
    self.history['portfolio_value'].append(self.portfolio_value)
    self.history['weights'].append(self.weights.copy())
    self.history['dates'].append(self.t)

    self.t += 1

    terminated = False
    truncated = (self.t >= self.T)

    obs = self._get_obs() if not truncated else np.zeros_like(self._get_obs())
    info = {'portfolio_value': self.portfolio_value, 'cost': cost}

    return obs, float(reward), terminated, truncated, info

In [6]:
# 3) Train / Eval helpers
from functools import partial

def make_env_from_slice(price_matrix, dates, start_idx, end_idx, **kwargs):
  slice_prices = price_matrix[start_idx:end_idx]
  slice_dates = dates[start_idx:end_idx] if dates is not None else None
  return SimpleRebalEnv(slice_prices, dates= slice_dates, window=WINDOW, **kwargs)

def compute_metrics(equity_curve, periods_per_year= 52):
  arr = np.array(equity_curve)
  returns = arr[1:] / arr[:-1] - 1.0
  cum_return = arr[-1] / arr[0] - 1.0
  ann_return = (1 + cum_return) ** (periods_per_year / len(returns)) - 1
  ann_vol = returns.std() * (periods_per_year ** 0.5)
  sharpe = ann_return / (ann_vol + 1e-9)
  # max drawdown
  peak = np.maximum.accumulate(arr)
  drawdown = (arr - peak) / peak
  max_dd = drawdown.min()
  return {
      'cumulative_return': float(cum_return),
      'annual_return': float(ann_return),
      'annual_vol': float(ann_vol),
      'sharpe': float(sharpe),
      'max_drawdown': float(max_dd)
  }

In [7]:
# 4) Train-test split & baseline
#We'll use a time-based split: train on first 70% of bars, validate on next 15%, test on final 15%.

T = price_matrix.shape[0]
train_end = int(T * 0.7)
val_end = int(T * 0.85)

print('Total bars', T)
print('Training bars', train_end)
print('Validation bars', val_end - train_end)
print('Testing bars', T - val_end)

Total bars 156
Training bars 109
Validation bars 23
Testing bars 24


In [None]:
#5) Train PPO and SAC
#Vectorized envs for stable-baseline.
from stable_baselines3 import PPO, SAC

train_env = DummyVecEnv([lambda: make_env_from_slice(price_matrix, dates, 0, train_end, init_cash=1e6, max_trade_cost=0.001)])
val_env = DummyVecEnv([lambda: make_env_from_slice(price_matrix, dates, train_end, val_end, init_cash=1e6, max_trade_cost=0.001)])

ppo_model = PPO('MlpPolicy', train_env, verbose=1)
ppo_model.learn(total_timesteps=80_000)

sac_model =SAC('MlpPolicy', train_env, verbose=1)
sac_model.learn(total_timesteps=80_000)

Using cpu device
-----------------------------
| time/              |      |
|    fps             | 567  |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 512         |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008583594 |
|    clip_fraction        | 0.0586      |
|    clip_range           | 0.2         |
|    entropy_loss         | -11.4       |
|    explained_variance   | -2.03       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0141     |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0144     |
|    std                  | 1.01        |
|    value_loss           | 0.00782     |
-----------------

In [None]:
#6) Save Models
ppo_model.save('ppo_model_rebal')
sac_model.save('sac_model_rebal')

In [None]:
#7) Evaluation on test set
#Run deterministic episodes through the test slice and record portfolio values.

def run_episode(env, model):
  obs, _ = env.reset() # Call reset on the environment and unpack the observation and info
  done = False
  pv =[env.portfolio_value]
  weights = [env.weights.copy()]
  while not done:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action) # Include terminated and truncated
    done = terminated or truncated # Update done status
    pv.append(info['portfolio_value'] if 'portfolio_value' in info else env.portfolio_value)
    weights.append(env.weights.copy())
  return np.array(pv), np.array(weights)

test_env = make_env_from_slice(price_matrix, dates, val_end, T, init_cash=1e6, max_trade_cost=0.001)

# run episodes
pv_ppo, w_ppo = run_episode(test_env, ppo_model) # Pass env first, then model
# reset test env for SAC (fresh instance)
test_env2 = make_env_from_slice(price_matrix, dates, val_end, T, init_cash=1e6, max_trade_cost=0.001)
pv_sac, w_sac = run_episode(test_env2, sac_model) # Pass env first, then model

metrics_ppo = compute_metrics(pv_ppo)
metrics_sac = compute_metrics(pv_sac)

print('PPO metrics:', metrics_ppo)
print('SAC metrics:', metrics_sac)

In [None]:
# 8) Visualizations

plt.figure(figsize=(10,5))
plt.plot(pv_ppo / pv_ppo[0], label='PPO')
plt.plot(pv_sac / pv_sac[0], label='SAC')
# baseline: buy-and-hold SPY on same test dates
bh_prices = price_matrix[val_end:T, 0]
bh_equity = bh_prices / bh_prices[0]  # assuming $1 start
# scale to initial capital
plt.plot(bh_equity * 1.0, label='BuyHold SPY (normalized)')
plt.legend(); plt.title('Normalized Equity Curves (Test)')
plt.show()

plt.figure(figsize=(10,4))
plt.imshow(np.vstack([w_ppo.T, w_sac.T])[:,:], aspect='auto')
plt.title('Weights over time (stacked: PPO then SAC)')
plt.xlabel('Time step')
plt.ylabel('Asset idx')
plt.colorbar(label='Weight')
plt.show()


In [None]:
# 9) Quick analysis & next steps
# - This is a compact demo: increase training timesteps, tune hyperparameters, and add more realistic transaction/slippage models for production-style experiments.
# - Important learning experiments: change reward function to penalize volatility, compare DQN on discretized 2-asset env, add features (indicators), and test out-of-sample robustness.
