In [1]:
import datetime
from functools import partial
import os

import jax.numpy as jnp
from absl import app, flags, logging
import flax
from flax.traverse_util import flatten_dict, unflatten_dict
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from ml_collections import config_flags, ConfigDict
import importlib
from octo.model.components.tokenizers import BinTokenizer, LowdimObsTokenizer, ImageTokenizer, UnsqueezingImageTokenizer, ProjectionTokenizer, SiglipTokenizer
import optax
import tensorflow as tf
import tqdm
import wandb

from octo.model.components.vit_encoders import SmallStem16
from octo.data.dataset import make_single_dataset
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_callbacks import (
    RolloutVisualizationCallback,
    SaveCallback,
    ValidationCallback,
    VisualizationCallback,
)
from octo.utils.train_utils import (
    check_config_diff,
    create_optimizer,
    format_name_with_config,
    merge_params,
    process_text,
    Timer,
    TrainState,
)

try:
    from jax_smi import initialise_tracking  # type: ignore

    initialise_tracking()
except ImportError:
    pass



2024-06-07 04:36:43.699365: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-07 04:36:43.699426: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-07 04:36:43.701219: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# FLAGS = flags.FLAGS

# flags.DEFINE_string("name", "experiment", "Experiment name.")
# flags.DEFINE_bool("debug", False, "Debug config (no wandb logging)")

# # default_config_file = os.path.join(
# #     os.path.dirname(__file__), "configs/finetune_config.py"
# # )
# config_flags.DEFINE_config_file(
#     "config",
#     "/home/joshwajones/octo_digit/scripts/configs/josh_finetune_config.py:None",
#     "File path to the training hyperparameter configuration.",
#     lock_config=False,
# )


In [2]:
from ml_collections import ConfigDict
from ml_collections.config_dict import FieldReference, placeholder
from octo.model.components.action_heads import MSEActionHead
from octo.model.components.tokenizers import BinTokenizer, LowdimObsTokenizer, ImageTokenizer, UnsqueezingImageTokenizer, ProjectionTokenizer, SiglipTokenizer
from octo.model.components.vit_encoders import SmallStem16

from octo.utils.spec import ModuleSpec

