# Step 1: Minimal CrossFormer Inference Example

This Colab demonstrates how to load a pre-trained / finetuned CrossFormer checkpoint, run inference for a single-arm and bimanual manipulation system, and compare the outputs to the true actions.

First, let's start with a minimal example!

In [None]:
import matplotlib.pyplot as plt
import numpy as np


In [None]:
from crossformer.model.crossformer_model import CrossFormerModel

model = CrossFormerModel.load_pretrained("hf://rail-berkeley/crossformer")

In [None]:
import jax.numpy as jnp

def analyze_attention(model, observation, task):
    """
    Computes attention rollout from readout tokens to input tokens.
    Returns the attention rollout weights and a mapping from token indices to token names.
    
    Args:
        observations: Dictionary of observations 
        tasks: Dictionary of task specifications
        head_name: Name of the readout head to analyze
    Returns:
        rollout: Attention rollout weights from readout to input tokens
        token_map: Mapping from token indices to token types
    """
    # Run transformer with attention weights stored
    transformer_outputs = model.module.apply(
        {"params": model.params},
        observation,
        task,
        observation["timestep_pad_mask"],
        train=False,
        method="crossformer_transformer",
        mutable=["intermediates"],
    )

    outputs, variables = transformer_outputs

        # Count and print token breakdown
    token_counts = {}

    print(outputs.keys())
    # Count prefix tokens
    for prefix_group in outputs.keys():
        if prefix_group.startswith("task_"):
            n_tokens = outputs[prefix_group].tokens.shape[-2]
            token_counts[prefix_group] = n_tokens
            
    # Count observation tokens
    for obs_group in outputs.keys():
        if obs_group.startswith("obs_"):
            n_tokens = outputs[obs_group].tokens.shape[-2]
            token_counts[obs_group] = n_tokens
            
    # Count readout tokens
    head_names = ['readout_bimanual', 'readout_nav', 'readout_quadruped', 'readout_single_arm']
    for readout_key in head_names:
        if readout_key in outputs:
            n_tokens = outputs[readout_key].tokens.shape[-2]
            token_counts[readout_key] = n_tokens

    
    attention_weights = []


    
    # Extract attention weights from each transformer block
    for i in range(model.config["model"]["transformer_kwargs"]["num_layers"]):
        block_name = f'encoderblock_{i}'
        if block_name in variables['intermediates']['crossformer_transformer']['BlockTransformer_0']['Transformer_0']:
            layer_attention = variables['intermediates']['crossformer_transformer']['BlockTransformer_0']['Transformer_0'][block_name]['MultiHeadDotProductAttention_0']['attention_weights'][0]
            attention_weights.append(layer_attention)

    print("Token count breakdown:")
    total = 0
    for k, v in token_counts.items():
        print(f"{k}: {v} tokens")
        total += v
    print(f"Total tokens: {total}")

    
    # Average attention weights across heads
    attention_weights = [jnp.mean(weights, axis=1) for weights in attention_weights]
    
    # Build token type list
    token_types = []
    
    # Map prefix tokens
    for prefix_group in outputs.keys():
        print(prefix_group)
        if "_" in prefix_group:
            n_tokens = outputs[prefix_group].tokens.shape[-2]
            token_types.extend([prefix_group] * n_tokens)
            
    # # Map observation tokens
    # for obs_group in outputs.keys():
    #     if obs_group.startswith("obs_"):
    #         n_tokens = outputs[obs_group].tokens.shape[-2]
    #         token_types.extend([obs_group] * n_tokens)

    # # Map readout tokens
    # readout_key = f"readout_{head_name}"
    # if readout_key in outputs:
    #     n_tokens = outputs[readout_key].tokens.shape[-2]
    #     token_types.extend([readout_key] * n_tokens)

    # Compute attention rollout
    rollout = attention_weights[0]
    for attention in attention_weights[1:]:
        rollout = jnp.matmul(attention, rollout)
    
    # Normalize rollout
    rollout = rollout / rollout.sum(axis=-1, keepdims=True)
    
    return rollout, token_types


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

