In [1]:
import dataclasses
import gc

import jax

from openpi.models import model as _model
from openpi.policies import rby1_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader


def _cleanup():
    """Best-effort cleanup after exceptions.

    Notes:
    - JAX/XLA may keep GPU memory reserved for reuse; this still helps release Python refs
      and clear compilation caches.
    - If you're running a long notebook session and hit OOM, a kernel restart is the
      most reliable way to fully reset GPU memory.
    """

    # Drop Python references.
    gc.collect()

    # Clear JAX compilation/executable caches.
    try:
        jax.clear_caches()
    except Exception:
        pass

    # If torch is present (e.g., PyTorch checkpoints), clear its CUDA cache.
    try:
        import torch

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass

  import pynvml  # type: ignore[import]


# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [16]:
config = _config.get_config("pi05_rby1")
checkpoint_dir = download.maybe_download("../checkpoints/pi05_rby1/pi05_rby1_test/200")
# checkpoint_dir = "../checkpoints/pi05_rby1/pi05_rby1_test/200"
policy = None
result = None

try:
    # Create a trained policy.
    policy = _policy_config.create_trained_policy(config, checkpoint_dir)

    # Run inference on a dummy example.
    example = rby1_policy.make_rby1_example()
    result = policy.infer(example)

    print("Actions shape:", result["actions"].shape)
finally:
    # Always clean up even if an exception occurs.
    try:
        del policy
    except NameError:
        pass
    try:
        del result
    except NameError:
        pass
    _cleanup()

Actions shape: (50, 16)


# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [17]:
config = _config.get_config("pi05_rby1")

checkpoint_dir = download.maybe_download("../checkpoints/pi05_rby1/pi05_rby1_test/200")
key = jax.random.key(0)

model = None
obs = act = loss = None

try:
    # Create a model from the checkpoint.
    model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

    # We can create fake observations and actions to test the model.
    obs, act = config.model.fake_obs(), config.model.fake_act()

    # Compute loss.
    loss = model.compute_loss(key, obs, act)
    print("Loss shape:", loss.shape)
finally:
    try:
        del model
    except NameError:
        pass
    try:
        del obs, act, loss
    except NameError:
        pass
    _cleanup()

Loss shape: (1, 50)


Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

loader = None
obs = act = loss = None

try:
    # Load a single batch of data. This is the same data that will be used during training.
    # NOTE: In order to make this example self-contained, we are skipping the normalization step
    # since it requires the normalization statistics to be generated using `compute_norm_stats`.
    loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
    obs, act = next(iter(loader))

    loss = model.compute_loss(key, obs, act)
    print("Loss shape:", loss.shape)
finally:
    try:
        del loader
    except NameError:
        pass
    try:
        del obs, act, loss
    except NameError:
        pass
    # `model` is created in the previous cell; we still clean caches here.
    _cleanup()