In [1]:
from absl import app, flags, logging
import flax
import jax
import optax
import tensorflow as tf
import tqdm
import wandb

from octo.data.dataset import make_single_dataset
from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_utils import (
    freeze_weights,
    merge_params,
    process_text,
    TrainState,
)

2024-04-18 19:28:05.468640: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-18 19:28:05.468681: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-18 19:28:05.470843: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
dataset = make_single_dataset(
    dataset_kwargs=dict(
        name="aloha_sim_cube_scripted_dataset",
        data_dir='/root/autodl-tmp/aloha_sim_dataset',
        image_obs_keys={"primary": "top"},
        state_obs_keys=["state"],
        language_key="language_instruction",
        action_proprio_normalization_type=NormalizationType.NORMAL,
        absolute_action_mask=[True] * 14,
    ),
    traj_transform_kwargs=dict(
        window_size=1,
        future_action_window_size=49,  # so we get 50 actions for our action chunk
    ),
    frame_transform_kwargs=dict(
        resize_size={"primary": (256, 256)},
    ),
    train=True,
)
original_dataset = dataset

traj {'action': <tf.Tensor 'args_3:0' shape=(None, 14) dtype=float32>, 'observation': {'state': <tf.Tensor 'args_9:0' shape=(None, 14) dtype=float32>, 'top': <tf.Tensor 'args_10:0' shape=(None,) dtype=string>}, 'discount': <tf.Tensor 'args_4:0' shape=(None,) dtype=float32>, 'is_first': <tf.Tensor 'args_5:0' shape=(None,) dtype=bool>, 'language_instruction': <tf.Tensor 'args_8:0' shape=(None,) dtype=string>, 'is_terminal': <tf.Tensor 'args_7:0' shape=(None,) dtype=bool>, 'reward': <tf.Tensor 'args_11:0' shape=(None,) dtype=float32>, 'is_last': <tf.Tensor 'args_6:0' shape=(None,) dtype=bool>, 'traj_metadata': {'episode_metadata': {'file_path': <tf.Tensor 'args_12:0' shape=(None,) dtype=string>}}, '_len': <tf.Tensor 'args_1:0' shape=(None,) dtype=int32>, '_traj_index': <tf.Tensor 'args_2:0' shape=(None,) dtype=int64>, '_frame_index': <tf.Tensor 'args_0:0' shape=(None,) dtype=int32>}
image_obs_keys {'primary': 'top'}
state_obs_keys ['state']
{
    "action": {
        "max": "Shape: (14,)",

In [3]:
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(10000)  # can reduce this if RAM consumption too high
    .batch(128)
    .iterator()
)

In [4]:
pretrained_model = OctoModel.load_pretrained('/data1/zhuxiaopei/octo-small')
text_processor = pretrained_model.text_processor
def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch
train_data_iter = map(process_batch, train_data_iter)

NotFoundError: /data1/zhuxiaopei/octo-small/config.json; No such file or directory

In [None]:
import json
print(json.dumps(pretrained_model.config, indent = 4))

In [None]:
example_batch = next(train_data_iter)

In [None]:
import jax.numpy as jnp
import numpy as np
import json

def print_shape_or_value(x):
    if isinstance(x, (jnp.ndarray, np.ndarray, tf.Tensor)):
        return f"Shape: {x.shape}"
    else:
        return x

def apply_to_nested_dict(func, d):
    if isinstance(d, dict):
        return {k: apply_to_nested_dict(func, v) for k, v in d.items()}
    else:
        return func(d)

converted_tree = jax.tree_util.tree_map(print_shape_or_value, example_batch)
formatted_output = json.dumps(converted_tree, indent=4)
print(formatted_output)

In [None]:
print(example_batch['task']['language_instruction']['attention_mask'][0])

In [None]:
print(example_batch['task']['language_instruction']['input_ids'][0])

In [None]:
print(example_batch['task']['pad_mask_dict']['language_instruction'][0])

In [None]:
print(type(example_batch))

In [None]:
take_one = original_dataset.take(1)
for step in take_one:
    print(json.dumps(jax.tree_map(print_shape_or_value, step), indent=4))
    print(step['task']['language_instruction'])