def plot_attention_rollout(
    observations: dict, 
    rollout: np.ndarray,
    token_types: list,
    save_path: str = None,
) -> plt.Figure:
    """
    Visualizes attention rollout from readout tokens to input tokens.
    
    Args:
        observations: Dictionary of observations
        rollout: Attention rollout matrix [num_timesteps, num_tokens]
        token_types: List of token type names
        save_path: Optional path to save visualization
    Returns:
        matplotlib Figure
    """
    # Create figure
    n_timesteps = observations["timestep_pad_mask"].shape[1]
    fig = plt.figure(figsize=(15, 5))
    
    for t in range(n_timesteps):
        plt.subplot(1, n_timesteps, t+1)
        
        # Get image tokens for this timestep
        obs_token_idxs = [i for i, name in enumerate(token_types) if name.startswith("obs_")]
        
        # Get observation attention
        obs_attention = rollout[t, obs_token_idxs]
        
        # Get image for this timestep
        for k,v in observations.items():
            if k.startswith("image_"):
                img = v[0,t]
                break
        
        # Plot image
        plt.imshow(img)
        
        # Calculate grid size for reshaping attention
        # The number of patches should be (H/patch_size) * (W/patch_size)
        H, W = img.shape[:2]
        patch_size = 32  # This is typically 16 or 32 for ViT
        num_patches_h = H // patch_size
        num_patches_w = W // patch_size
        
        # Reshape attention to match image patches
        try:
            attention_grid = obs_attention[:num_patches_h * num_patches_w].reshape(num_patches_h, num_patches_w)
        except ValueError:
            # If reshape fails, use interpolation to resize attention to match grid
            attention_grid = obs_attention.reshape(1, -1)
            attention_grid = jax.image.resize(attention_grid, (num_patches_h, num_patches_w), method='bilinear')
            
        # Resize attention grid to match image size
        attention_resized = jax.image.resize(attention_grid, img.shape[:2], method='bilinear')
        
        # Plot attention overlay
        plt.imshow(attention_resized, cmap='hot', alpha=0.5)
        plt.axis('off')
        plt.title(f'Timestep {t}')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    
    return fig

