<a href="https://colab.research.google.com/github/unknown-yuser/ppo_super-mario-bros/blob/main/ppo_super_mario.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%shell

# GPU情報を表示
nvidia-smi

# スーパーマリオブラザーズの環境を用意
pip install -q nes-py gym-super-mario-bros

# 強化学習ライブラリをインストール
pip install -q stable-baselines3[extra]

# 学習結果の描画
apt-get install -qq ffmpeg freeglut3-dev xvfb
pip install -q pyglet pyvirtualdisplay

In [None]:
#@title マリオブラザーズ環境

WORLD=1#@param {type:'integer'}
STAGE=1#@param {type:'integer'}
ACTION_TYPE_STR = "SIMPLE_MOVEMENT" #@param ["RIGHT_ONLY", "SIMPLE_MOVEMENT", "COMPLEX_MOVEMENT"]

MARIO_ENV=f'SuperMarioBros-{WORLD}-{STAGE}-v0'

from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

ACTION_TYPE_MAP = {
    'RIGHT_ONLY': RIGHT_ONLY,
    'SIMPLE_MOVEMENT': SIMPLE_MOVEMENT,
    'COMPLEX_MOVEMENT': COMPLEX_MOVEMENT
}

# getter
def mario_action():
    return ACTION_TYPE_MAP[ACTION_TYPE_STR]

In [None]:
#@title 学習/テスト 設定

OUTPUT_DIR='super_mario_bros/'#@param {type:'string'}
VIDEO_DIR='video/' #@param {type:'string'}
VIDEO_NAME_PREFIX="mario_play" #@param {type:'string'}

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.environ['OUTPUT_DIR'] = OUTPUT_DIR

FIRST_LEARNING_RATE = 0.0002 #@param {type:"number"}
LAST_LEARNING_RATE = 0.00003 #@param {type:"number"}
GAMMA = 0.95 #@param {type:"slider", min:0.8, max:1, step:0.01}
LAMBDA = 0.84 #@param {type:"slider", min:0.8, max:1, step:0.01}
REWARD_THRESHOLD = 3000 #@param {type:"number"}

MAX_STEPS = 600000 #@param {type:"integer"}
EVAL_INTERVAL = 40000 #@param {type:"integer"}

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)

def best_model_path():
    return os.path.join(OUTPUT_DIR, "best_model")

In [None]:
import gym
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv, WarpFrame
from stable_baselines3.common.monitor import Monitor

def mario_env(train:bool = False) -> gym.Env:
    e = gym_super_mario_bros.make(MARIO_ENV)
    e = JoypadSpace(e, mario_action())
    e = MaxAndSkipEnv(e, skip=4)
    e = WarpFrame(e, width=84, height=84)
    if train:
        return Monitor(e, OUTPUT_DIR)
    else:
        return e

In [None]:
from stable_baselines3 import PPO
from typing import Callable

def learning_rate_schedule(first_value: float, last_value: float) -> Callable[[float], float]:
    def func(progress_remaining: float):
        return last_value + (first_value - last_value) * progress_remaining
    return func

player = PPO(
    "CnnPolicy",
    mario_env(train=True), 
    learning_rate=learning_rate_schedule(FIRST_LEARNING_RATE, LAST_LEARNING_RATE),
    batch_size=32,
    gamma=GAMMA,
    gae_lambda=LAMBDA,
    create_eval_env=True,
    ent_coef=0.02,
    vf_coef=1.0,
    tensorboard_log=OUTPUT_DIR,
    verbose=1)

In [None]:
?PPO

In [None]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=REWARD_THRESHOLD, verbose=1)
eval_callback = EvalCallback(
    mario_env(train=True),
    callback_on_new_best=callback_on_best,
    eval_freq=EVAL_INTERVAL,
    best_model_save_path=best_model_path(),
    deterministic=False,
    verbose=1)

%time player.learn(total_timesteps=MAX_STEPS, callback=eval_callback)

In [None]:
%load_ext tensorboard
%tensorboard --logdir $OUTPUT_DIR

In [None]:
def load_best_player():
    return PPO.load(os.path.join(best_model_path(), 'best_model'))

In [None]:
from gym.wrappers import RecordVideo

def record_video(model: BaseAlgorithm, name_prefix:str, video_folder: str):
    """
    :param model: (RL model)
    :param video_folder: (str)
    """
    eval_env = RecordVideo(
        mario_env(),
        video_length=5000,
        video_folder=video_folder,
        name_prefix=name_prefix
    )
    
    obs = eval_env.reset()
    cnt_terminal_reached = 0
    max_cnt_terminal_reached = 10
    while True:
        action, _ = model.predict(obs)
        obs, _, done, _ = eval_env.step(action)
        if done:
            cnt_terminal_reached = cnt_terminal_reached + 1
            eval_env.reset()
            if cnt_terminal_reached >= max_cnt_terminal_reached:
                eval_env.close_video_recorder()
                break

record_video(load_best_player(), VIDEO_NAME_PREFIX, VIDEO_DIR)

In [None]:
import base64
from pathlib import Path
from IPython import display as ipdisplay
from pyvirtualdisplay import Display

# from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv

display = Display(visible=0, size=(400, 300))
display.start()

def show_video(name_prefix:str, video_folder: str):
    """
    Reference from https://github.com/eleurent/highway-env

    :param name_prefix: (str) Filter the video, showing only the only starting with this prefix
    :param video_folder: (str) Path to the folder containing videos
    """
    html = []
    for mp4 in Path(video_folder).glob('{}*.mp4'.format(name_prefix)):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append('''<video alt="{}" autoplay
        loop controls style="height: 400px;">
        <source src="data:video/mp4;base64,{}" type="video/mp4" />
        </video>'''.format(mp4, video_b64.decode('ascii')))
    ipdisplay.display(ipdisplay.HTML(data="<br>".join(html)))

show_video(VIDEO_NAME_PREFIX, VIDEO_DIR)