# Step 1: Minimal CrossFormer Inference Example

This Colab demonstrates how to load a pre-trained / finetuned CrossFormer checkpoint, run inference for a single-arm and bimanual manipulation system, and compare the outputs to the true actions.

First, let's start with a minimal example!

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


In [None]:
from crossformer.model.crossformer_model import CrossFormerModel

model = CrossFormerModel.load_pretrained("hf://rail-berkeley/crossformer")

In [None]:
# New cells to add after the model loading section in inference_pretrained.ipynb

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def get_attention_weights(model, observation, task):
    """Extract attention weights from the model's transformer."""
    # Run transformer with store_attention=True
    transformer_outputs = model.module.apply(
        {"params": model.params},
        observation,
        task,
        observation["timestep_pad_mask"],
        train=False,
        method="crossformer_transformer",
        mutable=['intermediates']  # This allows accessing intermediate values
    )
    # print(len(transformer_outputs))
    # Get stored attention weights from all layers
    if len(transformer_outputs) > 1:  # Check if we got variables dict back
        _, variables = transformer_outputs
        attention_weights = []
        
        # Extract attention weights from each transformer block
        for i in range(12):  # Assuming 12 layers as in the paper
            block_name = f'encoderblock_{i}'
            print(variables['intermediates']['crossformer_transformer']['BlockTransformer_0']['Transformer_0']['encoderblock_8']['MultiHeadDotProductAttention_0']['attention_weights'][0].shape)
            if 'intermediates' in variables and block_name in variables['intermediates']['crossformer_transformer']['BlockTransformer_0']['Transformer_0']:
                layer_attention = variables['intermediates']['crossformer_transformer']['BlockTransformer_0']['Transformer_0'][block_name]['MultiHeadDotProductAttention_0']['attention_weights'][0]
                attention_weights.append(layer_attention)
        
        return attention_weights
    return None

def compute_rollout(attention_weights):
    """Compute attention rollout from attention weights."""
    # Average attention weights across heads
    attention_weights = [np.mean(layer_weights, axis=1) for layer_weights in attention_weights]
    
    # Initialize rollout with first layer attention
    rollout = attention_weights[0]
    
    # Propagate attention through layers
    for attention in attention_weights[1:]:
        rollout = np.matmul(attention, rollout)
    
    # Normalize rollout
    rollout = rollout / rollout.sum(axis=-1, keepdims=True)
    return rollout

def visualize_attention(attention_rollout, images, save_path=None):
    """Visualize attention rollout overlaid on input images."""
    # Get number of timesteps
    num_timesteps = len(images)
    
    # Create a figure with subplots for each timestep
    fig, axes = plt.subplots(2, num_timesteps, figsize=(4*num_timesteps, 8))
    
    for t in range(num_timesteps):
        print(t)
        # Plot original image
        # axes[0, t].imshow(images[t])
        # axes[0, t].axis('off')
        # axes[0, t].set_title(f'Timestep {t}')
        axes[t].imshow(images[t])
        axes[t].axis('off')
        axes[t].set_title(f'Timestep {t}')

        # Get attention weights for this timestep
        attention_map = attention_rollout[t].mean(axis=0)  # Average over batch dimension
        attention_map = attention_map.reshape(-1)  # Flatten attention map
        
        # Create attention heatmap overlay
        attention_resized = attention_map.reshape(14, 14)  # Assuming 14x14 attention grid
        attention_resized = cv2.resize(attention_resized, (224, 224))  # Resize to image size
        
        # Plot attention heatmap
        im = axes[1, t].imshow(attention_resized, cmap='hot', alpha=0.7)
        axes[1, t].axis('off')
        axes[1, t].set_title('Attention Map')
    
    plt.colorbar(im, ax=axes.ravel().tolist())
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    plt.show()

# Example usage cell:
# Get attention weights for our example sequence
# attention_weights = get_attention_weights(model, observation, task)

# if attention_weights:
#     # Compute attention rollout
#     rollout = compute_rollout(attention_weights)
    
