In [1]:
import gymnasium as gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        env = gym.make('CartPole-v1')
        super().__init__(env)
        self.env = env

    def reset(self,**kwargs):

        return self.env.reset(**kwargs)

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        return state, reward, terminated, truncated, info


MyWrapper().reset()

(array([ 0.01549727,  0.03949478, -0.03268917,  0.03669816], dtype=float32),
 {})

In [2]:
from stable_baselines3.common.callbacks import BaseCallback


#Callback语法
class CustomCallback(BaseCallback):

    def __init__(self, verbose=0):
        super().__init__(verbose)

        #可以访问的变量
        #self.model
        #self.training_env
        #self.n_calls
        #self.num_timesteps
        #self.locals
        #self.globals
        #self.logger
        #self.parent

    def _on_training_start(self) -> None:
        #第一个rollout开始前调用
        pass

    def _on_rollout_start(self) -> None:
        #rollout开始前
        pass

    def _on_step(self) -> bool:
        #env.step()之后调用,返回False后停止训练
        return True

    def _on_rollout_end(self) -> None:
        #更新参数前调用
        pass

    def _on_training_end(self) -> None:
        #训练结束前调用
        pass


CustomCallback()

  from .autonotebook import tqdm as notebook_tqdm


<__main__.CustomCallback at 0x1d706152f80>

In [3]:
from stable_baselines3 import PPO


#让训练只执行N步的callback
class SimpleCallback(BaseCallback):

    def __init__(self):
        super().__init__(verbose=0)
        self.call_count = 0

    def _on_step(self):
        self.call_count += 1

        if self.call_count % 20 == 0:
            print(self.call_count)

        if self.call_count >= 100:
            return False

        return True


model = PPO('MlpPolicy', MyWrapper(), verbose=0)

model.learn(8000, callback=SimpleCallback())

20
40
60
80
100


<stable_baselines3.ppo.ppo.PPO at 0x1d7121c3c40>

In [6]:
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy



def test_callback(callback):

    #创建Monitor封装的环境,这会在训练过程中写出日志文件到models文件夹
    env = make_vec_env(MyWrapper, n_envs=2, monitor_dir='models')

    #等价写法
    # from stable_baselines3.common.monitor import Monitor
    # from stable_baselines3.common.vec_env import DummyVecEnv
    # env = Monitor(MyWrapper(), 'models')
    # env = DummyVecEnv([lambda: env])

    #训练
    model = A2C('MlpPolicy', env, verbose=0).learn(total_timesteps=1000,
                                                   callback=callback)

    #测试
    return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)


#使用Monitor封装的环境训练一个模型,保存下日志
#只是为了测试load_results, ts2xy这两个函数
test_callback(None)

(40.7, 10.149384217774003)

In [7]:
from stable_baselines3.common.results_plotter import load_results, ts2xy

#加载日志,这里找的是models/*.monitor.csv
load_results('models')

Unnamed: 0,index,r,l,t
0,0,14.0,14,0.070162
1,0,45.0,45,0.152162
2,1,35.0,35,0.164675
3,1,17.0,17,0.200032
4,2,18.0,18,0.217546
5,3,12.0,12,0.250064
6,2,26.0,26,0.275986
7,4,43.0,43,0.363406
8,3,47.0,47,0.385153
9,5,16.0,16,0.392659


In [8]:
ts2xy(load_results('models'), 'timesteps')

(array([ 14,  59,  94, 111, 129, 141, 167, 210, 257, 273, 300, 336, 402,
        421, 443, 479, 513, 610, 647, 663, 678, 721, 778, 840, 865, 884,
        919, 956], dtype=int64),
 array([14., 45., 35., 17., 18., 12., 26., 43., 47., 16., 27., 36., 66.,
        19., 22., 36., 34., 97., 37., 16., 15., 43., 57., 62., 25., 19.,
        35., 37.]))

In [10]:
#保存最优模型
class SaveOnBestTrainingRewardCallback(BaseCallback):

    def __init__(self):
        super().__init__(verbose=0)

        self.best = -float('inf')

    def _on_step(self):
        #self.n_calls是个从1开始的数
        if self.n_calls % 1000 != 0:
            return True

        #读取日志
        x, y = ts2xy(load_results('models'), 'timesteps')

        #求最后100个reward的均值
        mean_reward = sum(y[-100:]) / len(y[-100:])

        print(self.num_timesteps, self.best, mean_reward)

        #判断保存
        if mean_reward > self.best:
            self.best = mean_reward
            print('save', x[-1])
            self.model.save('models/best_model')

        return True


test_callback(SaveOnBestTrainingRewardCallback())

(9.1, 0.7000000000000001)

In [11]:
#可以打印或者画图的callback
class PlottingCallback(BaseCallback):

    def __init__(self, verbose=0):
        super().__init__(verbose=0)

    def _on_step(self) -> bool:
        if self.n_calls % 1000 != 0:
            return True

        x, y = ts2xy(load_results('models'), 'timesteps')
        print(self.n_calls)
        print('x=', x)
        print('y=', y)

        return True


test_callback(PlottingCallback())

(33.2, 6.786751800382861)

In [16]:
from tqdm.auto import tqdm


#更新进度条的callback
class ProgressBarCallback(BaseCallback):

    def __init__(self):
        super().__init__()
        self.pbar = tqdm(total=5000)

    def _on_step(self):
        self.pbar.update(1)

    def _on_training_end(self) -> None:
        self.pbar.close()


test_callback(ProgressBarCallback())

  0%|          | 1/5000 [00:00<00:48, 102.98it/s]


(9.75, 0.6224949798994366)

In [17]:
#同时使用多个callback
test_callback([PlottingCallback(), ProgressBarCallback()])

  0%|          | 1/5000 [00:00<00:49, 100.66it/s]


(10.8, 2.2494443758403984)