In [1]:
try:
    from google.colab import output
    output.enable_custom_widget_manager()
except ImportError:
    pass
try:
    import stanza
except:
    %pip install -v git+https://github.com/pfrommerd/stanza.git
    # Fix an issue with ipywidgets 8 and colab
    %pip install "ipywidgets<8"

In [2]:
from stanza.runtime import setup
setup()

from stanza.data.sequence import SequenceData, Chunk
from stanza.env.mujoco.pusht import (
    PushTPosObs, PushTEnv,
    PositionalControlTransform, PositionalObsTransform,
    KeypointObsTransform, RelKeypointObsTransform
)
from stanza.env import ImageRender
from stanza.env.transforms import ChainedTransform
from stanza import canvas
import stanza.env

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from functools import partial

In [3]:
if len(jax.devices()) > 1:
    render_mesh = Mesh(jax.devices()[:8], ('x',))
    render_sharding = NamedSharding(render_mesh, PartitionSpec('x'))
    data_mesh = Mesh(jax.devices()[:8], ('x',))
    data_sharding = NamedSharding(data_mesh, PartitionSpec('x'))
else:
    render_sharding = None

env = PushTEnv()
env = stanza.env.create("mujoco/pusht/rel_keypoint")

def draw(action_chunk, weight, width, height):
    T = action_chunk.shape[0]
    colors = jnp.array((jnp.arange(T)/T, jnp.zeros(T), jnp.zeros(T), weight*jnp.ones(T))).T
    circles = canvas.fill(
        canvas.circle(action_chunk, 0.02*jnp.ones(T)),
        color=colors
    )
    circles = canvas.stack_batch(circles)
    circles = canvas.transform(circles,
        translation=(1,-1),
        scale=(width/2, -height/2)
    )
    return circles

@partial(jax.jit, out_shardings=render_sharding, static_argnames=("width","height",))
def render(state, action_chunks=None, weights=None, width=256, height=256):
    image = env.render(state, ImageRender(width, height))
    if action_chunks is not None and weights is not None:
        circles = canvas.stack_batch(jax.vmap(draw, in_axes=(0,0,None,None))(action_chunks, weights, width, height))
        #jax.debug.print("{s}", s=circles)
        return canvas.paint(image, circles)
    return image

In [4]:
from stanza.datasets.env.pusht import load_chi_pusht_data
print("Loading data...")
dataset = load_chi_pusht_data().cache()
print("Computing full mujoco state...")
obs_length = 1
action_length = 16
def map_elements(el):
    return env.full_state(el.reduced_state)
try:
    data = dataset.map_elements(map_elements).cache()
except Exception as e:
    import traceback
    traceback.print_exc()
print("Converted reduced state to full state!")

def map_chunks(chunk : Chunk):
    state = chunk.elements
    obs = jax.vmap(env.observe)(state)
    obs = jax.tree.map(lambda x: x[:obs_length], obs)
    # use the future agent positions as the actions
    actions = jax.vmap(lambda s: PushTEnv.observe(env, s))(state).agent_pos
    # compute relative future agent positions
    action = jax.tree.map(lambda x, y: x[-action_length:] - y[obs_length-1], actions, actions)
    #action = jax.tree.map(lambda x: x[-action_length:], actions)
    return obs, action
