In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
from jax import random as jr

key = jr.PRNGKey(0)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
from rgm.envs.pong import Pong

env = Pong()

Generate random play and stitch "rewarding" sequences together using Karl's "Maxwell's daemon" approach:

In [4]:
from tqdm import tqdm

observations = []
actions = []
rewards = []

start_state = None

num_episodes = 128
horizon = 64

for n in tqdm(range(num_episodes)):
    acs = []
    os = []
    rs = []
    ims = []

    keys = jr.split(key, 2)
    key = keys[0]

    o, env = env.reset(keys[1:], state=start_state)
    os.append(jnp.asarray(o[:-1]))
    reward = o[-1]
    rs.append(reward)

    for i in range(horizon):
        keys = jr.split(key, 3)
        key = keys[0]
        a = jr.randint(keys[1], (1,), 0, 3)
        a = jnp.array([[0, a[0]]])
        acs.append(a)
        o, env = env.step(rng_key=keys[2:], actions = a)
        # split o into observation and reward
        obs = jnp.asarray(o[:-1])
        reward = o[-1]

        if reward[0, 0] != 0:
            # if reward is 1, save sequence, record start state and break
            if reward[0, 0] == 1:
                start_state = env.state
                actions += acs
                observations += os
                rewards += rs

            break

        os.append(obs)
        rs.append(reward)

actions = jnp.concatenate(actions, axis=0)
observations = jnp.concatenate(observations, axis=1)
rewards = jnp.concatenate(rewards, axis=1).squeeze(0)
print(rewards)

100%|██████████| 128/128 [01:34<00:00,  1.36it/s]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]





Fraction of the data kept as training data:

In [5]:
print(len(rewards) / (num_episodes * horizon))

0.0079345703125


In [6]:
data = {}
data["observations"] = observations
data["rewards"] = rewards
data["actions"] = actions
jnp.savez("../data/pong.npz", **data)

In [7]:
# data = jnp.load("data/pong.npz")
# observations = jnp.asarray(data["observations"])
# rewards = jnp.asarray(data["rewards"])

In [8]:
import mediapy
from rgm.envs.pong import obs2img

# TODO should be vmapped
imgs = [obs2img(observations, 12, 9, i) for i in range(observations.shape[1])]
with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos({"pong": imgs}, width=180, height=240, fps=10, codec="gif")

0
pong


In [9]:
print(observations.shape)
print(actions.shape)

(108, 65, 1)
(65, 2)


In [10]:
from rgm.fast_structure_learning import *
from rgm.rgm import *

In [11]:
rgm = RGM(max_levels=8, n_bins=5, dx=3, size=(12,9), action_range=(0, 4), svd=False)

In [12]:
rgm.learn_structure(observations, actions)

structure learn level 0 group 0 time 2.5012364387512207
structure learn level 0 group 1 time 0.9617540836334229
structure learn level 0 group 2 time 0.6817452907562256
structure learn level 0 group 3 time 0.8439693450927734
structure learn level 0 group 4 time 0.07053422927856445
structure learn level 0 group 5 time 0.08694243431091309
structure learn level 0 group 6 time 0.06847572326660156
structure learn level 0 group 7 time 0.15839076042175293
structure learn level 0 group 8 time 0.07990074157714844
structure learn level 0 group 9 time 0.09043288230895996
structure learn level 0 group 10 time 0.07900238037109375
structure learn level 0 group 11 time 0.08293867111206055
structure learn level 1 group 0 time 0.9294495582580566
structure learn level 1 group 1 time 0.4440317153930664
structure learn level 2 group 0 time 0.7982876300811768


In [13]:
rgm.save("../data/rgms/karl_pong_rgm.npz")

In [14]:
qs = rgm.infer_states(observations[:, :4, :], actions[:4], None)

In [15]:
r, u = rgm.reconstruct(qs)

In [16]:
rgm.action_bins

Array([0., 1., 2., 3., 4.], dtype=float32)

In [17]:
print(u)
print(actions[:4])

[[0.0000000e+00 1.9989027e+00]
 [0.0000000e+00 4.8774609e-04]
 [0.0000000e+00 1.0974287e-03]
 [0.0000000e+00 2.4387304e-04]]
[[0 2]
 [0 0]
 [0 0]
 [0 0]]


In [18]:
imgs = [obs2img(jnp.argmax(r, axis=-1, keepdims=True), 12, 9, i) for i in range(r.shape[1])]
with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos({"reconstruction": imgs}, width=180, height=240, fps=2, codec="gif")

0
reconstruction


In [32]:
from tqdm import tqdm

observations = []

T = 200

agent = RGMAgent(rgm)

keys = jr.split(key, 2)
key = keys[0]
o, env = env.reset(keys[1:])
for i in tqdm(range(T)):
    keys = jr.split(key, 2)
    key = keys[0]

    obs = jnp.asarray(o[:-1])
    reward = o[-1]
    observations.append(obs)

    a = agent.act(obs)
    a = jnp.expand_dims(a, axis=0)
    o, env = env.step(rng_key=keys[1:], actions=a)

observations = jnp.concatenate(observations, axis=1)

In [33]:
imgs = [obs2img(observations, 12, 9, i) for i in range(observations.shape[1])]
with mediapy.set_show_save_dir("/tmp"):
    mediapy.show_videos({"play": imgs}, width=180, height=240, fps=10, codec="gif")

0
play
