/
rescale_action.py
35 lines (28 loc) · 1.46 KB
/
rescale_action.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import warnings
import gym
from gym import spaces
class RescaleAction(gym.ActionWrapper):
r"""Rescales the continuous action space of the environment to a range [a,b].
Example::
>>> RescaleAction(env, a, b).action_space == Box(a,b)
True
"""
def __init__(self, env, a, b):
assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space))
assert np.less_equal(a, b).all(), (a, b)
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
super(RescaleAction, self).__init__(env)
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
def action(self, action):
assert np.all(np.greater_equal(action, self.a)), (action, self.a)
assert np.all(np.less_equal(action, self.b)), (action, self.b)
low = self.env.action_space.low
high = self.env.action_space.high
action = low + (high - low) * ((action - self.a) / (self.b - self.a))
action = np.clip(action, low, high)
return action