data = data.chunk(obs_length + action_length).map(map_chunks).as_pytree()
data_len = (data[1].shape[0]//8)*8    # truncate length of data to multiple of 8 for sharding
data = jax.tree.map(lambda x: x[:data_len], data)
data = jax.device_put(data, data_sharding)
#jax.debug.visualize_array_sharding(data[1][:,:,0])
#jax.debug.visualize_array_sharding(data[1][:,:,1])
#jax.debug.visualize_array_sharding(data[1][210])

sample_chunk = jax.tree.map(lambda x: x[0], data)

print("Loaded data", jax.tree.map(lambda x: x.shape, data))

Loading data...
Computing full mujoco state...
Converted reduced state to full state!
Loaded data (PushTKeypointRelObs(agent_block_pos=(22352, 1, 2), agent_block_end=(22352, 1, 2), rel_block_pos=(22352, 1, 2), rel_block_end=(22352, 1, 2)), (22352, 16, 2))


In [5]:
from stanza.diffusion import DDPMSchedule, nonparametric
from stanza.policy import PolicyInput, PolicyOutput
from stanza.policy.transforms import ChunkingTransform
from jax.random import PRNGKey

schedule = DDPMSchedule.make_squaredcos_cap_v2(16, prediction_type="sample")

@jax.jit
def chunk_policy(input: PolicyInput) -> PolicyOutput:
    agent_pos = PushTEnv.observe(env, input.state).agent_pos
    obs = input.observation
    #estimator = nonparametric.nw_local_poly_closed(data, schedule, 1, 0.001)
    #estimator = nonparametric.nw_local_poly(PRNGKey(42), data, schedule, 0,
    #                          nonparametric.log_gaussian_kernel, nonparametric.log_gaussian_kernel,
    #                          0.01, 0.001, 128)
    #diffuser = nonparametric.nw_diffuser(obs, estimator)
    diffuser = nonparametric.nw_cond_diffuser(obs, data, schedule, nonparametric.log_gaussian_kernel, 0.01)
    #diffuser = nonparametric.closest_diffuser(obs, data)
    action = schedule.sample(input.rng_key, diffuser, sample_chunk[1]) + agent_pos
    return PolicyOutput(action, info=action)
transform = ChunkingTransform(obs_length, action_length)
policy = transform.apply(chunk_policy)

In [6]:
from stanza.util.ipython import as_image, as_video

state = env.reset(jax.random.key(0))
state_batch = jax.tree.map(
    lambda x: jnp.repeat(x[None], obs_length, 0),
    state
)
output = chunk_policy(PolicyInput(
    jax.vmap(env.observe)(state_batch),
    state,
    rng_key=jax.random.key(42))
).action
as_image(render(state, output[None,...], jnp.ones((1,))))

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x06\x00\x…

In [10]:
import stanza.policy
from stanza import canvas

#@partial(jax.jit, out_shardings=render_sharding)
def batch_policy(obs, state, rng_key):
    keys = jax.random.split(rng_key, 8)
    return jax.vmap(chunk_policy, in_axes=(PolicyInput(None, None, rng_key=0),))(
        PolicyInput(obs, state, rng_key=keys)
    ).action

def roll_video(rng_key):
    #transform = ChunkingTransform(obs_length, action_length)
    #policy = transform.apply(chunk_policy)
    rng_key, x0_rng = jax.random.split(rng_key)
    r = stanza.policy.rollout(env.step, env.reset(x0_rng), policy, policy_rng_key=rng_key, observe=env.observe, length=100, last_action=True)
    def render_frame(state, action, width, height):
        return render(state, action[None], jnp.ones((1,)), width, height)
    return jax.vmap(lambda x, a: render_frame(x, a, 128, 128))(r.states, r.info)

@jax.jit
def generate_video(rng_key):
    keys = jax.random.split(rng_key, 4)
    videos = jax.vmap(roll_video)(keys)
    video = jax.vmap(
        lambda x: stanza.canvas.image_grid(x),
        in_axes=1, out_axes=0
    )(videos)
    return video
as_video(generate_video(jax.random.key(42)))

Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free...')

In [8]:
#%timeit chunk_policy(PolicyInput(jax.vmap(env.observe)(state_batch), rng_key=jax.random.key(42))).action

In [9]:
from stanza.policy.ipython import StreamingInterface
from threading import Thread
from ipywidgets import Label, Button
import time


executing = True
action_queue = []

label = Label(value="Hello world")

button = Button(description="Stop")
def button_click(_):
    global executing, action_queue
    if executing:
        button.description = "Execute"
        executing = False
    else:
        button.description = "Stop"
        executing = True
        action_queue = []
button.on_click(button_click)

interactive = StreamingInterface(256, 256)
def loop():
    state = env.reset(jax.random.key(42))
    #state = dataset[0][0].state
    state_batch = jax.tree.map(
        lambda x: jnp.repeat(x[None], obs_length, 0),
        state
    )
    key = jax.random.key(43)
    frame = 0

    action_chunks = None
    weights = None
    iterations = 0
    while True:
        t = time.time()
        key, r = jax.random.split(key)
        if executing and len(action_queue) > 0:
            action = action_queue.pop(0)
        else:
            action = interactive.mouse_pos()

        prev_state = state
        state = env.step(state, action)
        reward = env.reward(prev_state, action, state)
        label.value = f"reward: {reward} {iterations}"
        state_batch = jax.tree.map(
            lambda x, s: jnp.roll(x, -1).at[-1].set(s),
            state_batch, state
        )
        obs = jax.vmap(env.observe)(state_batch)

        if executing and len(action_queue) == 0:
            action_chunks = batch_policy(obs, state, r)
            weights = (0.2*jnp.ones(action_chunks.shape[0])).at[0].set(1)
            for a in action_chunks[0]:
                action_queue.append(a)
                #action_queue.append(a)
                #action_queue.append(a)
            iterations = iterations + action_length
        elif not executing and frame % 30 == 0: # re-sample actions every 30 frames
            action_chunks = batch_policy(obs, state, r)
            weights = jnp.ones(action_chunks.shape[0])
        image = render(state, action_chunks, weights)

        interactive.update(image)
        elapsed = time.time() - t
        time.sleep(max(0, 1/30 - elapsed))
        frame = frame + 1

t = Thread(target=loop, daemon=True)
t.start()
t.__del__ = lambda: t.stop()
display(label)
display(interactive)
display(button)

Label(value='Hello world')

HBox(children=(ImageStream(image=Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01…

Button(description='Stop', style=ButtonStyle())