In [2]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class TradingEnv(gym.Env):
    """
    Discrete-action trading environment with risk-aware reward.
    """

    def __init__(
        self,
        prices,
        max_steps=500,
        transaction_cost=0.001,
        risk_lambda=0.02,
        inventory_limit=10
    ):
        super().__init__()

        self.prices = prices.astype(np.float32)
        self.max_steps = max_steps
        self.tc = transaction_cost
        self.lambda_risk = risk_lambda
        self.inv_limit = inventory_limit

        self.action_space = spaces.Discrete(3)  # hold, buy, sell

        self.observation_space = spaces.Box(
            low=0.0, high=1.0, shape=(4,), dtype=np.float32
        )

        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.t = 0
        self.inventory = 0
        self.cash = 1.0
        self.prev_value = self.cash
        self.peak_value = self.cash
        return self._obs(), {}

    def step(self, action):
        price = self.prices[self.t]

        if action == 1 and self.inventory < self.inv_limit:
            self.inventory += 1
            self.cash -= price * (1 + self.tc)
        elif action == 2 and self.inventory > -self.inv_limit:
            self.inventory -= 1
            self.cash += price * (1 - self.tc)

        self.t += 1
        done = self.t >= self.max_steps

        value = self.cash + self.inventory * self.prices[self.t]
        reward = self._reward(value)

        self.prev_value = value
        self.peak_value = max(self.peak_value, value)

        info = {
            "portfolio_value": value,
            "inventory": self.inventory,
            "drawdown": self.peak_value - value
        }

        return self._obs(), reward, done, False, info

    def _reward(self, value):
        pnl = value - self.prev_value
        drawdown = max(0.0, self.peak_value - value)
        return float(pnl - self.lambda_risk * drawdown)

    def _obs(self):
        price_norm = self.prices[self.t] / self.prices[0]
        inv_norm = (self.inventory + self.inv_limit) / (2 * self.inv_limit)
        cash_norm = np.clip(self.cash, 0, 1)
        t_norm = self.t / self.max_steps
        return np.array([price_norm, inv_norm, cash_norm, t_norm], dtype=np.float32)
