In [1]:
from stanza.runtime import setup
setup()

In [2]:
import jax.random

from stanza.env import ImageRender
from stanza.env.mujoco.pusht import PushTEnv, PositionalControlTransform
env = PushTEnv()

In [3]:
from stanza.env.mujoco.pusht import XML
print(XML)


<mujoco>
<option timestep="0.005"/>
<worldbody>
    # The manipulator agent body
    <body pos="0.5 0.5 0" name="agent">
        # TODO: Replace with cylinder when MJX supports
        <geom type="sphere" size="0.05952" pos="0 0 0.05952" mass="0.1" rgba="0.1 0.1 0.9 1"/>
        <joint type="slide" axis="1 0 0" damping="0.1" stiffness="0" ref="0.5"/>
        <joint type="slide" axis="0 1 0" damping="0.1" stiffness="0" ref="0.5"/>
    </body>
    # The block body
    <body pos="-0.5 -0.5 0" name="block">
        # The horizontal box
        <geom type="box" size="0.2381 0.05952380952380952 0.5" 
            pos="0 -0.05952 0.5" mass="0.03" rgba="0.467 0.533 0.6 1"/>
        # The vertical box
        <geom type="box" size="0.05952 0.1786 0.5"
            pos="0 -0.2976190476190476 0.5" mass="0.03" rgba="0.467 0.533 0.6 1"/>

        <joint type="slide" axis="1 0 0" damping="5" stiffness="0" ref="-0.5"/>
        <joint type="slide" axis="0 1 0" damping="5" stiffness="0" ref="-0.5"/>
   

In [4]:
s = env.reset(jax.random.key(47))

In [5]:
print(env.observe(s))
env.reward(None, None, s)

PushTObs(agent_pos=Array([0.4887165, 0.5627138], dtype=float32), agent_vel=Array([0., 0.], dtype=float32), block_pos=Array([-0.2641813, -0.379523 ], dtype=float32), block_vel=Array([0., 0.], dtype=float32), block_rot=Array(0.00401996, dtype=float32), block_rot_vel=Array(0., dtype=float32))


Array(0., dtype=float32)

In [6]:
from stanza.util.ipython import as_image
as_image(env.render(ImageRender(256, 256), s))

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x06\x00\x…

In [7]:
from stanza.datasets.pusht import load_chen_pusht_data
from stanza.dataclasses import replace
import jax.numpy as jnp
state = load_chen_pusht_data()[0][0].state
print(state)
as_image(env.render(ImageRender(256, 256), state))
#as_image(env.render(ImageRender(256, 256), replace(state, q=jnp.concatenate((state.q[:2],jnp.array([0,0,-jnp.pi/4]))))))

PushTState(q=Array([-0.13492064,  0.63095236, -0.15476194, -0.1428571 , -3.0079994 ],      dtype=float32), qd=Array([0., 0., 0., 0., 0.], dtype=float32))


HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x06\x00\x…

In [8]:
from stanza.datasets.util import cache_path
from stanza.util.ipython import as_image
from numpy import uint8
import zarr
zip_path = cache_path("pusht", "pusht_data.zarr.zip")
with zarr.open(zip_path) as zf:
    images = zf["data/img"][0:zf["meta/episode_ends"][0]].astype(uint8)
as_image(images[0])

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00`\x00\x00\x00`\x08\x02\x00\x00\x00…