def get_config(config_string=None):
    # config_string = "full,language_conditioned"
    config_string = "full,multimodal"
    mode, task = config_string.split(",")
    assert task in ["image_conditioned", "language_conditioned", "multimodal"]
    assert mode in ["full", "head_only", "head_mlp_only"]

    # Fill this in for your own dataset!

    # There should be two image keys
    # first image key should be the third-person view (None if not used)
    # and second image key should be the wrist view (None if not used)
    CAMS_ONLY = True 

    FINETUNING_KWARGS = {
        "name": "digit_dataset:8.8.0",
        # "data_dir": "gs://619c8f721786ba/",
        "data_dir": "/home/joshwajones/tensorflow_datasets/",
        "image_obs_keys": {
            "primary": "image_0",
            "wrist": "image_1",
        },
        "digit_obs_keys": {}, # TODO: remove this, treating digits just as images
        "proprio_obs_key": None, # "state",
        "sensor_obs_keys": {},
        "language_key": "language_instruction",
        # We want to avoid normalizing the gripper
        "action_normalization_mask": [True, True, True, True, True, True, False],
        # standardize_fn is dynamically loaded from a file
        # for example: "experiments/kevin/custom_standardization_transforms.py:aloha_dataset_transform"
        "standardize_fn": ModuleSpec.create(
            "octo.data.oxe.oxe_standardization_transforms:bridge_dataset_transform",
        ),
        
        # If the default data loading speed is too slow, try these:
        # "num_parallel_reads": 8,  # for reading from disk / GCS
        # "num_parallel_calls": 16,  # for initial dataset construction
    }
    if not CAMS_ONLY: 
        digit_names = { 
            "digit_left": "digit_0",
            "digit_right": "digit_1",   
        }
        FINETUNING_KWARGS["image_obs_keys"].update(digit_names)

        sensor_names = { 
            "spectro": "mel_spectro",
            "imu": "imu",
            "digit_0_embedding": "digit_0_embedding", 
            "digit_1_embedding": "digit_1_embedding",
        }
        FINETUNING_KWARGS["sensor_obs_key"].update(sensor_names)



    # NEW_OBS_TOKENIZERS = { 
    #     "digits": 
    #         ModuleSpec.create(
    #             ImageTokenizer,
    #             obs_stack_keys=["image_digit_left", "image_digit_right"],
    #             task_stack_keys=[],
    #             encoder=ModuleSpec.create(SmallStem16),
    #         ),
    #     "spectrogram": 
    #         ModuleSpec.create( 
    #             UnsqueezingImageTokenizer,
    #             obs_stack_keys = ["spectro"], 
    #             task_stack_keys=[], 
    #             encoder=ModuleSpec.create(SmallStem16),
    #         ),
    #     "imu": ModuleSpec.create( 
    #             ProjectionTokenizer,
    #             num_output_tokens=7,
    #             n_bins=256,
    #             obs_keys=["imu"], 
    #         ),
    #    "digit_embeddings": ModuleSpec.create( 
    #             ProjectionTokenizer,
    #             num_output_tokens=7,
    #             n_bins=256,
    #             obs_keys=["digit_0_embedding", "digit_1_embedding"], 
    #         ),
    #     "siglip": {
    #             "freeze": False, # True, 
    #             "config": { 
    #                 'encoder_path':'/home/sjosh/nfs/octo_digit/siglip.npz:img', # '/home/joshwajones/octo/siglip.npz:img',
    #                 'image_model': 'vit', 
    #                 'image': dict(variant='B/16', pool_type='map')
    #             }
    #         }, 
    # }
    NEW_OBS_TOKENIZERS = {}

    NEW_ACTION_HEAD = None 
    NEW_ACTION_HEAD = ModuleSpec.create(
        MSEActionHead,
        readout_key="readout_action",
        use_map = False, # should this be disabled? 
        action_horizon=4,
        action_dim=7
    )
    
        


    if mode == "full":
        frozen_keys = None
    elif mode == "head_only":
        frozen_keys = ("octo_transformer.*",)
    elif mode == "head_mlp_only":
        frozen_keys = (
            "octo_transformer.*",
            "heads_*.map_head.probe",
            "heads_*.map_head.MultiHeadDotProductAttention_0.*",
        )
    elif mode == "frozen_transformer":
        frozen_keys = ("octo_transformer.BlockTransformer_0.*",)
    else:
        raise ValueError("Invalid mode")

    max_steps = FieldReference(50000)
    # max_steps = FieldReference(25000)
    window_size = FieldReference(default=2)

    config = dict(
        pretrained_path="hf://rail-berkeley/octo-small",
        batch_size=256,
        shuffle_buffer_size=10000,
        num_steps=max_steps,
        log_interval=100,
        eval_interval=200,
        save_interval=5000,
        # save_dir="/home/joshwajones/octo_save_dir/",
        save_dir="gs://619c8f721786ba/octo_ckpts/",
	seed=42,
        wandb=dict(
            project="octo", group=placeholder(str), entity=placeholder(str)
        ),
        dataset_kwargs=FINETUNING_KWARGS,
        modality=task,
        finetuning_mode=mode,
        window_size=window_size,
        optimizer=dict(
            learning_rate=dict(
                name="cosine",
                init_value=0.0,
                peak_value=3e-4,
                warmup_steps=2000,
                decay_steps=max_steps,
                end_value=0.0,
            ),
            weight_decay=0.01,
            clip_gradient=1.0,
            frozen_keys=frozen_keys,
            grad_accumulation_steps=None,  # if you are using grad accumulation, you need to adjust max_steps accordingly
        ),
        val_kwargs=dict(
            val_shuffle_buffer_size=1000,
            num_val_batches=16,
        ),
        viz_kwargs=dict(
            eval_batch_size=64,
            trajs_for_metrics=100,
            trajs_for_viz=8,
            samples_per_state=8,
        ),
    )

    if "siglip" in NEW_OBS_TOKENIZERS: 
        should_freeze, siglip_cfg = NEW_OBS_TOKENIZERS["siglip"]["freeze"], NEW_OBS_TOKENIZERS["siglip"]["config"]
        config["siglip_config"] = siglip_cfg
        NEW_OBS_TOKENIZERS["siglip"] = ModuleSpec.create( 
                    SiglipTokenizer,
                    image=siglip_cfg["image"],
                    image_model=siglip_cfg["image_model"],
                    encoder_path=siglip_cfg["encoder_path"],
                    n_bins=256,
                    obs_keys=['siglip'],
        )

        if should_freeze: 
            prev_frozen = frozen_keys if frozen_keys else ()
            config["optimizer"]["frozen_keys"] = prev_frozen + ("octo_transformer.observation_tokenizers_siglip.*", "*hf_model*")


    if task == "image_conditioned":
        goal_relabeling_strategy = "uniform"
        keep_image_prob = 1.0
    elif task == "language_conditioned":
        goal_relabeling_strategy = None
        keep_image_prob = 0.0
    elif task == "multimodal":
        goal_relabeling_strategy = "uniform"
        keep_image_prob = 0.5
    else:
        raise ValueError("Invalid modality")


    traj_transform_kwargs = dict(
        window_size=window_size,
        action_horizon=4,
        goal_relabeling_strategy=goal_relabeling_strategy,
        task_augment_strategy="delete_task_conditioning",
        task_augment_kwargs=dict(
            keep_image_prob=keep_image_prob,
        ),
        # If the default data loading speed is too slow, try these:
        # num_parallel_calls=16,  # for less CPU-intensive ops
    )

    workspace_augment_kwargs = dict(
        random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
        random_brightness=[0.1],
        random_contrast=[0.9, 1.1],
        random_saturation=[0.9, 1.1],
        random_hue=[0.05],
        augment_order=[
            "random_resized_crop",
            "random_brightness",
            "random_contrast",
            "random_saturation",
            "random_hue",
        ],
    )
    wrist_augment_kwargs = dict(
        random_brightness=[0.1],
        random_contrast=[0.9, 1.1],
        random_saturation=[0.9, 1.1],
        random_hue=[0.05],
        augment_order=[
            "random_brightness",
            "random_contrast",
            "random_saturation",
            "random_hue",
        ],
    )
    digit_augment_kwargs = dict(
        random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
        augment_order=[
            "random_resized_crop"
        ],
    )
    frame_transform_kwargs = dict(
        resize_size={
            "primary": (256, 256),  # workspace (3rd person) camera is at 256x256
            "wrist": (128, 128),   # wrist camera is at 128x128
            "digit_left": (256, 256), #(128, 128),
            "digit_right": (256, 256),  # (128, 128)
        },
        image_augment_kwargs = { 
            "primary": workspace_augment_kwargs, 
            "wrist": wrist_augment_kwargs, 
            "digit_left": digit_augment_kwargs,
            "digit_right": digit_augment_kwargs,
        }
    )
    # If the default data loading speed is too slow, try these:
    config[
        "frame_transform_threads"
    ] = 16  # for the most CPU-intensive ops (decoding, resizing, augmenting)

    config["traj_transform_kwargs"] = traj_transform_kwargs
    config["frame_transform_kwargs"] = frame_transform_kwargs
    config["new_obs_tokenizers"] = NEW_OBS_TOKENIZERS

    config['update_config'] = { 
        "model":  {
            "repeat_task_tokens": True,
        }
    }
    if NEW_ACTION_HEAD is not None: 
        config['update_config']['model']['heads'] = { 
            'action': ConfigDict(NEW_ACTION_HEAD)
        } 
    config['update_config'] = ConfigDict(config['update_config'])

    return ConfigDict(config)

