In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OPENPI_DATA_HOME"] = "/home/lperez/.cache/openpi"

In [2]:
import numpy as np
import jax
import cv2
from flax import nnx
import time
from openpi.models import model as _model
import openpi.shared.nnx_utils as nnx_utils
import jax.numpy as jnp
from openpi.training.config import get_config
from openpi.models.tokenizer import PaligemmaTokenizer
from openpi.models.model import Observation
from openpi.models.pi0 import make_attn_mask

In [3]:
PALIGEMMA_EOS_TOKEN = 1
max_decoding_steps = 20
temperature = 0.1

### Step 1: Initialize model and load pretrained params
config = get_config("right_pi05_20")
model_rng = jax.random.key(0)
rng = jax.random.key(0)
model = config.model.create(model_rng)

# Load pretrained params
print(f"debug line 0-1")
graphdef, state = nnx.split(model)
loader = config.weight_loader
params = nnx.state(model)
# Convert frozen params to bfloat16.
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))


debug line 0-1


In [None]:

params_shape = params.to_pure_dict()
loaded_params = loader.load(params_shape)
state.replace_by_pure_dict(loaded_params)
model = nnx.merge(graphdef, state)

### Step 2: Construct an observation batch
# load 3 images from tmp_test as uint8 format
img_share_path = '/home/lperez/main/nh/openpi/test'
img_paths = ['nh/1.png', 'nh/2.png', 'nh/3.png']

In [28]:
img_list = []
for img_name in img_paths:
    img_path = os.path.join(img_share_path, img_name)
    img = cv2.imread(img_path)
    img_list.append(img)
# Convert images from [0, 255] to [-1, 1] range as expected by the model
img_dict = {
    "base_0_rgb": jnp.array(img_list[0][np.newaxis, :, :, :]).astype(jnp.float32) / 127.5 - 1.0,
    "left_wrist_0_rgb": jnp.array(img_list[1][np.newaxis, :, :, :]).astype(jnp.float32) / 127.5 - 1.0,
    "right_wrist_0_rgb": jnp.array(img_list[2][np.newaxis, :, :, :]).astype(jnp.float32) / 127.5 - 1.0,
}

# Pick up the flascard on the table
# Tokenize the prompt
high_level_prompt = 'Put the fruits in the basket'
low_level_prompt = 'Put the apple in the basket'
tokenizer = PaligemmaTokenizer(max_len=50)
tokenized_prompt, tokenized_prompt_mask, token_ar_mask, token_loss_mask = tokenizer.tokenize_high_low_prompt(high_level_prompt, low_level_prompt)


In [29]:
# form a observation
data = {
    'image': img_dict,
    'image_mask': {key: jnp.ones(1, dtype=jnp.bool) for key in img_dict.keys()},
    'state': jnp.zeros((1, 32), dtype=jnp.float32),
    # 'state': None,
    'tokenized_prompt': jnp.stack([tokenized_prompt], axis=0),
    'tokenized_prompt_mask': jnp.stack([tokenized_prompt_mask], axis=0),
    'token_ar_mask': jnp.stack([token_ar_mask], axis=0),
    'token_loss_mask': jnp.stack([token_loss_mask], axis=0),
}
observation = Observation.from_dict(data)
rng = jax.random.key(42)
observation = _model.preprocess_observation(rng, observation, train=False, image_keys=list(observation.images.keys()))


In [30]:
#  observation = _model.preprocess_observation(None, new_observation, train=False, image_keys=list(observation.images.keys()))
observation = jax.tree.map(jax.device_put, observation)

In [31]:
rng = jax.random.key(42)
actions = jnp.zeros((1, 20, 32))
real_action_dim = 32

In [32]:
tok = observation.tokenized_prompt
tokenizer.detokenize(np.array(tok, dtype=np.int32))

'Task: put the fruits in the basket. Subtask: put the apple in the basket.;\nAction: '

In [39]:
        prefix_token_embeddings, prefix_mask, prefix_ar_mask = model.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)

In [41]:
prefix_ar_mask

Array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,

In [40]:
 prefix_mask

Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
      

In [None]:

        ### 1. Subtask-Generation Loss (Cross-Entropy Loss)
        # Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
        targets = jax.nn.one_hot(
            observation.tokenized_prompt[:, 1:],
            model.PaliGemma.llm.module.vocab_size,
        )

In [34]:
        # Use prefix tokens to perform subtask generation (Prefix: images*3, high-level prompt, low-level prompt, state?)
        # We input the last token because the last token is used for flow loss
        prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
        (prefix_out, _), kv_cache = model.PaliGemma.llm(
            [prefix_token_embeddings, None], 
            mask=prefix_attn_mask, 
            positions=prefix_positions, 
            adarms_cond=[None, None]
        )
        prefix_out = prefix_out[:, :-1]

        # decode from embedding to logits
        logits = model.PaliGemma.llm(
            prefix_out[:, -targets.shape[1] :], method='deembed'
        )
        logp = jax.nn.log_softmax(logits, axis=-1)

In [35]:
        # Compute CE loss on token targets
        assert observation.token_loss_mask is not None, "Token loss mask is required"
        loss_mask = observation.token_loss_mask[:, 1:]
        token_pplx = jnp.sum(targets * logp, axis=-1)
        subtask_generation_loss = -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)

        ### 2. Flow Matching Loss (MSE Loss)
        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
        batch_shape = actions.shape[:-2]
        noise = jax.random.normal(noise_rng, actions.shape)
        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
        time_expanded = time[..., None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

In [36]:
        suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = model.embed_suffix(observation, x_t, time)
        input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
        ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
        attn_mask = make_attn_mask(input_mask, ar_mask)
        attn_mask = attn_mask[:, -suffix_tokens.shape[1]:, :] # Q is [B, action_dim, ...], KV is full length
        positions = jnp.cumsum(input_mask, axis=1) - 1
        positions = positions[:, -suffix_tokens.shape[1]:]
        (_, suffix_out), _ = model.PaliGemma.llm(
            [None, suffix_tokens], kv_cache=kv_cache, mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
        )
        v_t = model.action_out_proj(suffix_out[:, -model.action_horizon :])

        # Calculate flow loss with true actions (Real Action Dim <= Action Dim (Padding))
        flow_loss = jnp.mean(jnp.square(v_t[:, :, :real_action_dim] - u_t[:, :, :real_action_dim]), axis=-1)

In [37]:
loss = subtask_generation_loss + jnp.mean(flow_loss, axis=-1)


In [38]:
loss

Array([8.817083], dtype=float32)