# Step 1: Minimal CrossFormer Inference Example

This Colab demonstrates how to load a pre-trained / finetuned CrossFormer checkpoint, run inference for a single-arm and bimanual manipulation system, and compare the outputs to the true actions.

First, let's start with a minimal example!

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy

In [None]:
from crossformer.model.crossformer_model import CrossFormerModel

model = CrossFormerModel.load_pretrained("hf://rail-berkeley/crossformer")

In [None]:
import os
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['CURL_CA_BUNDLE'] = certifi.where()


# create RLDS dataset builder
builder = tfds.builder_from_directory(
    builder_dir="gs://gresearch/robotics/bridge/0.1.0/"
)
ds = builder.as_dataset(split="train[:1]")

# sample episode and resize to 224x224 (default third-person cam resolution)
episode = next(iter(ds))
steps = list(episode["steps"])
images = [
    cv2.resize(np.array(step["observation"]["image"]), (224, 224)) for step in steps
]

# extract goal image and language instruction
goal_image = images[-1]
language_instruction = (
    steps[0]["observation"]["natural_language_instruction"].numpy().decode()
)

# visualize episode
print(f"Instruction: {language_instruction}")
#mediapy.show_video(images, fps=10)


In [None]:
WINDOW_SIZE = 1

# create task dictionary
task = model.create_tasks(
    goals={"image_primary": goal_image[None]}
)  # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])  # for language conditioned

In [None]:
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
input_images = np.stack(images[0 : 0 + WINDOW_SIZE])[None]
observation = {
    "image_primary": input_images,
    "timestep_pad_mask": np.full((1, input_images.shape[1]), True, dtype=bool),
}

rollout, token_types = model.analyze_attention(observation, task)

In [None]:
from crossformer.utils import visualization_utils
#import importlib
#importlib.reload(visualization_utils)
fig = visualization_utils.plot_readout_attention(rollout, token_types, "readout_nav", observation, observation_type="_primary", observation_image=observation["image_primary"][0,0])