In [1]:
from stanza.runtime import setup
setup()

from stanza.data.sequence import SequenceData, Chunk
from stanza.env.mujoco.pusht import (
    PushTPosObs, PushTEnv, 
    PositionalControlTransform, PositionalObsTransform,
    KeypointObsTransform
)
from stanza.env import ImageRender
from stanza import canvas
from stanza.policy.transforms import Transform, ChainedTransform

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from functools import partial

In [2]:
if len(jax.devices()) > 1:
    render_mesh = Mesh(jax.devices()[:8], ('x',))
    render_sharding = NamedSharding(render_mesh, PartitionSpec('x'))
    data_mesh = Mesh(jax.devices()[:8], ('x',))
    data_sharding = NamedSharding(data_mesh, PartitionSpec('x'))

env = PushTEnv()
env = ChainedTransform([
    PositionalControlTransform(),
    KeypointObsTransform()
]).transform_env(env)

def draw(action_chunk, weight):
    T = action_chunk.shape[0]
    colors = jnp.array((jnp.arange(T)/T, jnp.zeros(T), jnp.zeros(T), weight*jnp.ones(T))).T
    circles = canvas.fill(
        canvas.circle(action_chunk, 0.02*jnp.ones(T)),
        color=colors
    )
    circles = canvas.stack_batch(circles)
    circles = canvas.transform(circles,
        translation=(1,-1),
        scale=(128, -128)
    )
    return circles

@partial(jax.jit, out_shardings=render_sharding)
def render(state, action_chunks, weights):
    image = env.render(ImageRender(256, 256), state)
    if action_chunks is not None and weights is not None:
        circles = canvas.stack_batch(jax.vmap(draw)(action_chunks, weights))
        return canvas.paint(image, circles)
    return image

In [3]:
from stanza.datasets.pusht import load_chen_pusht_data
dataset = load_chen_pusht_data()

obs_length = 1
action_length = 8
def map_chunks(chunk : Chunk):
    state, action = (chunk.chunk.state, chunk.chunk.action)
    state = jax.vmap(env.observe)(state)
    obs = jax.tree_util.tree_map(lambda x: x[:obs_length], state)
    # use the future state positions as the actions
    action = jax.tree_util.tree_map(lambda x: x[-action_length:], state.agent_pos)
    # action = jax.tree_util.tree_map(lambda x: x[-action_length:], action)
    return obs, action
data = dataset.chunk(obs_length + action_length).map(map_chunks).as_pytree()
data_len = (data[1].shape[0]//8)*8    # truncate length of data to multiple of 8 for sharding
data = jax.tree_util.tree_map(lambda x: x[:data_len], data)
data = jax.device_put(data, data_sharding)
#jax.debug.visualize_array_sharding(data[1][:,:,0])
#jax.debug.visualize_array_sharding(data[1][:,:,1])
#jax.debug.visualize_array_sharding(data[1][210])

sample_chunk = jax.tree_util.tree_map(lambda x: x[0], data)

print("Loaded data", jax.tree_util.tree_map(lambda x: x.shape, data))

Loaded data (PushTKeypointObs(agent_pos=(24000, 1, 2), block_pos=(24000, 1, 2), block_end=(24000, 1, 2)), (24000, 8, 2))


In [4]:
from stanza.diffusion import DDPMSchedule, nonparametric
from stanza.policy import PolicyInput, PolicyOutput
from stanza.policy.transforms import ChunkTransform
from jax.random import PRNGKey

schedule = DDPMSchedule.make_squaredcos_cap_v2(16, prediction_type="sample")

@jax.jit
def chunk_policy(input: PolicyInput) -> PolicyOutput:
    obs = input.observation
    #estimator = nonparametric.nw_local_poly_closed(data, schedule, 1, 0.001)
    #estimator = nonparametric.nw_local_poly(PRNGKey(42), data, schedule, 1, 
    #                          nonparametric.log_gaussian_kernel, nonparametric.log_gaussian_kernel, 
    #                          0.02, 0.001, 16)
    #diffuser = nonparametric.nw_diffuser(obs, estimator)
    diffuser = nonparametric.nw_cond_diffuser(obs, data, schedule, nonparametric.log_gaussian_kernel, 0.05)
    #diffuser = nonparametric.closest_diffuser(obs, data)
    action = schedule.sample(input.rng_key, diffuser, sample_chunk[1])
    return PolicyOutput(action)
transform = ChunkTransform(obs_length, action_length)
policy = transform.transform_policy(chunk_policy)

In [5]:
from stanza.util.ipython import as_image

state = env.reset(jax.random.key(0))
state_batch = jax.tree_util.tree_map(
    lambda x: jnp.repeat(x[None], obs_length, 0), 
    state
)
output = chunk_policy(PolicyInput(
    jax.vmap(env.observe)(state_batch),
    rng_key=jax.random.key(42))
).action 
as_image(render(state, output[None,...], jnp.ones((1,))))

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x06\x00\x…

In [6]:
#%timeit chunk_policy(PolicyInput(jax.vmap(env.observe)(state_batch), rng_key=jax.random.key(42))).action 

In [7]:
from stanza.policy.ipython import StreamingInterface
from threading import Thread
from ipywidgets import Label, Button
import time


@partial(jax.jit, out_shardings=render_sharding)
def batch_policy(obs, rng_key):
    keys = jax.random.split(rng_key, 8)
    return jax.vmap(chunk_policy, in_axes=(PolicyInput(None, rng_key=0),))(
        PolicyInput(obs, rng_key=keys)
    ).action

executing = True
action_queue = []

label = Label(value="Hello world")

button = Button(description="Stop")
def button_click(_):
    global executing, action_queue
    if executing:
        button.description = "Execute"
        executing = False
    else:
        button.description = "Stop"
        executing = True
        action_queue = []
button.on_click(button_click)

interactive = StreamingInterface(256, 256)
def loop():
    state = env.reset(jax.random.key(42))
    #state = dataset[0][0].state
    state_batch = jax.tree_util.tree_map(
        lambda x: jnp.repeat(x[None], obs_length, 0), 
        state
    )
    key = jax.random.key(42)
    frame = 0
    
    action_chunks = None
    weights = None
    while True:
        t = time.time()
        key, r = jax.random.split(key)
        if executing and len(action_queue) > 0:
            action = action_queue.pop(0)
        else:
            action = interactive.mouse_pos()

        prev_state = state
        state = env.step(state, action)
        reward = env.reward(prev_state, action, state)
        label.value = f"reward: {reward}"
        state_batch = jax.tree_util.tree_map(
            lambda x, s: jnp.roll(x, -1).at[-1].set(s),
            state_batch, state
        )
        obs = jax.vmap(env.observe)(state_batch)
        
        if executing and len(action_queue) == 0:
            action_chunks = batch_policy(obs, r)
            weights = (0.2*jnp.ones(action_chunks.shape[0])).at[0].set(1)
            for a in action_chunks[0]:
                action_queue.append(a)
                action_queue.append(a)
                action_queue.append(a)
        elif not executing and frame % 30 == 0: # re-sample actions every 30 frames
            action_chunks = batch_policy(obs, r)
            weights = jnp.ones(action_chunks.shape[0])
        image = render(state, action_chunks, weights)
        
        interactive.update(image)
        elapsed = time.time() - t
        time.sleep(max(0, 1/30 - elapsed))
        frame = frame + 1

t = Thread(target=loop, daemon=True)
t.start()
t.__del__ = lambda: t.stop()
display(label)
display(interactive)
display(button)

Label(value='Hello world')

HBox(children=(ImageStream(image=Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01…

Button(description='Stop', style=ButtonStyle())