def plot_readout_attention(
    observations: dict,
    rollout: np.ndarray,
    token_types: list,
    save_path: str = None,
) -> plt.Figure:
    """
    Plots attention weights as a heatmap.
    
    Args:
        observations: Dictionary of observations
        rollout: Attention rollout matrix
        token_types: List of token type names
        save_path: Optional path to save visualization
    Returns:
        matplotlib Figure
    """
    # Create heatmap
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(rollout, cmap='viridis')
    
    # Add token labels
    # Only show a subset of ticks to avoid overcrowding
    n_tokens = len(token_types)
    step = max(1, n_tokens // 20)  # Show at most 20 ticks
    
    ax.set_xticks(np.arange(0, n_tokens, step))
    ax.set_yticks(np.arange(0, n_tokens, step))
    ax.set_xticklabels([token_types[i] for i in range(0, n_tokens, step)], rotation=45, ha='right')
    ax.set_yticklabels([token_types[i] for i in range(0, n_tokens, step)])
    
    # Add colorbar
    plt.colorbar(im)
    plt.title('Attention Rollout')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        
    return fig

In [None]:
# We'll demonstrate how to create an observation and task dictionary for a bimanual task. 
# Then we'll use them to sample an action from the model.

import jax
import numpy as np
# create a random image
img = np.random.randint(0, 255, size=(224, 224, 3))
# add batch and observation history dimension (CrossFormer accepts a history of up to 5 time-steps)
img = img[None, None]
# our bimanual training data has an overhead view and two wrist views
observation = {
    "image_high": img,
    "image_left_wrist": img,
    "image_right_wrist": img,
    "timestep_pad_mask": np.array([[True]]),
}
# create a task dictionary for a language task
task = model.create_tasks(texts=["uncap the pen"])

In [None]:
rollout, token_types = analyze_attention(model, observation, task)
print(rollout.shape)

In [None]:
print(len(token_types))

In [None]:
def get_observation_image(observation, observation_type):
    image=None
    for k in observation.keys():
        if observation_type in k:
            image = observation[k].squeeze()
            break
    return image

def plot_readout_attention2(
    rollouts,
    token_types,
    head,
    observations,
    observation_type = "_high",
    save_path: str = None,
) -> plt.Figure:
    """
    Plots attention weights from readout tokens as a heatmap.
    
    Args:
        model: CrossFormerModel instance  
        observations: Dictionary of observations
        tasks: Dictionary of task specifications
        readout_name: Name of readout head to analyze
        save_path: Optional path to save visualization
    Returns:
        matplotlib Figure
    """
    indexes_readout = []
    indexes_obs = []
    for i, j in enumerate(token_types):
        if j == head:
            indexes_readout.append(i)
        if observation_type in j:
            indexes_obs.append(i)

    num_timesteps =  rollouts.shape[0]
    num_images = len(indexes_readout)
    # Create a grid of subplots: num_images rows × num_timesteps columns
    fig, axs = plt.subplots(1, num_timesteps+1,squeeze=False) 
                           #figsize=(4*num_timesteps, 4*num_images),squeeze=False)

    observation_image = get_observation_image(observations, observation_type)
    im = axs[0, 0].imshow(observation_image, cmap='viridis')

    # If there's only one image, wrap axs in a list to make it 2D
    if num_images == 1:
        axs = np.array([axs])

    for t in range(num_timesteps):
        rollout = rollouts[t]
        # Create heatmap
    
        images_per_readout = []
        patch = np.zeros((32,32))
        for num_image in range(num_images):
            image = np.zeros((224,224))
            patches = []
            x=0
            y=0
            for index_obs in indexes_obs:
                attention = rollout[indexes_readout[num_image], index_obs]
                patch[:,:] = attention
                image[x:x+32, y:y+32] = attention
                x+= 32
                if x== 224:
                    x=0
                    y+=32
            images_per_readout.append(image.copy())
            # Plot the image in the corresponding subplot
        average_readout_image = np.asarray(images_per_readout).mean(0)
        # Plot original image
        axs[0, t+1].imshow(observation_image)
        # Overlay attention map with alpha
        # Convert attention to alpha values (darker where attention is higher)
        alpha = average_readout_image / average_readout_image.max()  # Normalize to [0,1]
        alpha = 1 - alpha  # Invert so high attention = darker
        
        # Create a dark overlay
        dark_overlay = np.zeros_like(observation_image)
        axs[0, t+1].imshow(dark_overlay, alpha=alpha, cmap='gray')

        axs[0,t+1].axis('off')  # Remove axes
            
        axs[0, t+1].set_title(f'Timestep {t}')
            
    
        
        
    # Add a colorbar that applies to all subplots
    #fig.colorbar(im, ax=axs.ravel().tolist())
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)

    return fig


In [None]:
fig = plot_readout_attention2(rollout, token_types, "readout_single_arm", observation)
#plot_attention_rollout(observation, rollout, token_types)

# Step 2: Run Inference on Full Trajectories

That was easy! Now let's try to run inference across a whole single-arm trajectory and visualize the results!

In [None]:
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np
import cv2

## Load Model Checkpoint
First, we will load the pre-trained checkpoint using the `load_pretrained()` function. You can specify the path to a checkpoint directory or a HuggingFace path.

Below, we are loading directly from HuggingFace.


In [None]:
from crossformer.model.crossformer_model import CrossFormerModel

