In [1]:
import os
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
import jax
import cv2
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np

import tensorflow as tf

tf.config.set_visible_devices([], "GPU")

import jax.numpy as jnp
from absl import app, flags
import optax
import json
import numpy as np
from functools import partial
import orbax.checkpoint as ocp
from tensorflow_text import SentencepieceTokenizer
# from palivla.dataset import prepare_image
from palivla.tokenizer import Tokenizer
from palivla.load_model import load_model_params_decode
from scalax.sharding import MeshShardingHelper, FSDPShardingRule, PartitionSpec
from flax.training.train_state import TrainState
from jax.experimental import multihost_utils
from ml_collections import config_flags
from palivla.train_state import PaliVLATrainState
from palivla.types import TrainingBatch, RolloutBatch

# export TPU_VISIBLE_DEVICES=0 
# export TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 
# export TPU_HOST_BOUNDS=1,1,1 
# export TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 
# export TPU_MESH_CONTROLLER_PORT=8476

os.environ['TPU_VISIBLE_DEVICES'] = '2'
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '1,1,1'
os.environ['TPU_HOST_BOUNDS'] = '1,1,1'
os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = 'localhost:8476'
os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'


  from .autonotebook import tqdm as notebook_tqdm
2024-12-01 07:35:54.097049: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733038554.117068 2283567 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733038554.123110 2283567 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-01 07:35:57.033342: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


## Load Datasets
Next, we will load a trajectory from the bridge dataset for testing the model. We will use the publicly available copy in the Open X-Embodiment dataset bucket.

In [2]:
# create RLDS dataset builder
builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')
ind = np.random.randint(0, 1000)
ds = builder.as_dataset(split=f'train[{ind}:{ind+1}]')

# sample episode + resize to 256x256 (default third-person cam resolution)
episode = next(iter(ds))
steps = list(episode['steps'])
images = [cv2.resize(np.array(step['observation']['image']), (224, 224)) for step in steps]

# extract goal image & language instruction
goal_image = images[-1]
language_instruction = steps[0]['observation']['natural_language_instruction'].numpy().decode()

# visualize episode
print(f'Instruction: {language_instruction}')
mediapy.show_video(images, fps=10)

Instruction: PICK UP THE VESSEL AND PUT ANOTHER SIDE OF THE TABLE


0
This browser does not support the video tag.


In [3]:


# action_mean = np.array(dataset_statistics[dataset_name]["action"]["mean"])
# action_std = np.array(dataset_statistics[dataset_name]["action"]["std"])
# action_mask = np.array(dataset_statistics[dataset_name]["action"]["mask"])
def make_inference_batch(batch):
        sensors = {
            k: batch["observation"][k][None].numpy()
            for k in batch["observation"]
            if k in model.model_state.model.modality_mappings and k != "text"
        }
        sensors_mask = {
            k: batch["observation"]["pad_mask_dict"][k].numpy()
            for k in model.model_state.model.modality_mappings
            if k != "text"
        }
        return TrainingBatch(
                sensors=sensors,
                sensors_mask=sensors_mask,
                sensors_next=None,
                sensors_next_mask=None,
                actions=batch["action"],
                tokens=batch["tokens"].numpy(),
                tokens_ar=batch["mask_ar"].numpy(),
                tokens_loss=batch.get("mask_loss", None),
                tokens_mask=batch["mask_input"].numpy(),
                rewards=None,
                td_mask=None,
                mc_returns=None,
                next_actions=None,
                next_tokens=None,
                gen_start=None,
            )

# Do inference
def do_inference(images, instructions, action):
        data = {
            "observation": {"image_primary": images, "pad_mask_dict": {"image_primary": tf.ones(len(images), dtype=tf.bool)}},
            "task": {"language_instruction": instructions},
            "action": action,
        }
        language_token_instructions = tokenizer.tokenize_language_instruction(data)
        # batch = prepare_image(batch)
        batch = tokenizer.prepare_tokens_for_training(data, language_token_instructions)
        batch = batch | data
        batch = make_inference_batch(batch)
        key = jax.random.PRNGKey(0)
        key, key_value = jax.random.split(key)
        all_inputs = batch.sensors | {"text": batch.tokens[..., :-1][None]}
        all_masks = batch.sensors_mask | {
            "text": jnp.ones_like(batch.tokens[..., :-1][None], dtype=jnp.bool_)
        }
        logits, info = model.model_state.apply_fn(
            {"params": model.model_state.params},
            all_inputs,
            data_masks=all_masks,
            text_ar_mask=batch.tokens_ar[..., :-1][None],
            train=False,
            rngs={"dropout": key},
        )
        values = info["values"]
        qs = get_value(values, batch.tokens[..., 1:][None], tokenizer.config)

        return qs

