In [1]:
from absl import app, flags, logging
import flax
import jax
import optax
import tensorflow as tf
import tqdm
import wandb

from octo.data.dataset import make_single_dataset
from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_utils import (
    freeze_weights,
    merge_params,
    process_text,
    TrainState,
)

2024-04-13 09:47:51.941767: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-13 09:47:51.941819: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-13 09:47:51.944014: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = make_single_dataset(
    dataset_kwargs=dict(
        name="aloha_sim_cube_scripted_dataset",
        data_dir='/data1/zhuxiaopei/aloha_sim_dataset',
        image_obs_keys={"primary": "top"},
        state_obs_keys=["state"],
        language_key="language_instruction",
        action_proprio_normalization_type=NormalizationType.NORMAL,
        absolute_action_mask=[True] * 14,
    ),
    traj_transform_kwargs=dict(
        window_size=1,
        future_action_window_size=49,  # so we get 50 actions for our action chunk
    ),
    frame_transform_kwargs=dict(
        resize_size={"primary": (256, 256)},
    ),
    train=True,
)
original_dataset = dataset

In [4]:
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(10000)  # can reduce this if RAM consumption too high
    .batch(128)
    .iterator()
)

In [6]:
pretrained_model = OctoModel.load_pretrained('/data1/zhuxiaopei/octo-small')
text_processor = pretrained_model.text_processor
def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch
train_data_iter = map(process_batch, train_data_iter)

In [7]:
import json
print(json.dumps(pretrained_model.config, indent = 4))

{
    "seed": 42,
    "num_steps": 300000,
    "save_dir": null,
    "model": {
        "observation_tokenizers": {
            "primary": {
                "module": "octo.model.components.tokenizers",
                "name": "ImageTokenizer",
                "args": [],
                "kwargs": {
                    "obs_stack_keys": [
                        "image_primary"
                    ],
                    "task_stack_keys": [
                        "image_primary"
                    ],
                    "encoder": {
                        "module": "octo.model.components.vit_encoders",
                        "name": "SmallStem16",
                        "args": [],
                        "kwargs": {}
                    }
                }
            },
            "wrist": {
                "module": "octo.model.components.tokenizers",
                "name": "ImageTokenizer",
                "args": [],
                "kwargs": {
                    "obs_stac

In [8]:
example_batch = next(train_data_iter)

In [9]:
import jax.numpy as jnp
import numpy as np
import json

def print_shape_or_value(x):
    if isinstance(x, (jnp.ndarray, np.ndarray, tf.Tensor)):
        return f"Shape: {x.shape}"
    else:
        return x

def apply_to_nested_dict(func, d):
    if isinstance(d, dict):
        return {k: apply_to_nested_dict(func, v) for k, v in d.items()}
    else:
        return func(d)

converted_tree = jax.tree_util.tree_map(print_shape_or_value, example_batch)
formatted_output = json.dumps(converted_tree, indent=4)
print(formatted_output)

{
    "absolute_action_mask": "Shape: (128, 14)",
    "action": "Shape: (128, 50, 14)",
    "observation": {
        "image_primary": "Shape: (128, 1, 256, 256, 3)",
        "pad_mask": "Shape: (128, 1)",
        "pad_mask_dict": {
            "image_primary": "Shape: (128, 1)",
            "proprio": "Shape: (128, 1)",
            "timestep": "Shape: (128, 1)"
        },
        "proprio": "Shape: (128, 1, 14)",
        "timestep": "Shape: (128, 1)"
    },
    "task": {
        "language_instruction": {
            "attention_mask": "Shape: (128, 16)",
            "input_ids": "Shape: (128, 16)"
        },
        "pad_mask_dict": {
            "language_instruction": "Shape: (128,)"
        }
    }
}


In [10]:
print(example_batch['task']['language_instruction']['attention_mask'][0])

[1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0]


In [11]:
print(example_batch['task']['language_instruction']['input_ids'][0])

[1432   95    8  123  346   11  609   34  147    1    0    0    0    0
    0    0]


In [12]:
print(example_batch['task']['pad_mask_dict']['language_instruction'][0])

True


In [14]:
print(type(example_batch))

<class 'dict'>


In [22]:
take_one = original_dataset.take(1)
for step in take_one:
    print(json.dumps(jax.tree_map(print_shape_or_value, step), indent=4))
    print(step['task']['language_instruction'])

{
    "absolute_action_mask": "Shape: (400, 14)",
    "action": "Shape: (400, 50, 14)",
    "dataset_name": "Shape: (400,)",
    "observation": {
        "image_primary": "Shape: (400, 1, 256, 256, 3)",
        "pad_mask": "Shape: (400, 1)",
        "pad_mask_dict": {
            "image_primary": "Shape: (400, 1)",
            "proprio": "Shape: (400, 1)",
            "timestep": "Shape: (400, 1)"
        },
        "proprio": "Shape: (400, 1, 14)",
        "timestep": "Shape: (400, 1)"
    },
    "task": {
        "language_instruction": "Shape: (400,)",
        "pad_mask_dict": {
            "language_instruction": "Shape: (400,)"
        }
    }
}
tf.Tensor(
[b'pick up the cube and hand it over' b'pick up the cube and hand it over'
 b'pick up the cube and hand it over' b'pick up the cube and hand it over'
 b'pick up the cube and hand it over' b'pick up the cube and hand it over'
 b'pick up the cube and hand it over' b'pick up the cube and hand it over'
 b'pick up the cube and hand i