In [1]:
!pip install gym_super_mario_bros==7.4.0 nes_py

!pip install torch torchvision torchaudio

!pip install stable-baselines3[extra]

%load_ext tensorboard

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
import gym_super_mario_bros

from nes_py.wrappers import JoypadSpace

from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

import os

from stable_baselines3 import PPO
from stable_baselines3 import DQN
from stable_baselines3 import A2C

from stable_baselines3.common.callbacks import BaseCallback

In [3]:
# Base environment setup
env = gym_super_mario_bros.make('SuperMarioBros2-v0')
# Simplify controls
env = JoypadSpace(env, SIMPLE_MOVEMENT)

In [4]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [5]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [6]:
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 11692), started 0:05:00 ago. (Use '!kill 11692' to kill it.)

In [7]:
callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)

In [8]:
model = DQN('CnnPolicy', env, verbose=1, tensorboard_log=LOG_DIR, learning_rate=0.001, buffer_size=256, seed=20)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


In [9]:
import time

In [10]:
start_time = time.time()
model.learn(total_timesteps=1000000, callback=callback)
training_time = time.time() - start_time

Logging to ./logs/DQN_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.71e+04 |
|    ep_rew_mean      | 554      |
|    exploration_rate | 0.351    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 188      |
|    time_elapsed     | 363      |
|    total_timesteps  | 68334    |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.0318   |
|    n_updates        | 4583     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.2e+04  |
|    ep_rew_mean      | 1.05e+03 |
|    exploration_rate | 0.0857   |
| time/               |          |
|    episodes         | 8        |
|    fps              | 119      |
|    time_elapsed     | 806      |
|    total_timesteps  | 96239    |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.0309 

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.26e+03 |
|    ep_rew_mean      | 1.36e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 68       |
|    fps              | 77       |
|    time_elapsed     | 2854     |
|    total_timesteps  | 221990   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 1.19     |
|    n_updates        | 42997    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.18e+03 |
|    ep_rew_mean      | 1.39e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 72       |
|    fps              | 77       |
|    time_elapsed     | 2968     |
|    total_timesteps  | 229063   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.275    |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.69e+03 |
|    ep_rew_mean      | 1.34e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 132      |
|    fps              | 71       |
|    time_elapsed     | 4511     |
|    total_timesteps  | 323799   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.359    |
|    n_updates        | 68449    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.68e+03 |
|    ep_rew_mean      | 1.34e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 136      |
|    fps              | 71       |
|    time_elapsed     | 4591     |
|    total_timesteps  | 328763   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.34     |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.63e+03 |
|    ep_rew_mean      | 1.3e+03  |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 196      |
|    fps              | 68       |
|    time_elapsed     | 6254     |
|    total_timesteps  | 430719   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.341    |
|    n_updates        | 95179    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.66e+03 |
|    ep_rew_mean      | 1.3e+03  |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 200      |
|    fps              | 68       |
|    time_elapsed     | 6392     |
|    total_timesteps  | 439290   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.249    |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.79e+03 |
|    ep_rew_mean      | 1.41e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 260      |
|    fps              | 67       |
|    time_elapsed     | 8174     |
|    total_timesteps  | 548545   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.332    |
|    n_updates        | 124636   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.8e+03  |
|    ep_rew_mean      | 1.43e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 264      |
|    fps              | 67       |
|    time_elapsed     | 8321     |
|    total_timesteps  | 557599   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.0874   |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.57e+03 |
|    ep_rew_mean      | 1.37e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 66       |
|    time_elapsed     | 9656     |
|    total_timesteps  | 639693   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.535    |
|    n_updates        | 147423   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.57e+03 |
|    ep_rew_mean      | 1.38e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 66       |
|    time_elapsed     | 9765     |
|    total_timesteps  | 646385   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.185    |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.59e+03 |
|    ep_rew_mean      | 1.33e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 388      |
|    fps              | 65       |
|    time_elapsed     | 11480    |
|    total_timesteps  | 751676   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.199    |
|    n_updates        | 175418   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.6e+03  |
|    ep_rew_mean      | 1.34e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 392      |
|    fps              | 65       |
|    time_elapsed     | 11557    |
|    total_timesteps  | 756514   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.903    |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.76e+03 |
|    ep_rew_mean      | 1.42e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 452      |
|    fps              | 64       |
|    time_elapsed     | 13319    |
|    total_timesteps  | 864885   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.41     |
|    n_updates        | 203721   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.8e+03  |
|    ep_rew_mean      | 1.45e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 456      |
|    fps              | 64       |
|    time_elapsed     | 13451    |
|    total_timesteps  | 872943   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.292    |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.85e+03 |
|    ep_rew_mean      | 1.46e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 516      |
|    fps              | 64       |
|    time_elapsed     | 15324    |
|    total_timesteps  | 987955   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 3.12     |
|    n_updates        | 234488   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.88e+03 |
|    ep_rew_mean      | 1.47e+03 |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 520      |
|    fps              | 64       |
|    time_elapsed     | 15470    |
|    total_timesteps  | 996927   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.378    |
|    n_updates      

In [11]:
import numpy as np
import tensorflow as tf

In [12]:
obs = env.reset()
start_time = time.time()
total_rewards = 0
game_scores = []
steps_per_episode = []
episode_rewards = []
while True:
    action, states = model.predict(obs.copy())
    action = action.item()  # convert NumPy array to scalar integer
    obs, rewards, done, info = env.step((action))
    total_rewards += rewards
    episode_rewards.append(rewards)
    if done:
        game_scores.append(info['score'])
        steps_per_episode.append(len(episode_rewards))
        episode_rewards = []
        obs = env.reset()
        if len(game_scores) == 10:
            break
test_time = time.time() - start_time

In [13]:
print(f'Average steps per episode for seed 10: {sum(steps_per_episode)/len(steps_per_episode)}')
print(f'Average mean reward for seed 10: {total_rewards/len(game_scores)}')
print(f'Average game score for seed 10: {sum(game_scores)/len(game_scores)}')
print(f'Training time for seed 10: {training_time:.2f} seconds')
print(f'Test time for seed 10: {test_time:.2f} seconds')

Average steps per episode for seed 10: 10106.1
Average mean reward for seed 10: 966.7
Average game score for seed 10: 600.0
Training time for seed 10: 15521.58 seconds
Test time for seed 10: 365.16 seconds


In [14]:
model = DQN.load('./train/best_model_50000/')