In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from openpi.training import config as _config
import jax
import flax.nnx as nnx
from scripts.train import _load_weights_and_validate
from PIL import Image
import jax.numpy as jnp
import openpi.transforms as _transforms
import numpy as np
import openpi.training.data_loader as _data_loader
import openpi.models.tokenizer as _tokenizer
import openpi.training.sharding as sharding
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
from openpi.models import model as _model
from openpi.models.pi0_fast import Pi0FAST, make_attn_mask
import argparse
from plot_reward import plot_values

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_model(config) -> Pi0FAST:
    rng = jax.random.key(42)  # or any seed
    model_rng, _ = jax.random.split(rng)

    model = config.model.create(model_rng)

    params_shape = nnx.state(model).to_pure_dict()

    loaded_params = _load_weights_and_validate(config.weight_loader, params_shape)

    graphdef, state = nnx.split(model)
    state.replace_by_pure_dict(loaded_params)
    model = nnx.merge(graphdef, state)
    return model

def get_dataset(config):
    data_config = config.data.create(config.assets_dirs, config.model)
    dataset = _data_loader.create_torch_dataset(
        data_config, config.model.action_horizon, config.model
    )
    transformed_dataset = _data_loader.transform_dataset(dataset, data_config)
    return transformed_dataset

def get_episode_data_index(config):
    repo_id = config.data.repo_id
    dataset = lerobot_dataset.LeRobotDataset(repo_id)
    return dataset.episode_data_index


def get_observation(dataset, index):
    element = dataset[index]
    batched_element = jax.tree.map(
        lambda x: jnp.expand_dims(jnp.array(x), axis=0), element
    )
    observation = _model.Observation.from_dict(batched_element)
    return observation

In [4]:
config = _config.get_config("pi0_fast_libero")
dataset = get_dataset(config)
model = load_model(config)
episode_data_index = get_episode_data_index(config)

Some kwargs in processor config are unused and will not have any effect: min_token, scale, action_dim, time_horizon, vocab_size. 
Some kwargs in processor config are unused and will not have any effect: min_token, scale, action_dim, time_horizon, vocab_size. 
The dataset you requested (physical-intelligence/libero) is in 2.0 format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id=physical-intelligence/libero
```

If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).

The dataset you requested (physical-intelligence/libero) is in 2.0 format.
While current version of LeRobot is backward-compatible with it, the 

In [5]:
end_episodes = episode_data_index["to"].cpu().numpy()
index = int(end_episodes[0]) - 1 
observation = get_observation(dataset, index)

rng = jax.random.key(0)
train = False

preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)

In [12]:
input_token_embeddings, input_mask, ar_mask = model.embed_inputs(observation)
attn_mask = make_attn_mask(input_mask, ar_mask)
fused_sequence_embeddings, _, _ = model.PaliGemma.llm(
    embedded_prefix=input_token_embeddings,
    mask=attn_mask,
    return_prelogits=True,
)

In [21]:
mask_expanded = jnp.expand_dims(input_mask, axis=-1)
summed_embeddings = jnp.sum(fused_sequence_embeddings * mask_expanded, axis=1)
num_valid_tokens = jnp.sum(input_mask, axis=1, keepdims=True)
pooled_fused_embedding = summed_embeddings / jnp.maximum(num_valid_tokens, 1)

In [26]:
fused_sequence_embeddings[:, -1, :].shape

(1, 2048)