CONFIG = get_config(None)
FLAGS = { 
    "name": "experiment", 
    "debug": True, 
    "config": CONFIG, 
} 
FLAGS = ConfigDict(FLAGS)

In [4]:
MAX_KEY_LEN = 15
INDENT_SIZE = MAX_KEY_LEN + 4
INDENT = ''.join([' ' for _ in range(INDENT_SIZE)])
def recursive_dict_print(dictionary, prefix=""): 
    for key, val in dictionary.items(): 
        key = key[:MAX_KEY_LEN]
        if isinstance(val, dict): 
            print(f'{prefix}{key}')
            new_prefix = prefix + INDENT
            recursive_dict_print(val, new_prefix)
        else: 
            indent = ''.join([' ' for _ in range(INDENT_SIZE - len(key))])
            print(f'{prefix}{key}:{indent}{val.shape}')



initialize_compilation_cache()
devices = jax.devices()
logging.info(
    f"""
    Octo Finetuning Script
    ======================
    Pretrained model: {FLAGS.config.pretrained_path}
    Finetuning Dataset: {FLAGS.config.dataset_kwargs.name}
    Data dir: {FLAGS.config.dataset_kwargs.data_dir}
    Task Modality: {FLAGS.config.modality}
    Finetuning Mode: {FLAGS.config.finetuning_mode}

    # Devices: {jax.device_count()}
    Batch size: {FLAGS.config.batch_size} ({FLAGS.config.batch_size // len(devices) } per device)
    # Steps: {FLAGS.config.num_steps}
"""
)

#########
#
# Setup Jax Data Parallelism
#
#########

assert (
    FLAGS.config.batch_size % len(devices) == 0
), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})"
assert (
    FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0
), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})"

# create a 1D mesh with a single axis named "batch"
mesh = Mesh(jax.devices(), axis_names="batch")
# Our batches will be data-parallel sharded -- each device will get a slice of the batch
dp_sharding = NamedSharding(mesh, PartitionSpec("batch"))
# Our model will be replicated across devices (we are only doing data parallelism, not model parallelism)
replicated_sharding = NamedSharding(mesh, PartitionSpec())