#     # Visualize attention for each timestep
#     visualize_attention(rollout, images)
# else:
#     print("No attention weights were captured. Make sure the model is configured to store attention weights.")

# Modified original inference cell to include attention visualization
def run_inference_with_attention(model, images, goal_image=None, language_instruction=None):
    """Run inference and visualize attention patterns."""
    # Create task dictionary
    if goal_image is not None:
        task = model.create_tasks(goals={"image_primary": goal_image[None]})
    elif language_instruction is not None:
        task = model.create_tasks(texts=[language_instruction])
    else:
        raise ValueError("Must provide either goal image or language instruction")

    # Stack images into observation
    input_images = np.stack(images)[None]  # Add batch dimension
    observation = {
        "image_primary": input_images,
        "timestep_pad_mask": np.full((1, input_images.shape[1]), True, dtype=bool),
    }

    # Get attention weights and actions
    attention_weights = get_attention_weights(model, observation, task)
    actions = model.sample_actions(
        observation,
        task,
        head_name="single_arm",
        rng=jax.random.PRNGKey(0),
    )
    
    # Compute and visualize attention rollout
    if attention_weights:
        rollout = compute_rollout(attention_weights)
        visualize_attention(rollout, images)
    
    return actions, attention_weights

# Example usage:
# actions, attention_weights = run_inference_with_attention(
#     model,
#     images,
#     language_instruction="pick up the spoon"
# )

