In [1]:
from stanza.runtime import setup
setup()

from stanza.data.sequence import SequenceData, Chunk
from stanza.envs.pusht import PushTEnv, PositionControlTransform, PositionObsTransform
from stanza.envs import ImageRender
from stanza import canvas
from stanza.policy.transforms import Transform, ChainedTransform

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from functools import partial

In [2]:
if len(jax.devices()) > 1:
    mesh = Mesh(jax.devices()[:8], ('x',))
    sharding = NamedSharding(mesh, PartitionSpec('x'))

env = PushTEnv()
env = ChainedTransform([
    PositionControlTransform(),
    PositionObsTransform()
]).transform_env(env)

def draw(action_chunk):
    T = action_chunk.shape[0]
    colors = jnp.array((jnp.arange(T)/T, jnp.zeros(T), jnp.zeros(T))).T
    circles = canvas.fill(
        canvas.circle(action_chunk, 0.02*jnp.ones(T)),
        color=colors
    )
    circles = canvas.stack_batch(circles)
    circles = canvas.transform(circles,
        translation=(1,-1),
        scale=(128, -128)
    )
    return circles

@partial(jax.jit, out_shardings=sharding)
def render(state, action_chunks):
    image = env.render(ImageRender(256, 256), state)
    circles = canvas.stack_batch(jax.vmap(draw)(action_chunks))
    return canvas.paint(image, circles)

In [3]:
data = SequenceData.load("pusht_data.pkl")

obs_length = 1
action_length = 6
def map_chunks(chunk : Chunk):
    state, action = chunk.chunk
    state = jax.vmap(env.observe)(state)
    obs = jax.tree_util.tree_map(lambda x: x[:obs_length], state)
    # use the future state positions as the actions
    action = jax.tree_util.tree_map(lambda x: x[-action_length:], state.agent_pos)
    # action = jax.tree_util.tree_map(lambda x: x[-action_length:], action)
    return obs, action
data = data.chunk(obs_length + action_length).map(map_chunks).as_pytree()
sample_chunk = jax.tree_util.tree_map(lambda x: x[0], data)

print("Loaded data", jax.tree_util.tree_map(lambda x: x.shape, data))

Loaded data (PushTPosObs(agent_pos=(4192, 1, 2), block_pos=(4192, 1, 2), block_rot=(4192, 1)), (4192, 6, 2))


In [4]:
from stanza.diffusion import DDPMSchedule, nonparametric
from stanza.policy import PolicyInput, PolicyOutput
from stanza.policy.transforms import ChunkTransform

schedule = DDPMSchedule.make_squaredcos_cap_v2(16, prediction_type="sample")

@jax.jit
def chunk_policy(input: PolicyInput) -> PolicyOutput:
    obs = input.observation
    diffuser = nonparametric.nw_cond_diffuser(obs, data, schedule, 
            nonparametric.log_gaussian_kernel, 0.01
    )
    #diffuser = nonparametric.closest_diffuser(obs, data)
    action = schedule.sample(input.rng_key, diffuser, sample_chunk[1])
    return PolicyOutput(action)
transform = ChunkTransform(obs_length, action_length)
policy = transform.transform_policy(chunk_policy)

In [5]:
from stanza.util.ipython import as_image

state = env.reset(jax.random.key(0))
state_batch = jax.tree_util.tree_map(
    lambda x: jnp.repeat(x[None], obs_length, 0), 
    state
)
output = chunk_policy(PolicyInput(
    jax.vmap(env.observe)(state_batch),
    rng_key=jax.random.key(42))
).action 

as_image(render(state, output[None,...]))

(256, 256, 4)


HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x06\x00\x…

In [6]:
from stanza.policy.ipython import StreamingInterface
from threading import Thread
from ipywidgets import Label
import time


@partial(jax.jit, out_shardings=sharding)
def batch_policy(obs, rng_key):
    keys = jax.random.split(rng_key, 8)
    return jax.vmap(chunk_policy, in_axes=(PolicyInput(None, rng_key=0),))(
        PolicyInput(obs, rng_key=keys)
    ).action

label = Label(value="Hello world")
interactive = StreamingInterface(256, 256)
def loop():
    state = env.reset(jax.random.key(42))
    state_batch = jax.tree_util.tree_map(
        lambda x: jnp.repeat(x[None], obs_length, 0), 
        state
    )
    key = jax.random.key(42)
    while True:
        t = time.time()
        key, r = jax.random.split(key)
        action = interactive.mouse_pos()
        prev_state = state
        state = env.step(state, action)
        reward = env.reward(prev_state, action, state)
        label.value = f"reward: {reward}"
        state_batch = jax.tree_util.tree_map(
            lambda x, s: jnp.roll(x, -1).at[-1].set(s),
            state_batch, state
        )
        obs = jax.vmap(env.observe)(state_batch)
        output = batch_policy(obs, jax.random.key(42))
        output = jax.tree_util.tree_map(lambda x: x[:1], output)
        image = render(state, output)
        interactive.update(image)
        elapsed = time.time() - t
        time.sleep(max(0, 1/30 - elapsed))
t = Thread(target=loop, daemon=True)
t.start()
t.__del__ = lambda: t.stop()
display(label)
interactive

Label(value='Hello world')

HBox(children=(ImageStream(image=Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01…