In [1]:
import gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        env = gym.make('CartPole-v1')
        super().__init__(env)
        self.env = env

    def reset(self):
        state, _ = self.env.reset()
        return state

    def step(self, action):
        state, reward, done, _, info = self.env.step(action)
        return state, reward, done, info


env = MyWrapper()

env.reset()

array([-0.02095911, -0.04298958,  0.04404614, -0.0466527 ], dtype=float32)

In [2]:
from stable_baselines3.common.callbacks import BaseCallback


#Callback语法
class CustomCallback(BaseCallback):

    def __init__(self, verbose=0):
        super().__init__(verbose)

        #可以访问的变量
        #self.model
        #self.training_env
        #self.n_calls
        #self.num_timesteps
        #self.locals
        #self.globals
        #self.logger
        #self.parent

    def _on_training_start(self) -> None:
        #第一个rollout开始前调用
        pass

    def _on_rollout_start(self) -> None:
        #rollout开始前
        pass

    def _on_step(self) -> bool:
        #env.step()之后调用,返回False后停止训练
        return True

    def _on_rollout_end(self) -> None:
        #更新参数前调用
        pass

    def _on_training_end(self) -> None:
        #训练结束前调用
        pass


CustomCallback()

<__main__.CustomCallback at 0x7f53a8ae42b0>

In [3]:
from stable_baselines3 import PPO


#让训练只执行N步的callback
class SimpleCallback(BaseCallback):

    def __init__(self):
        super().__init__(verbose=0)
        self.call_count = 0

    def _on_step(self):
        self.call_count += 1

        if self.call_count % 20 == 0:
            print(self.call_count)

        if self.call_count >= 100:
            return False

        return True


model = PPO('MlpPolicy', MyWrapper(), verbose=0)

model.learn(8000, callback=SimpleCallback())

20
40
60
80
100


<stable_baselines3.ppo.ppo.PPO at 0x7f53a8ae4430>

In [4]:
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
import gym


def test_callback(callback):

    #创建Monitor封装的环境,这会在训练过程中写出日志文件到models文件夹
    env = make_vec_env(MyWrapper, n_envs=1, monitor_dir='models')

    #等价写法
    # from stable_baselines3.common.monitor import Monitor
    # from stable_baselines3.common.vec_env import DummyVecEnv
    # env = Monitor(MyWrapper(), 'models')
    # env = DummyVecEnv([lambda: env])

    #训练
    model = A2C('MlpPolicy', env, verbose=0).learn(total_timesteps=5000,
                                                   callback=callback)

    #测试
    return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)


#使用Monitor封装的环境训练一个模型,保存下日志
#只是为了测试load_results, ts2xy这两个函数
test_callback(None)



(1267.75, 845.1674316370692)

In [5]:
from stable_baselines3.common.results_plotter import load_results, ts2xy

#加载日志,这里找的是models/*.monitor.csv
load_results('models')

Unnamed: 0,index,r,l,t
0,0,32.0,32,0.049117
1,1,20.0,20,0.071115
2,2,23.0,23,0.097009
3,3,27.0,27,0.128370
4,4,41.0,41,0.173688
...,...,...,...,...
79,79,88.0,88,5.220903
80,80,80.0,80,5.312054
81,81,82.0,82,5.447958
82,82,121.0,121,5.603233


In [6]:
ts2xy(load_results('models'), 'timesteps')