# prevent tensorflow from using GPU memory since it's only used for data loading
tf.config.set_visible_devices([], "GPU")

#########
#
# Setup WandB
#
#########

name = format_name_with_config(
    FLAGS.name,
    FLAGS.config.to_dict(),
)
# wandb_id = "{name}_{time}".format(
#     name=name,
#     time=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"),
# )
# wandb.init(
#     config=FLAGS.config.to_dict(),
#     id=wandb_id,
#     name=name,
#     mode="disabled" if FLAGS.debug else None,
#     **FLAGS.config.wandb,
# )

#########
#
# Load Pretrained model + optionally modify config
#
#########
pretrained_model_kwargs = {
    "checkpoint_path": "/home/joshwajones/tpu_octo_ckpts/standard_largebatch_20240605_015537"
    # "checkpoint_path": FLAGS.config.pretrained_path
}
for step in range(5000, 50001, 5000):
        pretrained_model_kwargs["step"] = step 
pretrained_model = OctoModel.load_pretrained(
    **pretrained_model_kwargs
)
rng = jax.random.PRNGKey(FLAGS.config.seed)
rng, init_rng = jax.random.split(rng)
model = pretrained_model

flat_config = flax.traverse_util.flatten_dict(
    pretrained_model.config, keep_empty_nodes=True
)

config = ConfigDict(flax.traverse_util.unflatten_dict(flat_config))
# config.update(FLAGS.config.get("update_config", ConfigDict()))
config = config.to_dict()
# check_config_diff(config, pretrained_model.config)
#########
#
# Setup Data Loader
#
#########

# create text processor
if config["text_processor"] is None:
    text_processor = None
else:
    text_processor = ModuleSpec.instantiate(config["text_processor"])()

def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch

params = model.params
if FLAGS.config.optimizer.frozen_keys is None:
    FLAGS.config.optimizer.frozen_keys = model.config["optimizer"]["frozen_keys"]

tx, lr_callable, param_norm_callable = create_optimizer(
    params,
    **FLAGS.config.optimizer.to_dict(),
)
train_state = TrainState.create(
    model=model,
    tx=tx,
    rng=rng,
)

if FLAGS.config.modality == "image_conditioned":
    modes_to_evaluate = ["image_conditioned"]
elif FLAGS.config.modality == "text_conditioned":
    modes_to_evaluate = ["text_conditioned"]
elif FLAGS.config.modality == "multimodal":
    modes_to_evaluate = ["image_conditioned", "text_conditioned"]
else:
    modes_to_evaluate = ["base"]

dataset_kwargs_list = [FLAGS.config.dataset_kwargs]

# viz_callback = VisualizationCallback(
#     text_processor=text_processor,
#     val_dataset_kwargs_list=dataset_kwargs_list,
#     dataset_kwargs=FLAGS.config,
#     modes_to_evaluate=modes_to_evaluate,
#     **FLAGS.config.viz_kwargs,
# )





In [None]:

# create text processor
if config["text_processor"] is None:
    text_processor = None
else:
    text_processor = ModuleSpec.instantiate(config["text_processor"])()

def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch

dataset = make_single_dataset(
    FLAGS.config.dataset_kwargs,
    traj_transform_kwargs=FLAGS.config.traj_transform_kwargs,
    frame_transform_kwargs=FLAGS.config.frame_transform_kwargs,
    train=True,
)
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(FLAGS.config.shuffle_buffer_size)
    .batch(FLAGS.config.batch_size)
    .iterator()
)
train_data_iter = map(process_batch, train_data_iter)
example_batch = next(train_data_iter)


In [5]:

# FLAGS.config.viz_kwargs['trajs_for_viz'] = 8
# FLAGS.config.viz_kwargs['samples_per_state'] = 8
# FLAGS.config.viz_kwargs['trajs_for_metrics'] = 1
# viz_callback = VisualizationCallback(
#     text_processor=text_processor,
#     val_dataset_kwargs_list=dataset_kwargs_list,
#     dataset_kwargs=FLAGS.config,
#     modes_to_evaluate=modes_to_evaluate,
#     **FLAGS.config.viz_kwargs,
# )
# viz_metrics = viz_callback(train_state, 0)

Cause: Unable to locate the source code of <function _gcd_import at 0x7f2f1fa6b400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


Cause: Unable to locate the source code of <function _gcd_import at 0x7f2f1fa6b400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


Cause: Unable to locate the source code of <function _gcd_import at 0x7f2f1fa6b400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


  0%|          | 0/1 [00:15<?, ?it/s]


TypeError: unnormalize() missing 1 required positional argument: 'mask'