In [1]:
import numpy as np
import gym
from stable_baselines3.common.env_checker import check_env


class GoLeftEnv(gym.Env):
    #支持的render模式,在jupyter中不支持human模式
    metadata = {'render.modes': ['console']}

    def __init__(self):
        super().__init__()

        #初始位置
        self.pos = 9

        #动作空间,这个环境中只有左,右两个动作
        self.action_space = gym.spaces.Discrete(2)

        #状态空间,一维数轴
        self.observation_space = gym.spaces.Box(low=0,
                                                high=10,
                                                shape=(1, ),
                                                dtype=np.float32)

    def reset(self):
        #重置位置
        self.pos = 9

        #当前状态
        return np.array([self.pos], dtype=np.float32)

    def step(self, action):
        #执行动作
        if action == 0:
            self.pos -= 1

        if action == 1:
            self.pos += 1

        self.pos = np.clip(self.pos, 0, 10)

        #判断游戏结束
        done = self.pos == 0

        #给予reward
        reward = 1 if self.pos == 0 else 0

        return np.array([self.pos], dtype=np.float32), reward, bool(done), {}

    def render(self, mode='console'):
        if mode != 'console':
            raise NotImplementedError()
        print(self.pos)

    def close(self):
        pass


env = GoLeftEnv()

#检查环境是否合法
check_env(env, warn=True)

In [2]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

#包装环境
train_env = make_vec_env(lambda: env, n_envs=1)

#定义模型
model = PPO('MlpPolicy', train_env, verbose=0)

In [3]:
import gym


#测试一个环境
def test(model, env):
    state = env.reset()
    over = False
    step = 0

    for i in range(100):
        action = model.predict(state)[0]

        next_state, reward, over, _ = env.step(action)

        if step % 1 == 0:
            print(step, state, action, reward)

        state = next_state
        step += 1

        if over:
            break


test(model, env)

0 [9.] 1 0
1 [10.] 0 0
2 [9.] 0 0
3 [8.] 0 0
4 [7.] 0 0
5 [6.] 1 0
6 [7.] 1 0
7 [8.] 0 0
8 [7.] 0 0
9 [6.] 0 0
10 [5.] 1 0
11 [6.] 0 0
12 [5.] 0 0
13 [4.] 1 0
14 [5.] 0 0
15 [4.] 1 0
16 [5.] 0 0
17 [4.] 1 0
18 [5.] 1 0
19 [6.] 1 0
20 [7.] 0 0
21 [6.] 0 0
22 [5.] 1 0
23 [6.] 1 0
24 [7.] 0 0
25 [6.] 0 0
26 [5.] 1 0
27 [6.] 0 0
28 [5.] 0 0
29 [4.] 0 0
30 [3.] 1 0
31 [4.] 0 0
32 [3.] 0 0
33 [2.] 1 0
34 [3.] 0 0
35 [2.] 0 0
36 [1.] 0 1


In [4]:
model.learn(5000)

#测试
test(model, env)

0 [9.] 0 0
1 [8.] 0 0
2 [7.] 0 0
3 [6.] 0 0
4 [5.] 0 0
5 [4.] 0 0
6 [3.] 0 0
7 [2.] 0 0
8 [1.] 0 1
