In [None]:
!pip install gym_super_mario_bros==7.3.0

In [None]:
!pip install nes_py

In [2]:
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

In [None]:
#Setup game
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)

In [None]:
done = True
for step in range(100000):
    if done:
        env.reset()
    state, reward, done, info = env.step(env.action_space.sample())
    env.render()
env.close()

# Preprocess Environment

In [None]:
!pip install stable-baselines3[extra]

In [3]:
from gym.wrappers import FrameStack, GrayScaleObservation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from matplotlib import pyplot as plt

In [4]:
#Base environment
env = gym_super_mario_bros.make('SuperMarioBros-v0')
#Simplify controls
env = JoypadSpace(env, SIMPLE_MOVEMENT)
#Grayscale
env = GrayScaleObservation(env, keep_dim = True)
#Dummy environment
env = DummyVecEnv([lambda:env])
#Stack the frames
env = VecFrameStack(env, 4, channels_order='last')

In [4]:
state = env.reset()

In [5]:
state.shape

(1, 240, 256, 4)

In [6]:
state, reward, done, info = env.step([env.action_space.sample()])

In [None]:
plt.figure(figsize=(10,8))
for idx in range(state.shape[3]):
    plt.subplot(1,4, idx+1)
    plt.imshow(state[0][:,:,idx])
plt.show()

# Train the RL Model

In [6]:
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

In [8]:
class TrainAndLoggingCallback(BaseCallback):
    
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
    
    def __init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
            
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
            
        return True

In [10]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [12]:
callback = TrainAndLoggingCallback(check_freq = 10000, save_path = CHECKPOINT_DIR)

In [19]:
model = PPO('CnnPolicy', env, verbose=1, tensorboard_log=LOG_DIR, learning_rate=0.0001, n_steps=512)

Using cpu device
Wrapping the env in a VecTransposeImage.


In [20]:
model.learn(total_timesteps=1000000, callback=callback)

Logging to ./logs/PPO_2
----------------------------
| time/              |     |
|    fps             | 113 |
|    iterations      | 1   |
|    time_elapsed    | 4   |
|    total_timesteps | 512 |
----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 17          |
|    iterations           | 2           |
|    time_elapsed         | 59          |
|    total_timesteps      | 1024        |
| train/                  |             |
|    approx_kl            | 0.012356691 |
|    clip_fraction        | 0.142       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.94       |
|    explained_variance   | -0.00121    |
|    learning_rate        | 0.0001      |
|    loss                 | 8.47        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0058     |
|    value_loss           | 64          |
-----------------------------------------
-----------------

-----------------------------------------
| time/                   |             |
|    fps                  | 8           |
|    iterations           | 13          |
|    time_elapsed         | 781         |
|    total_timesteps      | 6656        |
| train/                  |             |
|    approx_kl            | 0.013161514 |
|    clip_fraction        | 0.176       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.78       |
|    explained_variance   | -0.046      |
|    learning_rate        | 0.0001      |
|    loss                 | 0.0934      |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.00707    |
|    value_loss           | 0.268       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 8           |
|    iterations           | 14          |
|    time_elapsed         | 851         |
|    total_timesteps      | 7168  

----------------------------------------
| time/                   |            |
|    fps                  | 7          |
|    iterations           | 24         |
|    time_elapsed         | 1544       |
|    total_timesteps      | 12288      |
| train/                  |            |
|    approx_kl            | 0.12818667 |
|    clip_fraction        | 0.104      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.244     |
|    explained_variance   | 0.375      |
|    learning_rate        | 0.0001     |
|    loss                 | 18.8       |
|    n_updates            | 230        |
|    policy_gradient_loss | 0.0157     |
|    value_loss           | 219        |
----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 7            |
|    iterations           | 25           |
|    time_elapsed         | 1626         |
|    total_timesteps      | 12800        |
| tr

----------------------------------------
| time/                   |            |
|    fps                  | 7          |
|    iterations           | 35         |
|    time_elapsed         | 2425       |
|    total_timesteps      | 17920      |
| train/                  |            |
|    approx_kl            | 0.01785919 |
|    clip_fraction        | 0.0176     |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.119     |
|    explained_variance   | 0.0264     |
|    learning_rate        | 0.0001     |
|    loss                 | 131        |
|    n_updates            | 340        |
|    policy_gradient_loss | -0.00565   |
|    value_loss           | 263        |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 7           |
|    iterations           | 36          |
|    time_elapsed         | 2515        |
|    total_timesteps      | 18432       |
| train/  

-----------------------------------------
| time/                   |             |
|    fps                  | 6           |
|    iterations           | 46          |
|    time_elapsed         | 3412        |
|    total_timesteps      | 23552       |
| train/                  |             |
|    approx_kl            | 0.005086939 |
|    clip_fraction        | 0.145       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1          |
|    explained_variance   | -0.0636     |
|    learning_rate        | 0.0001      |
|    loss                 | 0.114       |
|    n_updates            | 450         |
|    policy_gradient_loss | 0.00244     |
|    value_loss           | 1.75        |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 6            |
|    iterations           | 47           |
|    time_elapsed         | 3498         |
|    total_timesteps      | 2

KeyboardInterrupt: 

# Test

In [21]:
model = PPO.load('./train/best_model_30000')

In [22]:
state = env.reset()
while True:
    action, _state = model.predict(state)
    state, reward, done, info = env.step(action)
    env.render()



KeyboardInterrupt: 