# Step 1: Minimal OCTO Inference Example

This Colab demonstrates how to load a pre-trained / finetuned OCTO checkpoint, run inference on some offline images and compare the outputs to the true actions.

First, let's start with a minimal example!

In [None]:
from octo.model.octo_model import OCTOModel

model = OCTOModel.load_pretrained("hf://rail-berkeley/octo-small")

In [None]:
# download one example BridgeV2 image
!wget https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol1_toykitchen1/many_skills/0/2023-03-15_14-35-28/raw/traj_group0/traj0/images0/im_0.jpg

In [None]:
import cv2
import matplotlib.pyplot as plt

img = cv2.imread("im_0.jpg")
img = cv2.resize(img, (256, 256))[..., ::-1]
plt.imshow(img)
plt.show()

In [None]:
# create obs & task dict, run inference
import jax

observation = {"image_primary": img[None, None]}   # add batch + time horizon 1
task = model.create_tasks(texts=["pick up the fork".encode()])
action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
print(action)   # [batch, action_chunk, action_dim]

# Step 2: Run Inference on Full Trajectories

That was easy! Now let's try to run inference across a whole trajectory and visualize the results!

In [None]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu' # Force on CPU

import cv2
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm
import rlds
import mediapy as media
import numpy as np
from PIL import Image
from IPython import display

## Load Model Checkpoint
First, we will load the pre-trained checkpoint using the `load_pretrained()` function. You can simply feed the path to a checkpoint directory or the HuggingFace path of your OCTO model of choice.

Below, we are loading directly from HuggingFace.


In [None]:
from octo.model.octo_model import OCTOModel

model = OCTOModel.load_pretrained("hf://rail-berkeley/octo-small")

## Load Datasets
Next, we will load a trajectory from the bridge dataset for testing the model. We will use the publicly available copy in the Open X-Embodiment dataset bucket.

In [None]:
# create RLDS dataset builder
builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')
ds = builder.as_dataset(split='train[:1]')

# sample episode + resize to 256x256 (default third-person cam resolution)
episode = next(iter(ds))
steps = list(episode['steps'])
images = [cv2.resize(np.array(step['observation']['image']), (256, 256)) for step in steps]

# extract goal image & language instruction
goal_image = images[-1]
language_instruction = steps[0]['observation']['natural_language_instruction'].numpy().decode()

# visualize episode
print(f'Instruction: {language_instruction}')
media.show_video(images, fps=10)

## Run Inference

Next, we will run inference over the images in the episode using the loaded model. 
Below we demonstrate setups for both, goal-conditioned and language-conditioned training.
Note that we need to feed inputs of the correct temporal window size.

In [None]:
WINDOW_SIZE = 2

# Jit the sample_actions function for speed
policy_fn = jax.jit(model.sample_actions)

# create `task` dict
task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])                  # for language conditioned

In [None]:
# run inference loop, this model only uses single image observations for bridge
# collect predicted and true actions
pred_actions, true_actions = [], []
for step in range(tqdm.tqdm(len(images) - WINDOW_SIZE + 1)):
    input_images = np.stack(images[step : step + WINDOW_SIZE])[None]
    observation = {
        'image_primary': input_images,
        'pad_mask': np.array([[True, True]]),
    }
    
    # this returns *normalized* actions --> we need to unnormalize using the dataset statistics
    norm_actions = policy_fn(observation, task, rng=jax.random.PRNGKey(0))
    norm_actions = norm_actions[0]   # remove batch
    
    actions = (
        norm_actions * model.dataset_statistics['action']['std']
        + model.dataset_statistics['action']['mean']
    )
    
    pred_actions.append(actions)
    true_actions.append(np.concatenate(
        (
            steps[step+1]['action']['world_vector'], 
            steps[step+1]['action']['rotation_delta'], 
            np.array(steps[step+1]['action']['open_gripper']).float()[None]
        ), axis=-1
    ))

## Visualize predictions and ground-truth actions

Finally, we will visualize the predicted actions in comparison to the groundtruth actions.

In [None]:
import matplotlib.pyplot as plt

ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']

# build image strip to show above actions
img_strip = np.concatenate(np.array(images[::3]), axis=1)

# set up plt figure
figure_layout = [
    ['image'] * len(ACTION_DIM_LABELS),
    ACTION_DIM_LABELS
]
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplot_mosaic(figure_layout)
fig.set_size_inches([45, 10])

# plot actions
pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()
for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
  axs[action_label].plot(pred_actions[:, action_dim], label='predicted action')
  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
  axs[action_label].set_title(action_label)
  axs[action_label].set_xlabel('Time in one episode')

axs['image'].imshow(img_strip)
axs['image'].set_xlabel('Time in one episode (subsampled)')
plt.legend()