# Gym Wrappers
了解如何使用 Gym Wrappers，它可以进行监控、标准化、限制步数、功能增强等

加载和保存功能，以及如何读取输出的文件以进行可能的导出。

In [None]:
!pip install swig
!pip install stable-baselines3

# 导入gym和RL算法库

In [4]:
import gymnasium as gym
from stable_baselines3 import A2C, SAC, PPO, TD3

# 保存和加载模型

In [6]:
import os

save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = PPO("MlpPolicy", "Pendulum-v1", verbose=0).learn(8_000)
model.save(f"{save_dir}/PPO_tutorial")

obs = model.env.observation_space.sample()

print("pre saved", model.predict(obs, deterministic=True))

del model

loaded_model = PPO.load(f"{save_dir}/PPO_tutorial")
print("loaded", loaded_model.predict(obs, deterministic=True))

pre saved (array([-0.04750421], dtype=float32), None)
loaded (array([-0.04750421], dtype=float32), None)


# Gym 和 VecEnv 包装器

## 自定义一个简单的包装器

In [7]:
class CustomWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    """
    def __init__(self, env):
        # 调用父构造函数，这样我们就可以稍后访问 self.env
        super().__init__(env)
    
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        obs, info = self.env.reset(**kwargs)
        return obs, info 
    
    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, bool, dict) observation, reward, is this a final state (episode finished),
        is the max number of steps reached (episode finished artificially), additional informations
        """
        obs, reward, terminnated, truncated, info = self.env.step(action)
        return obs, reward, terminnated, truncated, info

## 限制episode长度
包装器的一个实际用例是当您想要按情节限制步骤数时，因为达到限制时您需要覆盖完成信号。在信息字典中传递该信息也是一个很好的做法。

In [9]:
class TimeLimitWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    :param max_steps: (int) Max number of steps per episode
    """
    def __init__(self, env, max_steps=100):
        super(TimeLimitWrapper, self).__init__(env)
        self.max_steps = max_steps
        self.current_step = 0
    
    def reset(self, **kwargs):
        self.current_step = 0
        return self.env.reset(**kwargs)
    
    def step(self, action):
        self.current_step += 1
        obs, reward, terminated, truncated, info = self.env.step(action)
        if self.current_step >= self.max_steps:
            truncated = True
        return obs, reward, terminated, truncated, info

测试包装器