In [1]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
import openpi.training.data_loader as _data_loader
import openpi.training.sharding as sharding


In [None]:
# %pdb on

Automatic pdb calling has been turned ON


# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [3]:
config = _config.get_config("pi0_franka_low_mem_finetune")

checkpoint_dir = "/mnt/data/josyula/openpi/checkpoints/pi0_franka_low_mem_finetune/my_experiment/29999" 
#download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)

mesh = sharding.make_mesh(config.fsdp_devices)
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

data_loader = _data_loader.create_data_loader(
    config,
    sharding=data_sharding,
    shuffle=False,
)
data_iter = iter(data_loader)
batch = next(data_iter)

# result = policy.infer(example)

# Delete the policy to free up memory.
# del policy

# print("Actions shape:", result["actions"].shape)

Resolving data files:   0%|          | 0/85 [00:00<?, ?it/s]

In [4]:
import jax.numpy as jnp
import numpy as np
import torch

def to_jax(obj):
    """Recursively convert arrays/tensors/lists in a nested structure to jnp.array."""
    if isinstance(obj, dict):
        return {k: to_jax(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        # Preserve tuple type if needed
        return type(obj)(to_jax(v) for v in obj)
    elif isinstance(obj, torch.Tensor):
        return jnp.array(obj.detach().cpu().numpy())
    elif isinstance(obj, np.ndarray):
        return jnp.array(obj)
    else:
        return obj



In [5]:
import numpy as np
import torch
import jax.numpy as jnp

def convert_arrays(obj, target="jax"):
    """
    Recursively convert arrays/tensors/lists in a nested structure to JAX or Torch.

    Args:
        obj: Any nested combination of dicts/lists/arrays/tensors.
        target (str): 'jax' or 'torch' for output type.

    Returns:
        The converted structure.
    """
    if isinstance(obj, dict):
        return {k: convert_arrays(v, target=target) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return type(obj)(convert_arrays(v, target=target) for v in obj)
    elif isinstance(obj, torch.Tensor):
        np_arr = obj.detach().cpu().numpy()
        return jnp.array(np_arr) if target == "jax" else torch.from_numpy(np_arr)
    elif isinstance(obj, np.ndarray):
        return jnp.array(obj) if target == "jax" else torch.from_numpy(obj)
    else:
        return obj



In [6]:
import random
import random

def take_single_sample(obs_dict, idx=None):
    """
    Create a shallow copy of obs_dict containing only one batch element.
    Handles missing or None fields gracefully.
    """
    out = dict(obs_dict)  # shallow copy

    # pick index
    if idx is None:
        first_img = next(iter(obs_dict["images"].values()))
        B = first_img.shape[0]
        idx = random.randrange(B)

    # slice images
    out["images"] = {k: v[idx] for k, v in obs_dict["images"].items()}

    # slice optional dict-of-arrays fields
    if "image_mask" in obs_dict and obs_dict["image_mask"] is not None:
        out["image_mask"] = {k: v[idx] for k, v in obs_dict["image_mask"].items()}

    # slice optional array fields
    for key in [
        "state",
        "tokenized_prompt",
        "tokenized_prompt_mask",
        "token_ar_mask",
        "token_loss_mask",
        "actions",
    ]:
        if key in obs_dict and obs_dict[key] is not None:
            out[key] = obs_dict[key][idx]

    return out, idx


In [7]:
def print_keys_recursive(d, prefix=""):
    if isinstance(d, dict):
        for k, v in d.items():
            full_key = f"{prefix}.{k}" if prefix else k
            print(full_key)
            print_keys_recursive(v, prefix=full_key)

# Example usage
new_dict = batch[0].to_dict()
new_dict["images"] = new_dict["image"]

print_keys_recursive(new_dict)

# # for key in new_dict["images"].keys(): 
# #     print("shape is ", new_dict["images"][key].shape)
new_dict_torch = convert_arrays(new_dict, target="torch")
single_dict, chosen_idx = take_single_sample(new_dict_torch)  # or take_single_sample(new_dict_torch, idx=0)
# print("Picked index:", chosen_idx)
# for k, v in single_dict["images"].items():
#     print(k, v.shape)  # now (C, H, W) or (H, W, C) depending on your data

# 

state
tokenized_prompt
tokenized_prompt_mask
token_ar_mask
token_loss_mask
image
image.base_0_rgb
image.left_wrist_0_rgb
image.right_wrist_0_rgb
image_mask
image_mask.base_0_rgb
image_mask.left_wrist_0_rgb
image_mask.right_wrist_0_rgb
images
images.base_0_rgb
images.left_wrist_0_rgb
images.right_wrist_0_rgb


In [8]:
def pad_vector(vector, new_dim):
    """Pads the last dimension of a vector to `new_dim` with zeros."""
    current_dim = vector.shape[-1]
    if current_dim == new_dim:
        return vector  # OK during tracing

    pad_width = [(0, 0)] * vector.ndim
    pad_width[-1] = (0, new_dim - current_dim)
    return jnp.pad(vector, pad_width, mode="constant")

In [9]:
single_dict["state"] = pad_vector(single_dict["state"], 32)  
print(single_dict["state"].shape)
single_dict["prompt"] = "0"
policy.infer(single_dict)["actions"]

(32,)
stats.std shape  (8,)
x shape  (10, 32)
stats.mean shape  (8,)


array([[ 2.37385526e-02,  1.86081792e-01,  1.21667742e-02,
        -2.57891323e+00, -1.18848384e-01,  3.08934758e+00,
         8.66676721e-01,  1.05411286e+00,  9.87530742e-04,
        -1.63340732e-03,  1.46769134e-03,  1.45018246e-04,
         1.63424178e-03,  1.79445923e-03, -4.39793311e-04,
        -5.68688484e-04,  6.18458413e-04,  1.41337655e-04,
         4.84109409e-04,  1.19710088e-03,  4.51088403e-04,
        -4.72009654e-04, -7.24793205e-04, -2.63154770e-04,
         1.33228435e-03, -2.72572313e-04, -7.63357449e-04,
         3.07441065e-04, -6.18458413e-04, -3.61130003e-05,
        -4.38929043e-04,  2.91216665e-03],
       [ 2.41655590e-02,  1.86388751e-01,  1.36677216e-02,
        -2.57866729e+00, -1.19059219e-01,  3.08984818e+00,
         8.67079913e-01,  1.05342753e+00, -4.48465796e-04,
         6.44803692e-04,  8.19296464e-04,  1.73276836e-03,
        -1.33812561e-05,  1.79982365e-03, -1.33657589e-03,
        -6.72400670e-04, -2.30282775e-04, -9.86696276e-04,
        -5.43

# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [None]:
# config = _config.get_config("pi0_aloha_sim")

# checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_aloha_sim")
# key = jax.random.key(0)

# # Create a model from the checkpoint.
# model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# # We can create fake observations and actions to test the model.
# obs, act = config.model.fake_obs(), config.model.fake_act()

# # Sample actions from the model.
# loss = model.compute_loss(key, obs, act)
# print("Loss shape:", loss.shape)

2025-08-13 13:32:11.042224: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.96GiB (rounded to 2106589184)requested by op 
2025-08-13 13:32:11.042582: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ***************************************************************************************************_


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2106589184 bytes.

> [32m/mnt/data/josyula/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py[39m([92m243[39m)[36mbatched_device_put[39m[34m()[39m
[32m    241[39m       return array.ArrayImpl(
[32m    242[39m           aval, sharding, bufs, committed=committed, _skip_checks=True)
[32m--> 243[39m     [38;5;28;01mreturn[39;00m xc.batched_device_put(aval, sharding, xs, list(devices), committed)
[32m    244[39m   [38;5;28;01mfinally[39;00m:
[32m    245[39m     util.test_event([33m"batched_device_put_end"[39m)



Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)