In [1]:
import gym_super_mario_bros
import torch
import numpy as np

from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace

from pathlib import Path
from skvideo.io import vwrite

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from wrapper import apply_wrappers

In [4]:
ENV_NAME = 'SuperMarioBros-1-1-v3'
ACTIONS_DIR = './actions/'

In [5]:
def load_actions(episode, rewards):
    print("### CWD: ", Path.cwd())
    path = Path(Path.cwd(), ACTIONS_DIR,
                f"agent_actions_ep:{episode}_rw:{rewards}.pt")
    actions = torch.load(path)
    return actions

In [6]:
def test_from_actions(episode, rewards):
    actions = load_actions(episode, rewards)

    env = gym_super_mario_bros.make(ENV_NAME)
    env.metadata['apply_api_compatibility'] = True
    env.metadata['render.modes'] = ['rgb_array']

    env = JoypadSpace(env, RIGHT_ONLY)
    env = apply_wrappers(env)

    state = env.reset()
    reward_sum = 0
    coins = 0
    final_score = 0
    images = []
    print("### env reset, state: ", state)
    len_actions = len(actions)

    for i, action in enumerate(actions):
        next_state, reward, done, info = env.step(action)
        reward_sum += reward
        coins += info['coins']
        final_score += info['score']
        print("i: ", i, "/", len_actions,
              "action: ", action,
              "rewards: ", reward_sum,
              "coins: ", coins,
              "xpos: ", info['x_pos'],
              "flag_get: ", info['flag_get'])
        # env.render()
        rgb_array = env.render(mode='rgb_array')
        # copy the array to avoid references
        images.append(np.array(rgb_array, dtype=np.uint8))
    print("### reward_sum: ", reward_sum)
    env.close()

    images = np.array(images, dtype=np.uint8)
    print("### images.shape: ", images.shape)
    path = Path(ACTIONS_DIR, f"test_ep{episode}_rw{rewards}.mp4")

    # scikit video
    vwrite(path, images)
    print("### video saved at: ", path)

In [7]:
test_from_actions(29155, 3032)

### CWD:  /Users/corentin/Documents/mario_rl
### env reset, state:  <gym.wrappers.frame_stack.LazyFrames object at 0x1762c3a60>
i:  0 / 537 action:  4 rewards:  0.0 coins:  0 xpos:  40 flag_get:  False
i:  1 / 537 action:  4 rewards:  1.0 coins:  0 xpos:  41 flag_get:  False
i:  2 / 537 action:  4 rewards:  2.0 coins:  0 xpos:  42 flag_get:  False
i:  3 / 537 action:  1 rewards:  4.0 coins:  0 xpos:  44 flag_get:  False
i:  4 / 537 action:  1 rewards:  7.0 coins:  0 xpos:  47 flag_get:  False
i:  5 / 537 action:  1 rewards:  9.0 coins:  0 xpos:  50 flag_get:  False
i:  6 / 537 action:  1 rewards:  13.0 coins:  0 xpos:  54 flag_get:  False
i:  7 / 537 action:  1 rewards:  17.0 coins:  0 xpos:  58 flag_get:  False
i:  8 / 537 action:  1 rewards:  22.0 coins:  0 xpos:  63 flag_get:  False
i:  9 / 537 action:  1 rewards:  28.0 coins:  0 xpos:  69 flag_get:  False
i:  10 / 537 action:  1 rewards:  34.0 coins:  0 xpos:  75 flag_get:  False
i:  11 / 537 action:  1 rewards:  39.0 coins:  0 xpo

  self._proc.stdin.write(vid.tostring())


### video saved at:  actions/test_ep29155_rw3032.mp4