(array([  32,   52,   75,  102,  143,  153,  169,  183,  200,  212,  227,
         252,  283,  317,  358,  377,  399,  421,  445,  505,  537,  627,
         674,  725,  753,  808,  852,  900,  989, 1053, 1117, 1203, 1251,
        1304, 1320, 1424, 1458, 1483, 1604, 1679, 1711, 1769, 1851, 1890,
        1960, 2048, 2159, 2204, 2324, 2387, 2424, 2476, 2541, 2564, 2587,
        2643, 2678, 2703, 2769, 2823, 2952, 3010, 3040, 3129, 3294, 3392,
        3420, 3562, 3699, 3740, 3987, 4057, 4087, 4151, 4261, 4302, 4331,
        4363, 4385, 4473, 4553, 4635, 4756, 4994]),
 array([ 32.,  20.,  23.,  27.,  41.,  10.,  16.,  14.,  17.,  12.,  15.,
         25.,  31.,  34.,  41.,  19.,  22.,  22.,  24.,  60.,  32.,  90.,
         47.,  51.,  28.,  55.,  44.,  48.,  89.,  64.,  64.,  86.,  48.,
         53.,  16., 104.,  34.,  25., 121.,  75.,  32.,  58.,  82.,  39.,
         70.,  88., 111.,  45., 120.,  63.,  37.,  52.,  65.,  23.,  23.,
         56.,  35.,  25.,  66.,  54., 129.,  58.,  30.,  89.

In [7]:
#保存最优模型
class SaveOnBestTrainingRewardCallback(BaseCallback):

    def __init__(self):
        super().__init__(verbose=0)

        self.best = -float('inf')

    def _on_step(self):
        #self.n_calls是个从1开始的数
        if self.n_calls % 1000 != 0:
            return True

        #读取日志
        x, y = ts2xy(load_results('models'), 'timesteps')

        #求最后100个reward的均值
        mean_reward = sum(y[-100:]) / len(y[-100:])

        print(self.num_timesteps, self.best, mean_reward)

        #判断保存
        if mean_reward > self.best:
            self.best = mean_reward
            print('save', x[-1])
            self.model.save('models/best_model')

        return True


test_callback(SaveOnBestTrainingRewardCallback())

1000 -inf 37.23076923076923
save 968
2000 37.23076923076923 45.86046511627907
save 1972
3000 45.86046511627907 54.870370370370374
save 2963
4000 54.870370370370374 58.705882352941174
save 3992
5000 58.705882352941174 65.13333333333334
save 4885




(175.35, 34.45616780781055)

In [8]:
#可以打印或者画图的callback
class PlottingCallback(BaseCallback):

    def __init__(self, verbose=0):
        super().__init__(verbose=0)

    def _on_step(self) -> bool:
        if self.n_calls % 1000 != 0:
            return True

        x, y = ts2xy(load_results('models'), 'timesteps')
        print(self.n_calls)
        print('x=', x)
        print('y=', y)

        return True


test_callback(PlottingCallback())

1000
x= [ 50  80 110 136 167 204 240 270 295 322 391 515 601 620 662 699 738 764
 805 838 873 904 929 954 979]
y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.
  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.]
2000
x= [  50   80  110  136  167  204  240  270  295  322  391  515  601  620
  662  699  738  764  805  838  873  904  929  954  979 1021 1042 1058
 1080 1108 1157 1210 1234 1264 1297 1337 1397 1451 1468 1523 1561 1591
 1632 1692 1752 1781 1814 1850 1898 1918 1959]
y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.
  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.  42.  21.  16.
  22.  28.  49.  53.  24.  30.  33.  40.  60.  54.  17.  55.  38.  30.
  41.  60.  60.  29.  33.  36.  48.  20.  41.]
3000
x= [  50   80  110  136  167  204  240  270  295  322  391  515  601  620
  662  699  738  764  805  838  873  904  929  954  979 1021 1042 1058
 1080 1108 1157 1210 1234 1264 1297 1337 1397 1451 1468 1523 1561 1591
 1632 



(101.4, 16.73439571660716)

In [9]:
from tqdm.auto import tqdm


#更新进度条的callback
class ProgressBarCallback(BaseCallback):

    def __init__(self):
        super().__init__()
        self.pbar = tqdm(total=5000)

    def _on_step(self):
        self.pbar.update(1)

    def _on_training_end(self) -> None:
        self.pbar.close()


test_callback(ProgressBarCallback())

  0%|          | 0/5000 [00:00<?, ?it/s]

(175.3, 13.849548729110273)

In [10]:
#同时使用多个callback
test_callback([PlottingCallback(), ProgressBarCallback()])

  0%|          | 0/5000 [00:00<?, ?it/s]

1000
x= [ 28  55  68  77  95 138 159 196 268 279 310 341 363 405 439 462 492 546
 600 657 722 765 851 882 930 969]
y= [28. 27. 13.  9. 18. 43. 21. 37. 72. 11. 31. 31. 22. 42. 34. 23. 30. 54.
 54. 57. 65. 43. 86. 31. 48. 39.]
2000
x= [  28   55   68   77   95  138  159  196  268  279  310  341  363  405
  439  462  492  546  600  657  722  765  851  882  930  969 1038 1088
 1133 1177 1209 1245 1308 1362 1506 1613 1675 1757 1954]
y= [ 28.  27.  13.   9.  18.  43.  21.  37.  72.  11.  31.  31.  22.  42.
  34.  23.  30.  54.  54.  57.  65.  43.  86.  31.  48.  39.  69.  50.
  45.  44.  32.  36.  63.  54. 144. 107.  62.  82. 197.]
3000
x= [  28   55   68   77   95  138  159  196  268  279  310  341  363  405
  439  462  492  546  600  657  722  765  851  882  930  969 1038 1088
 1133 1177 1209 1245 1308 1362 1506 1613 1675 1757 1954 2071 2231 2305
 2448 2675 2739 2874]
y= [ 28.  27.  13.   9.  18.  43.  21.  37.  72.  11.  31.  31.  22.  42.
  34.  23.  30.  54.  54.  57.  65.  43.  86.  31



(737.35, 255.4991731884861)