In [1]:
import jax
import jax.numpy as jnp

from openpi.models import model as _model
from openpi.policies import aloha_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.shared import nnx_utils
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader


# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [2]:
config = _config.get_config("pi0_aloha_sim")
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the Aloha runtime.
example = aloha_policy.make_aloha_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print(result["actions"].shape)

(50, 14)


# Model inference

The following example shows how to create a live model from a checkpoint and run inference on a batch of trainingdata.


In [3]:
config = _config.get_config("pi0_aloha_sim")
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
# NOTE: We are converting the model weights to bfloat16 to reduce memory usage and speed up inference.
model = config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
actions = model.sample_actions(key, obs)

# Alternatively, we can compile the model using `nnx_utils.module_jit` to speed up inference.
sample_actions = nnx_utils.module_jit(model.sample_actions)
actions_jit = sample_actions(key, obs)

# Delete the model and sample_actions to free up memory.
del model
del sample_actions

# NOTE: Compiled and uncompiled results can be slightly different.
print("max diff:", jnp.max(jnp.abs(actions - actions_jit)))
print(actions.shape)
print(actions_jit.shape)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 106 files:   0%|          | 0/106 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

max diff: 0.032226562
(32, 50, 24)
(32, 50, 24)