# You can also analyze specific attention patterns
def analyze_attention_patterns(attention_weights, layer_idx=None):
    """Analyze attention patterns in specific layers or across all layers."""
    if layer_idx is not None:
        # Analyze specific layer
        layer_attention = attention_weights[layer_idx]
        avg_attention = np.mean(layer_attention, axis=(0,1))  # Average over batch and heads
        
        plt.figure(figsize=(10,8))
        sns.heatmap(avg_attention, cmap='viridis')
        plt.title(f'Average Attention Pattern - Layer {layer_idx}')
        plt.show()
    else:
        # Analyze attention across all layers
        num_layers = len(attention_weights)
        fig, axes = plt.subplots(2, num_layers//2, figsize=(20,8))
        axes = axes.ravel()
        
        for i in range(num_layers):
            avg_attention = np.mean(attention_weights[i], axis=(0,1))
            sns.heatmap(avg_attention, cmap='viridis', ax=axes[i])
            axes[i].set_title(f'Layer {i}')
        
        plt.tight_layout()
        plt.show()

# # Example: Analyze attention patterns
# if attention_weights:
#     # Analyze a specific layer
#     analyze_attention_patterns(attention_weights, layer_idx=0)
    
#     # Analyze all layers
#     analyze_attention_patterns(attention_weights)

In [None]:
# We'll demonstrate how to create an observation and task dictionary for a bimanual task. 
# Then we'll use them to sample an action from the model.

import jax
import numpy as np
# create a random image
img = np.random.randint(0, 255, size=(224, 224, 3))
# add batch and observation history dimension (CrossFormer accepts a history of up to 5 time-steps)
img = img[None, None]
# our bimanual training data has an overhead view and two wrist views
observation = {
    "image_high": img,
    "image_left_wrist": img,
    "image_right_wrist": img,
    "timestep_pad_mask": np.array([[True]]),
}
# create a task dictionary for a language task
task = model.create_tasks(texts=["uncap the pen"])

In [None]:
# Get attention weights separately
attention_weights = get_attention_weights(model, observation, task)
if attention_weights:
    # Compute rollout
    rollout = compute_rollout(attention_weights)
    
    # Visualize attention maps
    visualize_attention(rollout, [img.squeeze()])
    
    # Analyze attention patterns
    analyze_attention_patterns(attention_weights, layer_idx=0)  # Look at first layer
    analyze_attention_patterns(attention_weights)  # Look at all layers


In [None]:
print(rollout.shape)
print(model.__dir__())
print(type(model.config))
print(model.config.keys())
print(model.config["model"].keys())
print(model.config["model"]["token_embedding_size"])
print(type(model.module))
print(model.module.__dir__())
#print(model.module)

#for i in range(len(attention_weights)):
#    print(attention_weights[i].shape)
#visualize_attention(rollout, [img.squeeze()])


In [None]:
rollout, token_map = model.analyze_attention(observation, task, head_name="single_arm")

In [None]:
fig = plot_readout_attention(
    model,
    observation,
    task,
    readout_name="readout_single_arm",  # or "bimanual" for bimanual actions
)
plt.show()

In [None]:
fig = plot_attention_rollout(
    model,
    observation,
    task,
    readout_name="readout_single_arm"  # or "readout_bimanual"
)
plt.show()


# Step 2: Run Inference on Full Trajectories

That was easy! Now let's try to run inference across a whole single-arm trajectory and visualize the results!

In [None]:
# Install mediapy for visualization
!pip install mediapy
!pip install opencv-python

In [None]:
import cv2
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np

## Load Model Checkpoint
First, we will load the pre-trained checkpoint using the `load_pretrained()` function. You can specify the path to a checkpoint directory or a HuggingFace path.

Below, we are loading directly from HuggingFace.


In [None]:
from crossformer.model.crossformer_model import CrossFormerModel

model = CrossFormerModel.load_pretrained("hf://rail-berkeley/crossformer")

## Load Data
Next, we will load a trajectory from the Bridge dataset for testing the model. We will use the publicly available copy in the Open X-Embodiment dataset bucket.

In [None]:
import os
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['CURL_CA_BUNDLE'] = certifi.where()


# create RLDS dataset builder
builder = tfds.builder_from_directory(
    builder_dir="gs://gresearch/robotics/bridge/0.1.0/"
)
ds = builder.as_dataset(split="train[:1]")

# sample episode and resize to 224x224 (default third-person cam resolution)
episode = next(iter(ds))
steps = list(episode["steps"])
images = [
    cv2.resize(np.array(step["observation"]["image"]), (224, 224)) for step in steps
]

# extract goal image and language instruction
goal_image = images[-1]
language_instruction = (
    steps[0]["observation"]["natural_language_instruction"].numpy().decode()
)

# visualize episode
print(f"Instruction: {language_instruction}")
mediapy.show_video(images, fps=10)

## Run Inference

Next, we will run inference over the images in the episode using the loaded model. 
Below we demonstrate setups for both goal-conditioned and language-conditioned training.
Note that we need to feed inputs of the correct temporal window size.

In [None]:
WINDOW_SIZE = 5

# create task dictionary
task = model.create_tasks(
    goals={"image_primary": goal_image[None]}
)  # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])  # for language conditioned

In [None]:
# run inference loop, the model only uses 3rd person image observations for bridge

# collect predicted and true actions
pred_actions, true_actions = [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step : step + WINDOW_SIZE])[None]
    observation = {
        "image_primary": input_images,
        "timestep_pad_mask": np.full((1, input_images.shape[1]), True, dtype=bool),
    }

    # we need to pass in the dataset statistics to unnormalize the actions
    actions = model.sample_actions(
        observation,
        task,
        head_name="single_arm",
        unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
        rng=jax.random.PRNGKey(0),
    )
    actions = actions[0]  # remove batch

    pred_actions.append(actions)
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(
        np.concatenate(
            (
                steps[final_window_step]["action"]["world_vector"],
                steps[final_window_step]["action"]["rotation_delta"],
                np.array(steps[final_window_step]["action"]["open_gripper"]).astype(
                    np.float32
                )[None],
            ),
            axis=-1,
        )
    )

## Visualize predictions and ground-truth actions

Finally, we will visualize the predicted actions in comparison to the groundtruth actions.

In [None]:
ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']

# build image strip to show above actions
img_strip = np.concatenate(np.array(images[::3]), axis=1)

# set up plt figure
figure_layout = [
    ['image'] * len(ACTION_DIM_LABELS),
    ACTION_DIM_LABELS
]
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplot_mosaic(figure_layout)
fig.set_size_inches([45, 10])

# plot actions
pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()
for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
  # actions have batch, horizon, dim, in this example we just take the first action for simplicity
  axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')
  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
  axs[action_label].set_title(action_label)
  axs[action_label].set_xlabel('Time in one episode')

