In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp

from rgm import *
from atari.common import *

import mediapy

In [3]:
env = make_game(game_config("Breakout", ObservationType.MINI))

In [4]:
from tqdm import tqdm

observations = []
actions = []
rewards = []

num_episodes = 1000
horizon = 128

for n in tqdm(range(num_episodes)):
    acs = []
    os = []
    rs = []

    obs, info = env.reset()
    for i in range(horizon):
        action = env.action_space.sample()
        next_obs, reward, done, trunc, info  = env.step(action)
        acs.append([action])
        os.append(obs)
        rs.append(reward)
        obs = next_obs

        if done:
            break
            obs, info = env.reset()

    observations.append(os)
    actions.append(acs)
    rewards.append(rs)

totals = []
for r in rewards:
    totals.append(jnp.sum(jnp.asarray(r)))
print("Got reward (avg/max)", jnp.mean(jnp.asarray(totals)), jnp.max(jnp.asarray(totals)))

100%|██████████| 1000/1000 [00:00<00:00, 11673.80it/s]


Got reward (avg/max) 0.395 4


In [None]:
import numpy as np
import seaborn as sns
from matplotlib import colors

def render(observation):
    observation = observation[:, :, :4]
    cmap = sns.color_palette("cubehelix", observation.shape[-1])
    cmap.insert(0, (0, 0, 0))
    cmap = colors.ListedColormap(cmap)
    bounds = [i for i in range(observation.shape[-1] + 2)]
    norm = colors.BoundaryNorm(bounds, observation.shape[-1] + 1)
    numerical_state = np.amax(observation * np.reshape(np.arange(observation.shape[-1]) + 1, (1, 1, -1)), 2) + 0.5
    img = cmap(norm(numerical_state))
    return img[:, :, :3]

In [None]:
# add one-hot dimension for background
def to_one_hot(o):
    zero_mask = jnp.all(o == 0, axis=-1)  # shape of batch x time x 10 x 10
    arr = jnp.zeros((o.shape[0], o.shape[1], o.shape[2], o.shape[3], o.shape[4] + 1), dtype=jnp.float32)
    arr = arr.at[..., :4].set(o)
    arr = arr.at[zero_mask, 4].set(1)
    return arr

In [None]:
from tqdm import tqdm

env = make_game(game_config("Asterix", ObservationType.MINI))
env.unwrapped.sticky_action_prob = 0.0

steps = 0
last_step = 0
train_steps = 0

observations = []
actions = []
rewards = []

train_sequences = {}

num_episodes = 0
horizon = 128

rgm = RGM(max_levels=8, n_bins=5, dx=2, size=(10, 10), action_range=(0, 4), svd=False)
rgm_agent = RGMAgent(rgm)


# for n in tqdm(range(num_episodes)):
while num_episodes < 20:
    print(steps,"\t", steps - last_step)
    last_step = steps

    acs = []
    os = []
    rs = []
    
    num_episodes += 1
    rgm_agent.reset()
    obs, info = env.reset()
    for i in range(horizon):
        o = to_one_hot(jnp.asarray([[obs]]))[0].reshape(-1, 100, 5).transpose(1, 0, 2)
        action = rgm_agent.act(o)
        action = int(action[0])
        if action == -1:
            # if rgm_agent returns -1, randomly sample an action
            action = env.action_space.sample()
        next_obs, reward, done, trunc, info = env.step(action)

        steps += 1

        acs.append([action])
        os.append(obs)
        rs.append([reward])
        obs = next_obs

        if done:
            achieved_reward = jnp.asarray(rs)
            r = jnp.where(achieved_reward > 0)[0]
            if len(r) > 0:
                last_reward_idx = r[-1]
                step_length = 8
                size = step_length * (1 + last_reward_idx // step_length)
                if len(os) >= size:
                    imgs = []
                    for j in range(len(os)):
                        imgs.append(render(os[j]))
                    train_sequences[str(len(train_sequences.keys()))] = imgs

                    train_steps += size
                    # we have a trajectory with some rewards, add to RGM?
                    o = to_one_hot(jnp.asarray([os]))[0].reshape(-1, 100, 5).transpose(1, 0, 2)
                    # lump reward on there as an "action" modality to use for acting
                    a = jnp.concatenate([jnp.asarray(acs), jnp.asarray(rs)], axis=-1)
                    rgm.learn_structure(o, a)
            break            

    # only train on 1 sequence for debugging
    # if rgm.agents is not None:
    #     break

    observations.append(os)
    actions.append(acs)
    rewards.append(rs)

    jax.clear_caches()


print("Interacted with the environment for", steps, "steps")

totals = []
for r in rewards:
    totals.append(jnp.sum(jnp.asarray(r)))
print("Got reward (avg/max)", jnp.mean(jnp.asarray(totals)), jnp.max(jnp.asarray(totals)))
print("Trained on", len(train_sequences.keys()), "sequences, totalling", train_steps, "steps")

with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos(train_sequences, width=160, height=160, columns=5, fps=4, codec="gif")

In [None]:
from PIL import Image

resize = lambda x: jnp.array(Image.fromarray((x * 255).astype(jnp.uint8)).resize((256, 256), Image.NEAREST))
train_sequences_big = {k: [resize(img) for img in imgs] for k, imgs in train_sequences.items()}

with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos(train_sequences_big, width=160, height=160, columns=5, fps=6, codec="gif")

In [None]:
import mediapy

imgs = []
for i in range(len(observations)):
    for j in range(len(observations[i])):
        imgs.append(render(observations[i][j]))

from PIL import Image
resize = lambda x: jnp.array(Image.fromarray((x * 255).astype(jnp.uint8)).resize((256, 256), Image.NEAREST))
big = [resize(img) for img in imgs]

with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos({"breakout": big}, width=320, height=320, fps=20, codec="gif")

In [None]:
rgm.save("../data/rgms/mini_breakout_rgm.npz")

In [None]:
episode_rewards = []
play = []
actions = []

n = 1
for _ in tqdm(range(n)):
    total_reward = 0

    rgm_agent = RGMAgent(rgm)
    rgm_agent.reset()
    obs, info = env.reset()
    for i in range(horizon):
        o = to_one_hot(jnp.asarray([[obs]]))[0].reshape(-1, 100, 5).transpose(1, 0, 2)
        action = rgm_agent.act(o)
        action = int(action[0])
        if action == -1:
            print("random")
            # if rgm_agent returns -1, randomly sample an action
            action = env.action_space.sample()
        actions.append(action)
        next_obs, reward, done, trunc, info = env.step(action)
        total_reward += reward
        play.append(obs)
        obs = next_obs

        if done:
            break

    episode_rewards.append(total_reward)

print(
    "Played ", n, " games with reward (avg / max)",
    jnp.mean(jnp.asarray(episode_rewards)),
    jnp.max(jnp.asarray(episode_rewards)),
)

In [None]:
import mediapy

imgs = [render(play[i]) for i in range(len(play))]


with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos({"breakout": imgs}, width=320, height=320, fps=2, codec="gif")