In [1]:
import gym


#自定义一个Wrapper
class Pendulum(gym.Wrapper):

    def __init__(self):
        env = gym.make('Pendulum-v1')
        super().__init__(env)
        self.env = env

    def reset(self):
        state, _ = self.env.reset()
        return state

    def step(self, action):
        state, reward, done, _, info = self.env.step(action)
        return state, reward, done, info


Pendulum().reset()

array([-0.9987739 , -0.04950466, -0.29113272], dtype=float32)

In [2]:
import gym


#测试一个环境
def test(env, wrap_action_in_list=False):
    print(env)

    state = env.reset()
    over = False
    step = 0

    while not over:
        action = env.action_space.sample()

        if wrap_action_in_list:
            action = [action]

        next_state, reward, over, _ = env.step(action)

        if step % 20 == 0:
            print(step, state, action, reward)

        if step > 200:
            break

        state = next_state
        step += 1


test(Pendulum())

<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>
0 [ 0.07557004  0.99714047 -0.00593105] [0.77568775] -2.2360912073732293
20 [-0.09022966 -0.99592096 -1.0113312 ] [1.9640111] -2.8655519815412602
40 [-0.0988468   0.99510264  1.0879471 ] [-1.102183] -2.9078257844017066
60 [-0.10559524 -0.9944092  -2.0250401 ] [1.4726528] -3.223197442443063
80 [-0.430696    0.90249705  2.3240063 ] [1.7433363] -4.607638323265267
100 [-0.7445209  -0.66759914 -3.8428938 ] [0.12860739] -7.287856130160823
120 [-0.9827089   0.18515739  5.019113  ] [0.27670756] -11.253382467763519
140 [-0.9918974  0.1270416 -4.5611672] [0.13088569] -11.16588448336541
160 [-0.92749065 -0.3738463   4.023142  ] [0.45365784] -9.227764511724946
180 [-0.90662247  0.42194283 -3.7316737 ] [0.13845253] -8.715023595619744
200 [-0.569597   -0.82192415  2.3553548 ] [1.5546746] -5.295696103269491


In [3]:
#修改最大步数
class StepLimitWrapper(gym.Wrapper):

    def __init__(self, env):
        super().__init__(env)
        self.current_step = 0

    def reset(self):
        self.current_step = 0
        return self.env.reset()

    def step(self, action):
        self.current_step += 1
        state, reward, done, info = self.env.step(action)

        #修改done字段
        if self.current_step >= 100:
            done = True

        return state, reward, done, info


test(StepLimitWrapper(Pendulum()))

<StepLimitWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [-0.96306235  0.26927853 -0.44170246] [-0.27660012] -8.250454055829325
20 [-0.9896051  -0.1438115  -0.35573816] [-0.8561249] -8.997079924662561
40 [-0.99573505  0.09225862  0.9402986 ] [-1.6440495] -9.388755855980168
60 [-0.9821424  -0.18813926 -0.48762357] [-0.23724955] -8.740059353440317
80 [-0.99840486  0.05646031  0.33269486] [-1.5122217] -9.531211654731154


In [4]:
import numpy as np


#修改动作空间
class NormalizeActionWrapper(gym.Wrapper):

    def __init__(self, env):
        #获取动作空间
        action_space = env.action_space

        #动作空间必须是连续值
        assert isinstance(action_space, gym.spaces.Box)

        #重新定义动作空间,在正负一之间的连续值
        #这里其实只影响env.action_space.sample的返回结果
        #实际在计算时,还是正负2之间计算的
        env.action_space = gym.spaces.Box(low=-1,
                                          high=1,
                                          shape=action_space.shape,
                                          dtype=np.float32)

        super().__init__(env)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        #重新缩放动作的值域
        action = action * 2.0

        if action > 2.0:
            action = 2.0

        if action < -2.0:
            action = -2.0

        return self.env.step(action)


test(NormalizeActionWrapper(Pendulum()))

<NormalizeActionWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [ 0.9977141   0.06757639 -0.5798606 ] [0.10924502] -0.03824510760645264
20 [-0.974494    0.22441359 -7.9314737 ] [-0.19993515] -14.789678010888132
40 [0.82223725 0.56914485 1.9122643 ] [-0.1456236] -0.7323487597177175
60 [ 0.95644176 -0.29192328  3.2361808 ] [-0.45272422] -1.135862910021885
80 [ 0.17755312 -0.9841112   6.3154745 ] [-0.9058826] -5.930294741967333
100 [-0.9652388   0.26136962  7.652307  ] [0.47379] -14.13468087321719
120 [0.97018814 0.24235296 1.3940586 ] [-0.802597] -0.25683890851431856
140 [-0.00675709 -0.9999772   5.308917  ] [-0.7024001] -5.309108278379465
160 [-0.59028745 -0.8071931  -6.0892677 ] [0.17197001] -8.557770839551926
180 [0.15892577 0.98729056 3.3852875 ] [0.8132015] -3.1401303461009196
200 [ 0.5729039  -0.8196226  -0.55536026] [0.618847] -0.9554186163032075