axs['image'].imshow(img_strip)
axs['image'].set_xlabel('Time in one episode (subsampled)')
plt.legend()

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import jax
import jax.numpy as jnp
import numpy as np

def plot_full_evaluation(model, images, pred_actions, true_actions, tasks, steps_to_show=None):
    """
    Creates a comprehensive visualization showing images, attention, and action predictions.
    
    Args:
        model: CrossFormer model instance
        images: List of images from the episode
        pred_actions: Predicted actions array
        true_actions: Ground truth actions array
        tasks: Task dict for attention computation
        steps_to_show: Optional list of timesteps to show (default: every 3rd step)
    """
    if steps_to_show is None:
        steps_to_show = range(0, len(images), 3)
    
    # Build image strip
    img_strip = np.concatenate([images[i] for i in steps_to_show], axis=1)
    
    # Action dimension labels
    ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']
    
    # Set up the figure layout - now including attention maps
    figure_layout = [
        ['image'] * len(ACTION_DIM_LABELS),
        ['attention'] * len(ACTION_DIM_LABELS),
        ACTION_DIM_LABELS
    ]
    plt.rcParams.update({'font.size': 12})
    fig, axs = plt.subplot_mosaic(figure_layout, figsize=(45, 15))
    
    # Plot image strip
    axs['image'].imshow(img_strip)
    axs['image'].set_xlabel('Time in one episode (subsampled)')
    axs['image'].set_title('Episode Images')
    
    # For each timestep, create observation and get attention
    attention_maps = []
    for t in steps_to_show:
        # Create observation for this timestep with correct window size
        window_size = 5  # default window size, adjust if needed
        start_t = max(0, t - window_size + 1)
        obs_images = images[start_t:t+1]
        
        # Pad if needed
        if len(obs_images) < window_size:
            padding = [obs_images[0]] * (window_size - len(obs_images))
            obs_images = padding + obs_images
        
        observation = {
            "image_primary": np.stack(obs_images)[None],  # Add batch dim
            "timestep_pad_mask": np.ones((1, window_size), dtype=bool)
        }
        
        # Get transformer outputs
        transformer_outputs = model.run_transformer(
            observation, tasks, observation["timestep_pad_mask"], train=False
        )
        
        # Compute attention
        primary_tokens = transformer_outputs['obs'].tokens[0, -1]
        d_k = primary_tokens.shape[-1]
        attention_weights = jnp.einsum('nd,md->nm', primary_tokens, primary_tokens) / jnp.sqrt(d_k)
        attention_weights = jax.nn.softmax(attention_weights, axis=-1)
        
        # Reshape attention to grid
        n_tokens = attention_weights.shape[0]
        grid_size = int(np.sqrt(n_tokens))
        attention_map = attention_weights[:grid_size*grid_size, :grid_size*grid_size]
        attention_map = attention_map.reshape(grid_size, grid_size)
        attention_maps.append(attention_map)
    
    # Plot attention strip
    attention_strip = np.concatenate(attention_maps, axis=1)
    axs['attention'].imshow(attention_strip, cmap='viridis')
    axs['attention'].set_xlabel('Time in one episode (subsampled)')
    axs['attention'].set_title('Attention Maps')
    
    # Plot actions
    for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
        axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')
        axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
        axs[action_label].set_title(action_label)
        axs[action_label].set_xlabel('Time in one episode')
    
    plt.legend()
    return fig

# Example usage in the evaluation section:
"""
# Run evaluation with attention visualization
pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()

# Create task dict
task = model.create_tasks(texts=[language_instruction])

# Create visualization
fig = plot_full_evaluation(
    model,
    images,
    pred_actions,
    true_actions,
    task
)
plt.show()
"""

In [None]:
fig = plot_readout_attention(
    model,
    observation,
    task,
    readout_name="readout_single_arm"
)
plt.show()



In [None]:
fig = plot_attention_rollout(
    model,
    observation,
    task,
    readout_name="readout_single_arm"  # or "readout_bimanual"
)
plt.show()
