# wrappers

> Environment wrappers to adjust to agent interfaces

In [None]:
# | default_exp wrappers

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# | hide
from fastcore.test import *
from hydra import compose
from hydra import initialize
from hydra.utils import instantiate
from nbdev.showdoc import *

In [None]:
# | export

from typing import *

from fastcore.basics import patch
import gymnasium as gym
from gymnasium.spaces import Box
from gymnasium.spaces import Discrete

import rlmm
from rlmm.core import *
from rlmm.utils import *

In [None]:
# class DiscreteActions(ActionWrapper):
#     def __init__(self, env, disc_to_cont):
#         super().__init__(env)
#         self.disc_to_cont = disc_to_cont
#         self.action_space = Discrete(len(disc_to_cont))

#     def action(self, act):
#         return self.disc_to_cont[act]

In [None]:
# | export


class DiscreteActionWrapper(gym.ActionWrapper):
    def __init__(self, env: gym.Env, action_dict: Dict[int, Any]):
        super().__init__(env)
        self.action_dict = action_dict
        self.action_space = Discrete(len(action_dict))

    def action(self, action: int):
        if action not in self.action_dict.keys():
            raise ValueError(
                f"Discrete action {action} not in action_dict {action_dict}"
            )

        return self.action_dict[action]

In [None]:
params = {"_target_": "rlmm.wrappers.DiscreteActionWrapper"}

hydra_nb(
    obj=DiscreteActionWrapper,
    path="../conf/wrappers/discrete_action.yaml",
    params=params,
)

_target_: rlmm.wrappers.DiscreteActionWrapper



In [None]:
with initialize(version_base=None, config_path="../conf"):
    cfg = compose(config_name="conf.yaml")

In [None]:
dataset_book = instantiate(cfg.dataset_book)
dataset_trades = instantiate(cfg.dataset_trades)
env = instantiate(cfg.envs, _partial_=True)(dataset_book, dataset_trades)
env.state

array([1.500000e+01, 3.150020e+00, 6.378770e+00, 1.503374e+01,
       4.091820e+00, 1.070000e+00, 3.210000e-01, 4.601000e+00,
       1.507110e+00, 1.298630e+00, 7.349200e-01, 5.729470e+00,
       2.474070e+00, 2.882000e-01, 5.542400e-01, 2.639190e+00,
       8.807530e+00, 3.251250e+01, 3.376740e+00, 2.746670e+00,
       0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
       0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
       0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
       0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
       0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
       1.000000e+06, 1.000000e+02])

In [None]:
action_dict = {1: [("ask", 2, 10)], 2: [("bid", 4, 3)], 3: [("ask", 1, 100)]}

In [None]:
env = DiscreteActionWrapper(env, action_dict)
env

<DiscreteActionWrapper<OrderBookEnv instance>>

In [None]:
env.step(3)

[('ask', 1, 100)]


(array([1.599750e+01, 3.359480e+00, 6.000000e-02, 1.000000e+00,
        3.712500e-01, 2.230250e+00, 6.750000e+00, 1.417510e+00,
        1.795550e+00, 4.289640e+00, 3.043117e+01, 2.440770e+00,
        1.000000e+01, 1.000000e+01, 2.025370e+00, 3.375400e-01,
        3.270205e+01, 3.680150e+00, 1.900000e+00, 2.746670e+00,
        0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
        0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
        0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
        0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
        0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
        1.000000e+06, 1.000000e+02]),
 205.0,
 False,
 {'idx': 2,
  'timestamp': '2021-03-05 06:18:00',
  'cash': 1000000,
  'inventory': 100,
  'portfolio_value': 1148155.5})

In [None]:
try:
    env.step(4)
except Exception as e:
    print(e)

Discrete action 4 not in action_dict {1: [('ask', 2, 10)], 2: [('bid', 4, 3)], 3: [('ask', 1, 100)]}


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()