In [2]:
import gymnasium as gym
from gymnasium.spaces import Discrete
from gymnasium.spaces import Box
import numpy as np

In [3]:
class TradingEnv(gym.Env):

    def __init__(self, data, episode_length = 250, budget=10000):
        self.portfolio_value = budget
        self.cur_row_num = 0
        self.starting_row_num = 0
        self.asset_allocation = 0.0
        self.data = data
        self.episode_length = episode_length
    
        # action space: Sell 25%, sell 10%, no change, buy 10%, buy 25% (percentages are of total portfolio value, asset + cash, at each timestep)
        self.action_space = Discrete(5)

        # observation space: Close, Volume, SMA Ratio, RSI, Bandwidth, Asset Allocation
        self.observation_space = Box(low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
            high=np.array([np.inf, np.inf, np.inf, 100.0, np.inf, 1.0]))


    def _get_obs(self):
        obs = np.array(self.data.iloc[self.row_num, :])
        obs = np.append(obs, self.asset_allocation)
        return np.array(self.data.iloc[self.row_num, :])

    def _get_info(self):
        return {'Portfolio Value': self.portfolio_value}

    def reset(self, seed):
        super().reset(seed=seed)
        self.starting_row_num = np.random.randint(0, len(self.data) - self.episode_length - 2)    # prevent out of bounds, also subtract 2 to avoid
                                                                                                 # weird edge cases for now (should change later)
        self.cur_row_num = self.starting_row_num
        rand = np.random.rand()
        if rand > 0.7:
            self.asset_allocation = 0.0
        else:
            self.asset_allocation = np.random.rand()
        
        return self._get_obs(), self._get_info()

    def step(self, action):
        self.cur_row_num += 1
        if (self.cur_row_num - self.starting_row_num) > self.episode_length:
            terminated = True
        truncated = False
        self.asset_allocation = self._action_to_allocation(action)
        obs = self._get_obs()
        rew = self._get_rew()
        info = self._get_info()
        return obs, rew, terminated, truncated, info
    
    def _action_to_allocation(self, action):
        allocation_change = 0.0
        if action == 0: allocation_change = -.25
        elif action == 1: allocation_change = -.1
        elif action == 2: allocation_change = 0.0
        elif action == 3: allocation_change = .1
        else: allocation_change = 0.25
        return min(1.0, self.asset_allocation + allocation_change)
    
    def _get_reward(self):
        asset_change = (self.data.iloc[self.cur_row_num, 0] - self.data.iloc[self.cur_row_num - 1, 0]) / self.data.iloc[self.cur_row_num - 1, 0]
        new_portfolio_value = self.portfolio_value * (self.asset_allocation * (1.0 + asset_change) + (1.0 - self.asset_allocation))
        reward = (new_portfolio_value - self.portfolio_value) / self.portfolio_value
        self.portfolio_value = new_portfolio_value
        return reward


    # def close(self):
    #     pass