In [1]:
import gymnasium as gym


#自定义一个Wrapper
class Pendulum(gym.Wrapper):

    def __init__(self):
        env = gym.make('Pendulum-v1')
        super().__init__(env)
        self.env = env

    def reset(self, seed=None, options=None):
        state, info = self.env.reset()
        return state, info

    def step(self, action):
        state, reward, done, truncated, info = self.env.step(action)

        return state, reward, done, truncated, info


Pendulum().reset()

(array([-0.4241724, -0.9055815,  0.8657286], dtype=float32), {})

In [3]:
#测试一个环境
def test(env, wrap_action_in_list=False):
    print(env)

    state = env.reset()
    over = False
    step = 0

    while not over:
        action = env.action_space.sample()

        if wrap_action_in_list:
            action = [action]

        next_state, reward, over, truncated, info = env.step(action)

        if step % 20 == 0:
            print(step, state, action, reward, info)

        if step > 200:
            break

        state = next_state
        step += 1


test(Pendulum())

<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>
0 (array([-0.75942934,  0.65058976,  0.33862457], dtype=float32), {}) [-1.8423958] -5.935478494375225 {}
20 [-0.8220362  -0.56943524 -1.309392  ] [-1.8641158] -6.605075509520823 {}
40 [-0.9606308   0.27782816  2.7733295 ] [-0.35866398] -8.949208123893145 {}
60 [-0.98197013  0.1890362  -3.0943265 ] [1.5095853] -9.670597822238605 {}
80 [-0.795866  -0.6054728  1.7512435] [-0.07221768] -6.512931672318567 {}
100 [-0.66456056  0.7472344  -1.2653979 ] [-1.4816382] -5.441752562969946 {}
120 [-0.5669427  -0.82375723  0.31710556] [-0.3166121] -4.734632097375697 {}
140 [-0.6127761   0.79025656  1.2994612 ] [0.82946855] -5.144076322153992 {}
160 [-0.8306981  -0.55672324 -2.8063476 ] [-1.1940153] -7.297384852546583 {}
180 [-0.99894196 -0.04598893  4.6388717 ] [-1.7274575] -11.737559182739165 {}
200 [-0.90032417  0.43521994 -3.363787  ] [1.6156356] -8.377266635543235 {}


In [3]:
#修改最大步数
class StepLimitWrapper(gym.Wrapper):

    def __init__(self, env):
        super().__init__(env)
        self.current_step = 0

    def reset(self,seed=None, options=None):
        self.current_step = 0
        return self.env.reset()

    def step(self, action):
        self.current_step += 1
        state, reward, done, truncated, info = self.env.step(action)

        #修改done字段
        if self.current_step >= 100:
            done = True

        return state, reward, done, info


test(StepLimitWrapper(Pendulum()))

<StepLimitWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [-0.7358193  -0.67717797 -0.40382123] [-0.5497488] -5.765440561412341
20 [-0.9347265   0.35536796  1.9101306 ] [-0.1730721] -8.08375711348729
40 [-0.9913878   0.13095891 -2.549925  ] [-0.81193155] -9.712515159385875
60 [-0.86229056 -0.5064138   1.1561615 ] [0.8439644] -6.949468564700318
80 [-0.6871128   0.72655076 -0.10954288] [-1.3248858] -5.423954906875421


In [4]:
import numpy as np


#修改动作空间
class NormalizeActionWrapper(gym.Wrapper):

    def __init__(self, env):
        #获取动作空间
        action_space = env.action_space

        #动作空间必须是连续值
        assert isinstance(action_space, gym.spaces.Box)

        #重新定义动作空间,在正负一之间的连续值
        #这里其实只影响env.action_space.sample的返回结果
        #实际在计算时,还是正负2之间计算的
        env.action_space = gym.spaces.Box(low=-1,
                                          high=1,
                                          shape=action_space.shape,
                                          dtype=np.float32)

        super().__init__(env)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        #重新缩放动作的值域
        action = action * 2.0

        if action > 2.0:
            action = 2.0

        if action < -2.0:
            action = -2.0

        return self.env.step(action)


test(NormalizeActionWrapper(Pendulum()))

<NormalizeActionWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [-0.90479934  0.42583814  0.46141297] [0.71307397] -7.3225321184911145
20 [-0.9846407  -0.17459278 -2.415618  ] [-0.85981315] -9.384230063667319
40 [-0.99319226 -0.11648658  3.4871514 ] [-0.6373727] -10.367310571865762
60 [-0.86642474  0.49930772 -3.734519  ] [0.1263556] -8.252804318413528
80 [-0.6702318 -0.7421518  1.911543 ] [0.43951586] -5.680660931112576
100 [-0.5961345  0.8028846 -1.0250479] [0.12384028] -4.9869101896966495
120 [-0.64067    -0.76781636 -1.2842121 ] [-0.6126538] -5.301933755735472
140 [-0.7986208  0.6018345  2.6601052] [-0.72427845] -6.938714204658353
160 [-0.9983815   0.05687125 -3.7444353 ] [-0.37042484] -10.917945192112665
180 [-0.9121499 -0.4098567  3.2691987] [-0.35181427] -8.463830116640269
200 [-0.71257746  0.70159346 -1.1502234 ] [0.67964196] -5.722462683064076


In [5]:
from gym.wrappers import TimeLimit


#修改状态
class StateStepWrapper(gym.Wrapper):

    def __init__(self, env):

        #状态空间必须是连续值
        assert isinstance(env.observation_space, gym.spaces.Box)

        #增加一个新状态字段
        low = np.concatenate([env.observation_space.low, [0.0]])
        high = np.concatenate([env.observation_space.high, [1.0]])

        env.observation_space = gym.spaces.Box(low=low,
                                               high=high,
                                               dtype=np.float32)

        super().__init__(env)

        self.step_current = 0

    def reset(self):
        self.step_current = 0
        return np.concatenate([self.env.reset(), [0.0]])

    def step(self, action):
        self.step_current += 1
        state, reward, done, info = self.env.step(action)

        #根据step_max修改done
        if self.step_current >= 100:
            done = True

        return self.get_state(state), reward, done, info

    def get_state(self, state):
        #添加一个新的state字段
        state_step = self.step_current / 100

        return np.concatenate([state, [state_step]])


test(StateStepWrapper(Pendulum()))

<StateStepWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [ 0.59643936 -0.80265814  0.02045725  0.        ] [0.0625756] -0.8681826781206066
20 [ 0.69300234  0.72093534 -2.91545701  0.2       ] [0.48333085] -1.4984907639906544
40 [-0.40385997  0.91482079  5.91973925  0.4       ] [-0.16275077] -7.450653647426798
60 [0.8851468  0.46531188 2.53248    0.6       ] [1.4304487] -0.8776350784009858
80 [ 0.82157564 -0.57009953  1.69356799  0.8       ] [0.3554661] -0.6549398995047444


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [6]:
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

#使用Monitor Wrapper,会在训练的过程中输出rollout/ep_len_mean和rollout/ep_rew_mean,就是增加些日志
#gym升级到0.26以后失效了,可能是因为使用了自定义的wapper
env = DummyVecEnv([lambda: Monitor(Pendulum())])

A2C('MlpPolicy', env, verbose=1).learn(1000)

Using cpu device
------------------------------------
| time/                 |          |
|    fps                | 919      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -1.44    |
|    explained_variance | -0.00454 |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -43.2    |
|    std                | 1.02     |
|    value_loss         | 976      |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 897      |
|    iterations         | 200      |
|    time_elapsed       | 1        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -1.43    |
|    explained_variance | 4.43e-05 |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss     

<stable_baselines3.a2c.a2c.A2C at 0x7f94cc7fee80>

In [7]:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack

#VecNormalize,他会对state和reward进行Normalize
env = DummyVecEnv([Pendulum])
env = VecNormalize(env)

test(env, wrap_action_in_list=True)

<stable_baselines3.common.vec_env.vec_normalize.VecNormalize object at 0x7f949b273e20>
0 [[-0.00487219  0.00638567  0.00554428]] [array([1.8279244], dtype=float32)] [-10.]
20 [[-0.03349784 -0.66757905 -2.173565  ]] [array([-1.3404763], dtype=float32)] [-0.16482905]
40 [[-1.4011567   0.07794451  1.3022798 ]] [array([1.6657785], dtype=float32)] [-0.16956142]
60 [[-1.3572015   0.41527337 -1.5391531 ]] [array([1.3103601], dtype=float32)] [-0.12919044]
80 [[-0.34073314 -0.9262497   1.2184559 ]] [array([0.99134326], dtype=float32)] [-0.07326685]
100 [[ 1.6391766  1.4822493 -1.3587774]] [array([-1.5195391], dtype=float32)] [-0.04477553]
120 [[-0.01510759 -1.150826    1.7960454 ]] [array([-0.46556672], dtype=float32)] [-0.07235025]
140 [[-0.5499776  -0.83480734 -2.1066453 ]] [array([-0.03180405], dtype=float32)] [-0.08347733]
160 [[ 2.0203962  -0.5923113  -0.62100804]] [array([1.302856], dtype=float32)] [-0.00734869]
180 [[ 1.8016039   0.8581704  -0.43550822]] [array([1.2153829], dtype=float32