In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OPENPI_DATA_HOME"] = "/home/lperez/.cache/openpi"

In [2]:
import numpy as np
import jax
import cv2
from flax import nnx
import time
from openpi.models import model as _model
import openpi.shared.nnx_utils as nnx_utils
import jax.numpy as jnp
from openpi.training.config import get_config
from openpi.models.tokenizer import PaligemmaTokenizer
from openpi.models.model import Observation
from openpi.models.pi0 import make_attn_mask

  import pynvml  # type: ignore[import]


In [3]:
PALIGEMMA_EOS_TOKEN = 1
max_decoding_steps = 20
temperature = 0.1

### Step 1: Initialize model and load pretrained params
config = get_config("right_pi05_20")
model_rng = jax.random.key(0)
rng = jax.random.key(0)
model = config.model.create(model_rng)

# Load pretrained params
print(f"debug line 0-1")
graphdef, state = nnx.split(model)
loader = config.weight_loader
params = nnx.state(model)
# Convert frozen params to bfloat16.
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))


debug line 0-1


In [4]:

params_shape = params.to_pure_dict()
loaded_params = loader.load(params_shape)
state.replace_by_pure_dict(loaded_params)
model = nnx.merge(graphdef, state)

### Step 2: Construct an observation batch
# load 3 images from tmp_test as uint8 format
img_share_path = '/home/lperez/main/nh/openpi_subtask_generation/test'
# img_name_list = ['to.png', 'lef.png', 'righ.png']

In [5]:
# img_name_list = ['leftImg.png', 'rightImg.png', 'faceImg.png']
# img_name_list = ['cats/1.png', "cats/2.png", "cats/3.png"]
# img_name_list = ['cup/1.png', "cup/2.png", "cup/3.png"]
img_name_list = ['open/1.png', "open/2.png", "open/3.png"]

In [6]:
img_list = []
for img_name in img_name_list:
    img_path = os.path.join(img_share_path, img_name)
    img = cv2.imread(img_path)
    img_list.append(img)
# Convert images from [0, 255] to [-1, 1] range as expected by the model
img_dict = {
    "base_0_rgb": jnp.array(img_list[0][np.newaxis, :, :, :]).astype(jnp.float32) / 127.5 - 1.0,
    "left_wrist_0_rgb": jnp.array(img_list[1][np.newaxis, :, :, :]).astype(jnp.float32) / 127.5 - 1.0,
    "right_wrist_0_rgb": jnp.array(img_list[2][np.newaxis, :, :, :]).astype(jnp.float32) / 127.5 - 1.0,
}

# Pick up the flashcard on the table
# Tokenize the prompt
# high_level_prompt = 'Put the fruits on the basket'
high_level_prompt = 'Put up the flashcard on the table'
low_level_prompt = 'ASD'
tokenizer = PaligemmaTokenizer(max_len=50)
tokenized_prompt, tokenized_prompt_mask, token_ar_mask, token_loss_mask = tokenizer.tokenize_high_low_prompt(high_level_prompt, low_level_prompt)


In [7]:
# form a observation
data = {
    'image': img_dict,
    'image_mask': {key: jnp.ones(1, dtype=jnp.bool) for key in img_dict.keys()},
    'state': jnp.zeros((1, 32), dtype=jnp.float32),
    # 'state': None,
    'tokenized_prompt': jnp.stack([tokenized_prompt], axis=0),
    'tokenized_prompt_mask': jnp.stack([tokenized_prompt_mask], axis=0),
    'token_ar_mask': jnp.stack([token_ar_mask], axis=0),
    'token_loss_mask': jnp.stack([token_loss_mask], axis=0),
}
observation = Observation.from_dict(data)
rng = jax.random.key(42)
observation = _model.preprocess_observation(rng, observation, train=False, image_keys=list(observation.images.keys()))

# Set the low level task tokens to padding according to the loss mask (loss mask is the indication of low-level prompt)
# We move it from inside model to outside because the inside func need to be jittable
loss_mask = jnp.array(observation.token_loss_mask)
new_tokenized_prompt = observation.tokenized_prompt.at[loss_mask].set(0)
new_tokenized_prompt_mask = observation.tokenized_prompt_mask.at[loss_mask].set(False)


In [8]:
new_observation = _model.Observation(
                    images=observation.images,
                    image_masks=observation.image_masks,
                    state=observation.state,
                    tokenized_prompt=new_tokenized_prompt,
                    tokenized_prompt_mask=new_tokenized_prompt_mask,
                    token_ar_mask=observation.token_ar_mask,
                    token_loss_mask=observation.token_loss_mask,
                    )

In [9]:
# form a observation
data = {
    'image': img_dict,
    'image_mask': {key: jnp.ones(1, dtype=jnp.bool) for key in img_dict.keys()},
    'state': jnp.zeros((1, 32), dtype=jnp.float32),
    # 'state': None,
    'tokenized_prompt': jnp.stack([tokenized_prompt], axis=0),
    'tokenized_prompt_mask': jnp.stack([tokenized_prompt_mask], axis=0),
    'token_ar_mask': jnp.stack([token_ar_mask], axis=0),
    'token_loss_mask': jnp.stack([token_loss_mask], axis=0),
}
observation = Observation.from_dict(data)
rng = jax.random.key(42)
observation = _model.preprocess_observation(rng, observation, train=False, image_keys=list(observation.images.keys()))