def normalize_action(action, unnormalization_statistics):
    mask = unnormalization_statistics.get(
        "mask", jnp.ones_like(unnormalization_statistics["mean"], dtype=bool)
    )
    action = action[..., : len(mask)]
    action = jnp.where(
        mask,
        2 * (action - unnormalization_statistics["p01"]) / (unnormalization_statistics["p99"] - unnormalization_statistics["p01"]) - 1,
        action
    )
    return action

def unnormalize_action_minmax(action, unnormalization_statistics):
    mask = unnormalization_statistics.get(
        "mask", jnp.ones_like(unnormalization_statistics["mean"], dtype=bool)
    )
    action = action[..., : len(mask)]
    action = jnp.where(
        mask,
        (action + 1) / 2 * (unnormalization_statistics["p99"] - unnormalization_statistics["p01"]) + unnormalization_statistics["p01"],
        action,
    )

    return action

def get_value(pred_values, tokens, tokenizer_config: Tokenizer.TokenizerConfig):
        value_token_starts = jnp.argmax(tokens == tokenizer_config.end_of_action_token, axis=-1)
        _get_values = jax.vmap(
            lambda x, i: jax.lax.dynamic_slice(x, (i,), (1,))
        )
        qs = jax.vmap(_get_values, in_axes=(-1, None), out_axes=-1)(
            pred_values, value_token_starts
        ).squeeze()
        return qs

In [4]:
resume_from_checkpoint_dir = "gs://rail-tpus-mitsuhiko-central2/logs/test/dulcet-leaf-45/"
# resume_from_checkpoint_step = "gs://rail-tpus-mitsuhiko-central2/logs/test/bright-thunder-2/"
resume_from_checkpoint_step = 100000
dataset_name = "bridge_dataset"

In [5]:
# Sharding
mesh = MeshShardingHelper([-1], ["fsdp"])

model_sharding = FSDPShardingRule("fsdp", fsdp_axis_size=mesh.mesh.shape["fsdp"])
data_sharding = PartitionSpec("fsdp")
# data_sharding = jax.sharding.SingleDeviceSharding(jax.local_devices()[0])

restore_checkpoint_manager = ocp.CheckpointManager(
        resume_from_checkpoint_dir,
        item_handlers=PaliVLATrainState.get_checkpoint_handlers(),
)

model = PaliVLATrainState.restore(
        checkpoint_manager=restore_checkpoint_manager,
        step=resume_from_checkpoint_step,
        load_optimizer=False,
        mesh=mesh,
        model_sharding=model_sharding,
        data_sharding=data_sharding,
    )
tokenizer = model.tokenizer
decode = model.decode
dataset_statistics = model.dataset_statistics


I1201 07:36:04.858799 2288254 google_auth_provider.cc:181] Running on GCE, using service account 180902422847-compute@developer.gserviceaccount.com
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [6]:
tokenizer

