In [1]:
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
import time
from matplotlib import pyplot as plt
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
import os
from stable_baselines3 import PPO

from stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback

In [2]:
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)

monitor_dir = r'./monitor_log/'
os.makedirs(monitor_dir,exist_ok=True)
env = Monitor(env,monitor_dir)

env = GrayScaleObservation(env,keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env,4,channels_order='last')

In [3]:
tensorboard_log = r'./tensorboard_log/'
learning_rate = 1e-6
n_steps = 2048
model = PPO("CnnPolicy", env, verbose=1,
            tensorboard_log = tensorboard_log,
            learning_rate = learning_rate,
            n_steps = n_steps)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [4]:
class SaveOnBestTrainingRewardCallback(BaseCallback):
    """
    Callback for saving a model (the check is done every ``check_freq`` steps)
    based on the training reward (in practice, we recommend using ``EvalCallback``).

    :param check_freq: (int)
    :param log_dir: (str) Path to the folder where the model will be saved.
      It must contains the file created by the ``Monitor`` wrapper.
    :param verbose: (int)
    """

    def __init__(self, check_freq, log_dir, verbose=1):
        super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.log_dir = log_dir
        self.save_path = os.path.join(log_dir, 'best_model')
        self.best_mean_reward = -np.inf

    # def _init_callback(self) -> None:
    def _init_callback(self):
        # Create folder if needed
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    # def _on_step(self) -> bool:
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            print('self.n_calls: ',self.n_calls)
            self.model.save(self.save_path)
            # Retrieve training reward
            x, y = ts2xy(load_results(self.log_dir), 'timesteps')
            if len(x) > 0:
                # Mean training reward over the last 100 episodes
                mean_reward = np.mean(y[-100:])
                # if self.verbose > 0:
                #     print("Num timesteps: {}".format(self.num_timesteps))
                #     print(
                #         "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}"
                #         .format(self.best_mean_reward, mean_reward))

                # New best model, you could save the agent here
                if mean_reward > self.best_mean_reward:
                    self.best_mean_reward = mean_reward
                    # Example for saving best model
                    # if self.verbose > 0:
                    #     print("Saving new best model at {} timesteps".format(
                    #         x[-1]))
                    #     print("Saving new best model to {}.zip".format(
                    #         self.save_path))
                    self.model.save(self.save_path)

        return True

In [5]:
log_dir = monitor_dir
callback1 = SaveOnBestTrainingRewardCallback(1000, log_dir)

model.learn(total_timesteps=100,callback=callback1)

Logging to ./tensorboard_log/PPO_4
self.n_calls:  10
self.n_calls:  20
self.n_calls:  30
self.n_calls:  40
self.n_calls:  50
self.n_calls:  60
self.n_calls:  70
self.n_calls:  80
self.n_calls:  90
self.n_calls:  100
self.n_calls:  110
self.n_calls:  120
----------------------------
| time/              |     |
|    fps             | 6   |
|    iterations      | 1   |
|    time_elapsed    | 20  |
|    total_timesteps | 128 |
----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x22a09446e48>