In [5]:
from gym.wrappers import TimeLimit


#修改状态
class StateStepWrapper(gym.Wrapper):

    def __init__(self, env):

        #状态空间必须是连续值
        assert isinstance(env.observation_space, gym.spaces.Box)

        #增加一个新状态字段
        low = np.concatenate([env.observation_space.low, [0.0]])
        high = np.concatenate([env.observation_space.high, [1.0]])

        env.observation_space = gym.spaces.Box(low=low,
                                               high=high,
                                               dtype=np.float32)

        super().__init__(env)

        self.step_current = 0

    def reset(self):
        self.step_current = 0
        return np.concatenate([self.env.reset(), [0.0]])

    def step(self, action):
        self.step_current += 1
        state, reward, done, info = self.env.step(action)

        #根据step_max修改done
        if self.step_current >= 100:
            done = True

        return self.get_state(state), reward, done, info

    def get_state(self, state):
        #添加一个新的state字段
        state_step = self.step_current / 100

        return np.concatenate([state, [state_step]])


test(StateStepWrapper(Pendulum()))

<StateStepWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>
0 [0.59817356 0.80136657 0.85264915 0.        ] [-0.00020923] -0.9368131282951905
20 [ 0.56020814 -0.82835186  1.70206869  0.2       ] [-1.9285277] -1.2463098720796288
40 [ 0.44708437  0.89449179 -4.06003475  0.4       ] [0.5458641] -2.874784491887043
60 [-0.93931633 -0.34305215  6.48627567  0.6       ] [-1.5630432] -12.001692269034855
80 [-0.99761033 -0.06909105 -5.71979904  0.8       ] [-0.394344] -12.711692973575024


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [6]:
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

#使用Monitor Wrapper,会在训练的过程中输出rollout/ep_len_mean和rollout/ep_rew_mean,就是增加些日志
#gym升级到0.26以后失效了,可能是因为使用了自定义的wapper
env = DummyVecEnv([lambda: Monitor(Pendulum())])

A2C('MlpPolicy', env, verbose=1).learn(1000)

Using cpu device
------------------------------------
| time/                 |          |
|    fps                | 872      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -1.46    |
|    explained_variance | -0.0073  |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -44.6    |
|    std                | 1.04     |
|    value_loss         | 981      |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 871      |
|    iterations         | 200      |
|    time_elapsed       | 1        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -1.45    |
|    explained_variance | 0.00435  |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss     

<stable_baselines3.a2c.a2c.A2C at 0x7f6c8c5a6bb0>

In [7]:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack

#VecNormalize,他会对state和reward进行Normalize
env = DummyVecEnv([Pendulum])
env = VecNormalize(env)

test(env, wrap_action_in_list=True)

<stable_baselines3.common.vec_env.vec_normalize.VecNormalize object at 0x7f6c8c5a6af0>
0 [[ 3.2637883e-05 -7.0706955e-03 -6.9902935e-03]] [array([-0.25961152], dtype=float32)] [-10.]
20 [[1.2276741  0.97891146 2.1453943 ]] [array([-0.24803689], dtype=float32)] [-0.05282598]
40 [[ 0.60790956 -1.172277   -0.47112164]] [array([-0.12413894], dtype=float32)] [-0.0540317]
60 [[-0.17742442  0.9338299   0.91889554]] [array([-0.7250232], dtype=float32)] [-0.06084454]
80 [[-1.1442417 -0.5317964 -0.8895803]] [array([-0.43278834], dtype=float32)] [-0.08468775]
100 [[-1.346504   -0.00736775  1.2767539 ]] [array([0.1718505], dtype=float32)] [-0.09404851]
120 [[-0.685307    0.76545364 -0.9292815 ]] [array([-0.5423339], dtype=float32)] [-0.05717669]
140 [[-0.4067262  -1.0522009   0.80991924]] [array([0.6089338], dtype=float32)] [-0.0439776]
160 [[ 0.60329753  1.177517   -0.4289072 ]] [array([1.5845369], dtype=float32)] [-0.0280198]
180 [[ 0.37099332 -1.2547264  -0.08798675]] [array([-1.9815657], dtype