In [None]:
from absl import app, flags, logging
import flax
import jax
import optax
import tensorflow as tf
import tqdm
import wandb

from octo.data.dataset import make_single_dataset
from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_utils import (
    freeze_weights,
    merge_params,
    process_text,
    TrainState,
)

In [None]:
dataset = make_single_dataset(
    dataset_kwargs=dict(
        name="aloha_sim_cube_scripted_dataset",
        data_dir='/root/autodl-fs/aloha_sim_dataset',
        image_obs_keys={"primary": "top"},
        state_obs_keys=["state"],
        language_key="language_instruction",
        action_proprio_normalization_type=NormalizationType.NORMAL,
        absolute_action_mask=[True] * 14,
    ),
    traj_transform_kwargs=dict(
        window_size=1,
        future_action_window_size=49,  # so we get 50 actions for our action chunk
    ),
    frame_transform_kwargs=dict(
        resize_size={"primary": (256, 256)},
    ),
    train=True,
)
original_dataset = dataset

In [None]:
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(10000)  # can reduce this if RAM consumption too high
    .batch(128)
    .iterator()
)

In [None]:
pretrained_model = OctoModel.load_pretrained('./weights/octo-base')
text_processor = pretrained_model.text_processor
def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch
train_data_iter = map(process_batch, train_data_iter)

In [None]:
import json
print(json.dumps(pretrained_model.config, indent = 4))

In [None]:
example_batch = next(train_data_iter)

In [None]:
import jax.numpy as jnp
import numpy as np
import json

def print_shape_or_value(x):
    if isinstance(x, (jnp.ndarray, np.ndarray, tf.Tensor)):
        return f"Shape: {x.shape}"
    else:
        return x

def apply_to_nested_dict(func, d):
    if isinstance(d, dict):
        return {k: apply_to_nested_dict(func, v) for k, v in d.items()}
    else:
        return func(d)

converted_tree = jax.tree_util.tree_map(print_shape_or_value, example_batch)
formatted_output = json.dumps(converted_tree, indent=4)
print(formatted_output)

In [None]:
print(example_batch['task']['language_instruction']['attention_mask'][0])

In [None]:
print(example_batch['task']['language_instruction']['input_ids'][0])

In [None]:
print(example_batch['task']['pad_mask_dict']['language_instruction'][0])

In [None]:
cnt = 0
for batch in original_dataset:
    cnt += 1
    print(json.dumps(apply_to_nested_dict(print_shape_or_value,batch), indent = 4))
print(400*cnt)