In [1]:
import os
import time
import numpy as np
import cv2
import torch
from torch import nn
import gym
import imageio

from matplotlib import pyplot as plt

from nes_py.wrappers import JoypadSpace

import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import GrayScaleObservation

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.vec_env import VecVideoRecorder
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

In [2]:
base_path='C:\\Projects\\rl_mario\\'

In [3]:
class ResizeEnv(gym.ObservationWrapper):
    def __init__(self, env, size):
        gym.ObservationWrapper.__init__(self, env)
        (oldh, oldw, oldc) = env.observation_space.shape
        newshape = (size, size, oldc)
        self.observation_space = gym.spaces.Box(low=0, high=255,
            shape=newshape, dtype=np.uint8)

    def observation(self, frame):
        height, width, _ = self.observation_space.shape
        frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
        if frame.ndim == 2:
            frame = frame[:,:,None]
        return frame

In [4]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info

In [5]:
env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)

monitor_dir = base_path + r'./monitor_log/'
os.makedirs(monitor_dir,exist_ok=True)
env = Monitor(env,monitor_dir)

env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env, keep_dim=True)
env = ResizeEnv(env, size=84)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env,4,channels_order='last')

  logger.warn(
  deprecation(
  deprecation(


In [6]:
best_params={
    'n_steps': 512,
    'learning_rate': 0.0001,
    'batch_size': 64,
    'n_epochs': 10,
    'gamma': 0.9,
    'gae_lambda': 1.0,
    'ent_coef': 0.01,
}

In [7]:
class MarioNet(BaseFeaturesExtractor):

    def __init__(self, observation_space: gym.spaces.Box, features_dim):
        super(MarioNet, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

policy_kwargs = dict(
    features_extractor_class=MarioNet,
    features_extractor_kwargs=dict(features_dim=512),
)

In [8]:
model = PPO(
    "CnnPolicy", 
    env, 
    verbose=1,
    policy_kwargs=policy_kwargs,
    #learning_rate=linear_schedule(3e-4),
    device='cuda',
    **best_params
)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [9]:
model.set_parameters(r'C:\\Projects\\rl_mario\\models\\model_500000.zip')

Exception: _generator_ctor() takes from 0 to 1 positional arguments but 2 were given
  th_object = th.load(file_content, map_location=device)


In [10]:
# keys = ['high', 'low', 'bounded_above', 'bounded_below']
# setattr(env.observation_space, '_shape', (4,240,256))
# for k in keys:
#     new_attr = getattr(env.observation_space, k).reshape(4,240,256)
#     setattr(env.observation_space, k, new_attr)
# model = PPO.load(r'C:\\Projects\\rl_mario\\models\\model_2000000.zip', env=env, 
#     custom_objects = {'observation_space': env.observation_space, 'action_space': env.action_space})

In [11]:
obs = env.reset()

In [12]:
# video_length = 100
# video_folder = base_path + "video\\"

# env = VecVideoRecorder(env, 
#                            video_folder, 
#                            record_video_trigger=lambda x: x == 0, 
#                            video_length=video_length,
#                            name_prefix=f"mario")

# env.reset()

In [13]:
def save_gif(frames, i):
    imageio.mimsave(
        base_path + f"video\\mario_{i}.gif", 
        [np.array(img) for i, img in enumerate(frames)], 
        fps=27
    )

In [14]:
# record
frames = []
total_length = 1800
done = True
win = 0

#while True:
for i in range(total_length):
    if done:
        state = env.reset()
        
    img = env.render(mode="rgb_array")
    frames.append(img.copy())
    action, _ = model.predict(obs)
    obs, _, done , info = env.step(action)

    if info[0]["flag_get"]:
        print('win!')
        win += 1
        
        print(len(frames))
        save_gif(frames, win)
        frames = []
        
        if win > 2:
            break

save_gif(frames, 'f2')
# imageio.mimsave(
#     base_path + "video\\mario.gif", 
#     [np.array(img) for i, img in enumerate(frames) if i%2 == 0], fps=50)
# imageio.mimsave(
#     base_path + f"video\\mario_{win}.gif", 
#     [np.array(img) for i, img in enumerate(frames)], fps=27)

See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):


In [27]:
# for i in range(12):
#     plt.figure(figsize=(16,16))
#     plt.subplot(2,6,i+1)
#     plt.imshow(frames[i])
# plt.show()

In [None]:
# for ima in paths:
#     img = cv2.imread(ima)
#     frames.append(img)
# imageio.mimsave(base_path + "video\\mario.gif", frames, 'GIF', duration=0.1)

In [14]:
# run
done = True
while True:
    if done:
        state = env.reset()
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    env.render()
        
    if info[0]["flag_get"]:
        print('win!')
        break

env.close()

  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  logger.warn(
Your graphics drivers do not support OpenGL 2.0.
You may experience rendering issues or crashes.
Microsoft Corporation
GDI Generic
1.1.0


KeyboardInterrupt: 