model = CrossFormerModel.load_pretrained("hf://rail-berkeley/crossformer")

## Load Data
Next, we will load a trajectory from the Bridge dataset for testing the model. We will use the publicly available copy in the Open X-Embodiment dataset bucket.

In [None]:
import os
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['CURL_CA_BUNDLE'] = certifi.where()


# create RLDS dataset builder
builder = tfds.builder_from_directory(
    builder_dir="gs://gresearch/robotics/bridge/0.1.0/"
)
ds = builder.as_dataset(split="train[:1]")

# sample episode and resize to 224x224 (default third-person cam resolution)
episode = next(iter(ds))
steps = list(episode["steps"])
images = [
    cv2.resize(np.array(step["observation"]["image"]), (224, 224)) for step in steps
]

# extract goal image and language instruction
goal_image = images[-1]
language_instruction = (
    steps[0]["observation"]["natural_language_instruction"].numpy().decode()
)

# visualize episode
print(f"Instruction: {language_instruction}")
#mediapy.show_video(images, fps=10)

In [None]:
print(images[0].shape)


## Run Inference

Next, we will run inference over the images in the episode using the loaded model. 
Below we demonstrate setups for both goal-conditioned and language-conditioned training.
Note that we need to feed inputs of the correct temporal window size.

In [None]:
WINDOW_SIZE = 1

# create task dictionary
task = model.create_tasks(
    goals={"image_primary": goal_image[None]}
)  # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])  # for language conditioned

In [None]:
# run inference loop, the model only uses 3rd person image observations for bridge

# collect predicted and true actions
pred_actions, true_actions = [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step : step + WINDOW_SIZE])[None]
    observation = {
        "image_primary": input_images,
        "timestep_pad_mask": np.full((1, input_images.shape[1]), True, dtype=bool),
    }

    # we need to pass in the dataset statistics to unnormalize the actions
    actions = model.sample_actions(
        observation,
        task,
        head_name="single_arm",
        unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
        rng=jax.random.PRNGKey(0),
    )
    actions = actions[0]  # remove batch

    pred_actions.append(actions)
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(
        np.concatenate(
            (
                steps[final_window_step]["action"]["world_vector"],
                steps[final_window_step]["action"]["rotation_delta"],
                np.array(steps[final_window_step]["action"]["open_gripper"]).astype(
                    np.float32
                )[None],
            ),
            axis=-1,
        )
    )

## Visualize predictions and ground-truth actions

Finally, we will visualize the predicted actions in comparison to the groundtruth actions.

In [None]:
ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']

# build image strip to show above actions
img_strip = np.concatenate(np.array(images[::3]), axis=1)

# set up plt figure
figure_layout = [
    ['image'] * len(ACTION_DIM_LABELS),
    ACTION_DIM_LABELS
]
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplot_mosaic(figure_layout)
fig.set_size_inches([45, 10])

# plot actions
pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()
for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
  # actions have batch, horizon, dim, in this example we just take the first action for simplicity
  axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')
  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
  axs[action_label].set_title(action_label)
  axs[action_label].set_xlabel('Time in one episode')

axs['image'].imshow(img_strip)
axs['image'].set_xlabel('Time in one episode (subsampled)')
plt.legend()

In [None]:
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step : step + WINDOW_SIZE])[None]
    observation = {
        "image_primary": input_images,
        "timestep_pad_mask": np.full((1, input_images.shape[1]), True, dtype=bool),
    }

    # we need to pass in the dataset statistics to unnormalize the actions
    rollout, token_types = analyze_attention(model, observation, task)
    break

In [None]:
print(rollout.shape)
print(rollout)
print(rollout.max())

In [None]:
import importlib
importlib.reload(visualization_utils)
from crossformer.utils import visualization_utils
fig = visualization_utils.plot_readout_attention(rollout, token_types, "readout_single_arm", observation, observation_type="_primary", observation_image=observation["image_primary"][0,0])