In [1]:
import numpy as np

import jax
import jax.numpy as jnp

from openpi.models import model as _model
from openpi.policies import libero_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
from openpi.training import checkpoints as _checkpoints
import openpi.transforms as _transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = _config.get_config("pi0_fast_libero_predictor")
checkpoint_dir = download.maybe_download("/scratch/s5649552/openpi/checkpoints/pi0_fast_libero_predictor/predictor_v1/1999")

In [3]:
model = config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))

In [4]:
data_config = config.data.create(config.assets_dirs, config.model)
norm_stats = _checkpoints.load_norm_stats(
    checkpoint_dir / "assets", data_config.asset_id
)

repack_transforms = _transforms.Group()
default_prompt = None


transforms = _transforms.compose([
    *repack_transforms.inputs,
    _transforms.InjectDefaultPrompt(default_prompt),
    *data_config.data_transforms.inputs,
    _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
    *data_config.model_transforms.inputs,
])

Some kwargs in processor config are unused and will not have any effect: vocab_size, min_token, action_dim, time_horizon, scale. 
Some kwargs in processor config are unused and will not have any effect: vocab_size, min_token, action_dim, time_horizon, scale. 


In [5]:
example = libero_policy.make_libero_example()
inputs = jax.tree.map(lambda x: x, example)
inputs = transforms(inputs)
inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
observation = _model.Observation.from_dict(inputs)
rng = jax.random.PRNGKey(0)
actions = jnp.zeros((1, 10, 7))

In [6]:
observation.images["base_0_rgb"] = jnp.repeat(observation.images["base_0_rgb"], 10, axis=0)
observation.images["base_0_rgb"].shape

(10, 224, 224, 3)

In [7]:
output = model.forward(rng, observation, actions)

x_noisy shape: (1, 10, 256, 2048)
timestep shape: (1,)
lc_his shape: (1, 10, 256, 2048)
a_next shape: (1, 10, 7)
is xbt same as xbta? True
is xbt same as xbta? True
is xbt same as xbta? True


In [8]:
output.shape

(1, 10, 256, 2048)