Tokenizer(config=Tokenizer.TokenizerConfig(action_vocab_size=256, action_vocab_offset=256000, vocab_size=<bound method SentencepieceTokenizer.vocab_size of <tensorflow_text.python.ops.sentencepiece_tokenizer.SentencepieceTokenizer object at 0x7fee240fbd30>>, num_action_tokens=7, bos_token=2, eos_token=1, pad_token=0, begin_of_action_token=108, end_of_action_token=3, max_pad_length=60, min_action_value=-2, max_action_value=2, prompt_autoregressive=False), language_tokenizer=<tensorflow_text.python.ops.sentencepiece_tokenizer.SentencepieceTokenizer object at 0x7fee240fbd30>, token_structure=FrozenDict({
    causal: ('action', (3,)),
    pad: ((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),),
    prefix: ((2,), 'prompt', (108,)),
}), action_tokenizer=BinActionTokenizer(
    # attributes
    min_action_value = -2
    max_action_value = 2
    action_dim = 7
 

## Run Inference

Next, we will run inference over the images in the episode using the loaded model. 
Below we demonstrate setups for both, goal-conditioned and language-conditioned training.
Note that we need to feed inputs of the correct temporal window size.

In [7]:
WINDOW_SIZE = 1
# run inference loop, this model only uses single image observations for bridge
# collect predicted and true actions
pred_actions, true_actions = [], []
pred_values = []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_image = images[step][None]
    input_image = tf.cast(input_image, tf.float32) / 127.5 - 1.0
    instruction = tf.constant(language_instruction)

    true_action = np.concatenate(
        (
            steps[step]['action']['world_vector'], 
            steps[step]['action']['rotation_delta'], 
            np.array(steps[step]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    )
    action = normalize_action(true_action[None], dataset_statistics[dataset_name]["action"])


    value = do_inference(input_image, instruction, action[None])
    print(value)
    pred_values.append(value)

    true_actions.append(true_action)

I0000 00:00:1733038598.880065 2284568 service.cc:148] XLA service 0x7feb00479bc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733038598.880145 2284568 service.cc:156]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1733038598.901904 2284568 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
  2%|▏         | 1/49 [00:22<17:45, 22.20s/it]

-25.903908


  4%|▍         | 2/49 [00:26<09:18, 11.89s/it]

-25.748095


  6%|▌         | 3/49 [00:31<06:34,  8.58s/it]

-26.174894


  8%|▊         | 4/49 [00:36<05:16,  7.03s/it]

-26.259235


 10%|█         | 5/49 [00:41<04:35,  6.25s/it]

-25.955688


 12%|█▏        | 6/49 [00:45<04:07,  5.75s/it]

-26.429398


 14%|█▍        | 7/49 [00:50<03:46,  5.39s/it]

-26.432339


 16%|█▋        | 8/49 [00:55<03:30,  5.14s/it]

-26.123074


 18%|█▊        | 9/49 [00:59<03:20,  5.01s/it]

-26.008825


 20%|██        | 10/49 [01:04<03:10,  4.88s/it]

-25.432974


 22%|██▏       | 11/49 [01:09<03:03,  4.83s/it]

-24.889908


 24%|██▍       | 12/49 [01:13<02:56,  4.76s/it]

-25.021587


 27%|██▋       | 13/49 [01:18<02:49,  4.72s/it]

-24.81437


 29%|██▊       | 14/49 [01:22<02:43,  4.66s/it]

-24.888048


 31%|███       | 15/49 [01:27<02:37,  4.64s/it]

-24.995699


 33%|███▎      | 16/49 [01:31<02:31,  4.61s/it]

-25.206852


 35%|███▍      | 17/49 [01:36<02:28,  4.66s/it]

-24.904936


 37%|███▋      | 18/49 [01:41<02:23,  4.63s/it]

-25.429453


 39%|███▉      | 19/49 [01:45<02:19,  4.64s/it]

-24.814856


 41%|████      | 20/49 [01:50<02:16,  4.72s/it]

-24.939505


 43%|████▎     | 21/49 [01:55<02:11,  4.70s/it]

-24.984694


 45%|████▍     | 22/49 [02:00<02:05,  4.67s/it]

-25.067413


 47%|████▋     | 23/49 [02:04<02:00,  4.64s/it]

-25.013424


 49%|████▉     | 24/49 [02:09<01:56,  4.65s/it]

-25.278107


 51%|█████     | 25/49 [02:13<01:51,  4.64s/it]

-25.011715


 53%|█████▎    | 26/49 [02:18<01:47,  4.65s/it]

-24.848564


 55%|█████▌    | 27/49 [02:23<01:42,  4.66s/it]

-24.803501


 57%|█████▋    | 28/49 [02:27<01:36,  4.62s/it]

-25.02766


 57%|█████▋    | 28/49 [02:29<01:52,  5.35s/it]


KeyboardInterrupt: 

## Visualize predictions and ground-truth actions

Finally, we will visualize the predicted actions in comparison to the groundtruth actions.

In [None]:
import matplotlib.pyplot as plt

ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']

# build image strip to show above actions
img_strip = np.concatenate(np.array(images[::3]), axis=1)

# set up plt figure
figure_layout = [
    ['image'] * len(ACTION_DIM_LABELS),
    ACTION_DIM_LABELS,
    ["value"] * len(ACTION_DIM_LABELS),
]
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplot_mosaic(figure_layout)
fig.set_size_inches([45, 10])

# plot actions
# pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()
for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
  # actions have batch, horizon, dim, in this example we just take the first action for simplicity
  # axs[action_label].plot(pred_actions[:, action_dim], label='predicted action')
  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
  axs[action_label].set_title(action_label)
  axs[action_label].set_xlabel('Time in one episode')

# plot value
axs['value'].plot(pred_values)

axs['image'].imshow(img_strip)
axs['image'].set_xlabel('Time in one episode (subsampled)')
plt.legend()