In [5]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import libero_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [7]:
config = _config.get_config("pi05_libero")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_libero")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
# example = droid_policy.make_droid_example()
example = libero_policy.make_libero_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions:", result["actions"])
print("Actions shape:", result["actions"].shape)

x_next: [[[-2.24616671e+00 -1.79153991e+00  1.52077436e-01 -3.24375391e-01
   -6.99718177e-01 -1.02913773e+00 -9.71461773e-01  2.68479258e-01
   -1.17869997e+00  1.91731191e+00 -1.71628013e-01  8.67332518e-01
   -1.17024076e+00 -6.73010230e-01 -3.35156590e-01  3.98357093e-01
   -1.07164717e+00 -6.14736527e-02 -8.59867454e-01 -1.76345134e+00
   -9.96619225e-01 -2.90022790e-01  1.19751740e+00  7.31182158e-01
   -1.04212475e+00 -4.89774525e-01  7.27263987e-01  1.48956525e+00
   -3.68538380e-01  4.88530472e-02 -5.83446741e-01 -1.59083784e+00]
  [-2.39484310e-01  8.75555813e-01 -1.20171034e+00 -6.48959517e-01
   -2.17880487e+00 -1.48568797e+00 -3.68761569e-01 -1.54406083e+00
   -1.78951097e+00 -1.48470676e+00 -2.48204142e-01 -6.47193015e-01
    3.53990525e-01 -8.46832991e-01  1.72522593e+00  5.00060022e-01
   -7.17024505e-02  6.89931273e-01  8.24652016e-01 -1.01886320e+00
    6.22029185e-01 -1.23589694e+00  1.67558178e-01 -1.44404435e+00
   -4.26603466e-01  4.48907197e-01 -6.74657762e-01  8

# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [None]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)