## Install dependencies

In [1]:
!pip install nes-py > /dev/null
!pip install gym-super-mario-bros > /dev/null

## Imports

In [2]:
from google.colab import drive, files
drive.mount('/content/drive')

Mounted at /content/drive


In [38]:
import base64
from functools import reduce
import glob
import io
from IPython import display as ipythondisplay
from IPython.display import HTML
import numpy as np
import pickle
from typing import List, Tuple, Union, Optional, NamedTuple
from tqdm.notebook import tqdm as tqdm_notebook
import operator
import os

import torch
from torch import nn

from nes_py.wrappers import JoypadSpace
import gym
import gym_super_mario_bros
from gym_super_mario_bros.smb_env import SuperMarioBrosEnv
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import Monitor

## Definitions

In [23]:
class Individual(NamedTuple('Individual', (
    ('weights', torch.Tensor),
    ))):
    def __copy__(self):
        return Individual(self.weights.copy())

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Network(nn.Module):
    def __init__(self, layers: nn.Module) -> None:
        super(Network, self).__init__()
        self.layers = layers

    def forward(self, state: torch.Tensor) -> int:
        output = self.layers.forward(state.view(1, 3, 240, 256))
        return torch.argmax(output).item()

    def set_weights(self, weights: torch.Tensor) -> None:
        cpt = 0
        for param in self.parameters():
            tmp = reduce(operator.mul, param.size())

            param.data.copy_(weights[cpt:cpt + tmp].view(param.size()).to(device))
            cpt += tmp

    def get_weights(self) -> Individual:
        return Individual(weights=torch.hstack([v.flatten() for v in self.parameters()]))

In [25]:
def make_wrapped_env(env_name: str, folder: str) -> gym.wrappers.time_limit.TimeLimit:
    env = gym_super_mario_bros.make(env_name)
    env = JoypadSpace(env, SIMPLE_MOVEMENT)

    env = _wrap_env(env, folder)
    return env
    

def _wrap_env(env, folder):
    env = Monitor(env, f'{VIDEOS_FOLDER_PATH}/{folder}', force=True)
    return env

In [26]:
def _save_selected_video(
    actor: Network,
    folder: str,
    max_iters: int,
    world_stage: Tuple[int, int],
) -> None:
    env = gym_super_mario_bros.make('SuperMarioBrosRandomStages-v0')
    env = JoypadSpace(env, SIMPLE_MOVEMENT)

    # Need to be set before calling env.step, but wasn't as we didn't call env.reset directly.
    env = Monitor(env, f'{VIDEOS_FOLDER_PATH}/{folder}', force=True)
    env.reset()
    env.env.env = SuperMarioBrosEnv(target=world_stage)
    state = env.env.env.reset()
    env.env._elapsed_steps = 0
    env.stats_recorder.total_steps = 0

    for _ in range(max_iters):
        action = network.forward(torch.Tensor(state.copy()).to(device))
        state, reward, done, info = env.step(action)

        if done:
            break

def save_selected_videos(
    individual: Individual, 
    network: Network, 
    folder: str,
    max_iters_per_map: int,
    all_world_stages: List[Tuple[int, int]],
) -> None:
    actor = network.set_weights(individual.weights)

    for world_stage in tqdm_notebook(all_world_stages, desc='Maps played', position=0):
        current_folder = os.path.join(folder, str(world_stage))

        _save_selected_video(
            actor=actor,
            folder=current_folder,
            max_iters=max_iters_per_map,
            world_stage=world_stage,
        )

In [27]:
def show_videos(folder: str):
    absolute_folder_path = os.path.join(VIDEOS_FOLDER_PATH, folder)
    for subfolder in os.listdir(absolute_folder_path):
        video_folder_path = os.path.join(absolute_folder_path, subfolder)
        if os.path.isdir(video_folder_path):
            _show_video(video_folder_path)


def _show_video(folder):
    path = f'{folder}/*.mp4'
    mp4list = glob.glob(path)
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 300px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")

In [28]:
def load_individual(path: str) -> Individual:
    weights = torch.load(open(path, 'rb'))

    return Individual(weights=weights)

## Initialization

In [29]:
network = Network(layers=nn.Sequential(
    nn.Conv2d(3, 32, kernel_size=8, stride=4),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2),
    nn.ReLU(),
    nn.Conv2d(64, 64, kernel_size=3, stride=1),
    nn.ReLU(),

    nn.Flatten(),
    nn.Linear(46592, 7),
)).to(device)

## Visualize playing agent

In [30]:
VIDEOS_FOLDER_PATH = 'videos'

In [31]:
weights = files.upload()
individual = load_individual(path='final_agent.tch')

Saving final_agent.tch to final_agent (1).tch


In [50]:
video_folder_name = 'selected_agent_videos'
save_selected_videos(
    individual=individual,
    network=network,
    folder=video_folder_name,
    max_iters_per_map=2000,
    all_world_stages=[(7, 4), (3, 4), (4, 4), (2, 2), (6, 1), (8, 2), (4, 1), (7, 2)]  # Some levels better performed on.
)

Maps played:   0%|          | 0/8 [00:00<?, ?it/s]

In [51]:
show_videos(video_folder_name)