## 这一节使用wrapper对gym环境进行处理

In [8]:
#先引入下相关包
import gym_super_mario_bros
from gym.spaces import Box
from gym import Wrapper
from nes_py.wrappers import JoypadSpace#BinarySpaceToDiscreteSpaceEnv
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
import cv2
import numpy as np
import subprocess as sp

In [9]:
# RGB图像转灰度图
#借助cv2即（opencv）包快速转换COLOR_RGB2GRAY
def process_frame(frame):
    if frame is not None:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) #图像转换
        frame = cv2.resize(frame, (84, 84))[None, :, :] / 255. #裁剪合适大小，并归一化
        return frame
    else:
        return np.zeros((1, 84, 84))

In [10]:
#写一个继承Wrapper的包装类，一定注意在构造方法中调用父类的构造函数
#1.包装类copy了原环境的以下信息,在新的类初始化时要进行相应的修改,虽然这些量
#  在实现新的环境逻辑时不一定用得到;此外在包装类初始化时还要定义一些别的要
#  使用的变量
'''self.env = env
   self.action_space = self.env.action_space
   self.observation_space = self.env.observation_space
   self.reward_range = self.env.reward_range
   self.metadata = self.env.metadata'''
#2.同时新的环境逻辑通过重写step和reset方法实现,只能重写step和reset

class CustomReward(Wrapper):
  '''这个类的作用
  1.处理状态空间,将RGB转为灰度,并将图像裁剪为84x84
  2.设定新的奖励函数
  这里我们做了几个小优化如下：
      1).reward += (info["score"] - self.curr_score) / 40.
      原来的reward仅包含了对“离终点更近”的奖励和“时间消耗”、”死掉“的惩罚
      为了让游戏更好玩，我们添加了info["score"]，包含了对获得技能、金币的
      奖励，但不是重点，为了不影响整体要通关的属性，弱化他
      2).我们对回合结束时到达终点和未达到的奖励和惩罚进行放大，激励agent
      更快速的到达终点
      if done:
                  if info["flag_get"]:
                      reward += 50
                  else:
                      reward -= 50
      
      3.这里仅仅是对reward修改的一些示例，后面自己在实战时可以自己根据实际
      情况进行定义，比如当agent有时陷入一个错误的路线卡住时，可以添加一个缓
      冲区让agent学会后退等
  '''

  def __init__(self, env=None):
      super().__init__(env)
      self.observation_space= Box(low=0,high=255,shape=(1,84,84))
      self.curr_score = 0

  # 重写step方法以处理状态空间并规定新的奖励函数
  def step(self,action):
      # 走一步,拿到原有的奖励
      state,reward,done,info=self.env.step(action)
      state=process_frame(state)
      reward += (info["score"]-self.curr_score)/40.
      self.curr_score = info["score"]
      if done:
          if info["flag_get"]:
              reward += 50
          else:
              reward -= 50
      return state, reward / 10., done, info
  #reset需要初始化一些自定义变量并返回一个初始状态
  def reset(self):
    self.curr_score = 0
    return process_frame(self.env.reset())

In [11]:
## 在学习时并不需要所有帧,我们可以连续4帧给相同的输入,并将结果合并为一帧
class CustomSkipFrame(Wrapper):
    def __init__(self, env,skip=4) -> None:
        super().__init__(env)
        self.observation_space = Box(low=0, high=255, shape=(4, 84, 84))
        self.skip = skip

    def step(self, action):
        total_reward = 0
        states = []
        state, reward, done, info = self.env.step(action)
        for i in range(self.skip):
            if not done:
                state, reward, done, info = self.env.step(action)
                total_reward += reward
                states.append(state)
            else:
                states.append(state)
        states = np.concatenate(states, 0)[None, :, :, :]
        return states.astype(np.float32), reward, done, info  

    def reset(self):
        state = self.env.reset()
        states = np.concatenate([state for _ in range(self.skip)], 0)[None, :, :, :]
        return states.astype(np.float32)

In [12]:
#至此，我们完成了超级玛丽环境的自定义，封装如下：
def create_train_env(world, stage, action_type, output_path=None):
    env = gym_super_mario_bros.make("SuperMarioBros-{}-{}-v0".format(world, stage))
    if action_type == "right":
        actions = RIGHT_ONLY
    elif action_type == "simple":
        actions = SIMPLE_MOVEMENT
    else:
        actions = COMPLEX_MOVEMENT
    env = JoypadSpace(env, actions)
    env = CustomReward(env)
    env = CustomSkipFrame(env)
    return env, env.observation_space.shape[0], len(actions)

In [13]:
#测试一下
custom_env = create_train_env(1,1,'simple')
print(custom_env)


(<CustomSkipFrame<CustomReward<JoypadSpace<TimeLimit<SuperMarioBrosEnv<SuperMarioBros-1-1-v0>>>>>>, 4, 7)