# Set the low level task tokens to padding according to the loss mask (loss mask is the indication of low-level prompt)
# We move it from inside model to outside because the inside func need to be jittable
loss_mask = jnp.array(observation.token_loss_mask)
new_tokenized_prompt = observation.tokenized_prompt.at[loss_mask].set(0)
new_tokenized_prompt_mask = observation.tokenized_prompt_mask.at[loss_mask].set(False)
new_observation = _model.Observation(
                    images=observation.images,
                    image_masks=observation.image_masks,
                    state=observation.state,
                    tokenized_prompt=new_tokenized_prompt,
                    tokenized_prompt_mask=new_tokenized_prompt_mask,
                    token_ar_mask=observation.token_ar_mask,
                    token_loss_mask=observation.token_loss_mask,
                    )
observation = _model.preprocess_observation(None, new_observation, train=False, image_keys=list(observation.images.keys()))
observation = jax.tree.map(jax.device_put, observation)

In [10]:
observation.tokenized_prompt

Array([[     2,   7071, 235292,   2507,    908,    573,  12995,   5306,
           611,    573,   3037, 235265,   4284,   8277, 235292, 235248,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,
             0,      0]], dtype=int32)

In [11]:
observation.tokenized_prompt_mask

Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False]], dtype=bool)

In [12]:
from openpi.models.pi05 import left_to_right_align
from openpi.models.pi05 import put_along_last_axis

In [13]:
batch_size = observation.tokenized_prompt.shape[0]
prefix_token_embeddings, prefix_mask, prefix_ar_mask = model.embed_prefix(observation)

In [14]:
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)

In [15]:
prefix_attn_mask

Array([[[ True,  True,  True, ..., False, False, False],
        [ True,  True,  True, ..., False, False, False],
        [ True,  True,  True, ..., False, False, False],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False]]], dtype=bool)

In [16]:
        # left to right align all input token sequences
        prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
            prefix_token_embeddings, prefix_mask, prefix_attn_mask
        )

In [17]:

        prefill_size = prefix_token_embeddings.shape[1]
        prefill_len = jnp.sum(prefix_mask, axis=-1)
        prefix_start = prefill_size - prefill_len

        prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
        prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1

In [18]:
prefix_attn_mask[0][818]

Array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,

FIRST CALL

In [19]:
(prefix_out, _), kv_cache = model.PaliGemma.llm(
            [prefix_token_embeddings, None], mask=prefix_attn_mask, positions=prefix_positions, adarms_cond=[None, None]
        )

In [20]:
       
        last_token_embedding = prefix_out[:, -1:]
        last_logits = model.PaliGemma.llm(last_token_embedding, method="deembed")
        last_logits = jax.nn.log_softmax(last_logits, axis=-1)
        output_tokens = jnp.zeros((batch_size, max_decoding_steps))

In [21]:
        def step(carry):
            rng, last_logit, output_tokens, cache, _, step = carry

            # Sample token from last logit
            # Split RNG for this step
            rng, rng_step = jax.random.split(rng)
            token = jax.lax.cond(
                temperature > 0.0,
                lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
                lambda _: jnp.argmax(last_logit, axis=-1),
                operand=None,
            )
            output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)

            # Check for early stopping --> stop if all batch elements have EOS token
            ### TODO: erase extra decoded token due to mismatch
            has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
            all_eos = jnp.all(has_eos)

            # Decode one step
            token_embedding =  model.PaliGemma.llm(token, method="embed")
            positions = prefill_len[:, None] + step
            jax.debug.print("positions: {s}", s=positions)
            mask = jnp.logical_and(
                jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
                jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
                < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
            )
            #jax.debug.print("step: {s}", s=step)
            jax.debug.print("runningmask.shape . {m}", m=mask.astype(jnp.int32).shape)
            jax.debug.print("runningmask. . {m}", m=mask.astype(jnp.int32))
            (prefix_out, _), kv_cache = model.PaliGemma.llm(
                [token_embedding, None], mask=mask, positions=positions, adarms_cond=[None, None], kv_cache=cache
            )
            last_token_embedding = prefix_out[:, -1:]
            last_logits = model.PaliGemma.llm(last_token_embedding, method="deembed")
            last_logits = jax.nn.log_softmax(last_logits, axis=-1)

            #jax.debug.print("STEP: {s}", s=step)
            #jax.debug.print("MASK: {m}", m=mask.astype(jnp.int32))
            #jax.debug.print("POSITIONS: {p}", p=positions.astype(jnp.int32))

            return rng, last_logits, output_tokens, kv_cache, all_eos, step + 1

In [22]:
        def cond(carry):
            _, _, _, _, all_eos, step = carry
            return (~all_eos) & (step < max_decoding_steps)

        # Use lax.while_loop so we can jit the full decoding loop.
        _, _, output_tokens, kv_cache, _, _ = jax.lax.while_loop(
            cond, step, (rng, last_logits, output_tokens, kv_cache, False, 0)
        )

positions: [[784]]
runningmask.shape . (Array(1, dtype=int32), Array(1, dtype=int32), Array(838, dtype=int32))
runningmask. . [[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
   1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 

In [23]:
tokenizer.detokenize(np.array(output_tokens, dtype=np.int32))

'pick up white tape'

Prompt, describe image, open: "The image features a person's legs, a foot, and a foot. The person is wearing"