In [1]:
import os
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import cv2
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np
import imageio

import datetime
from functools import partial
import os

import jax.numpy as jnp
from absl import app, flags, logging
import flax
from flax.traverse_util import flatten_dict, unflatten_dict
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from ml_collections import config_flags, ConfigDict
import importlib
from octo.model.components.tokenizers import BinTokenizer, LowdimObsTokenizer, ImageTokenizer, UnsqueezingImageTokenizer, ProjectionTokenizer, SiglipTokenizer
import optax
import tensorflow as tf
import tqdm
import wandb
from octo.model.components.vit_encoders import ResNet26, SmallStem32

from octo.model.components.vit_encoders import SmallStem16
from octo.data.dataset import make_single_dataset
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_callbacks import (
    RolloutVisualizationCallback,
    SaveCallback,
    ValidationCallback,
    VisualizationCallback,
    GradCAMVisualizationCallback,
)
from octo.utils.train_utils import (
    check_config_diff,
    create_optimizer,
    format_name_with_config,
    merge_params,
    process_text,
    Timer,
    TrainState,
)

try:
    from jax_smi import initialise_tracking  # type: ignore

    initialise_tracking()
except ImportError:
    pass


2024-06-07 04:58:26.568149: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-07 04:58:26.568213: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-07 04:58:26.569850: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
from ml_collections import ConfigDict
from ml_collections.config_dict import FieldReference, placeholder
from octo.model.components.action_heads import MSEActionHead
from octo.model.components.tokenizers import BinTokenizer, LowdimObsTokenizer, ImageTokenizer, UnsqueezingImageTokenizer, ProjectionTokenizer, SiglipTokenizer
from octo.model.components.vit_encoders import SmallStem16

from octo.utils.spec import ModuleSpec

def get_config(config_string=None):
    # config_string = "full,language_conditioned"
    config_string = "full,multimodal"
    mode, task = config_string.split(",")
    assert task in ["image_conditioned", "language_conditioned", "multimodal"]
    assert mode in ["full", "head_only", "head_mlp_only"]

    # Fill this in for your own dataset!

    # There should be two image keys
    # first image key should be the third-person view (None if not used)
    # and second image key should be the wrist view (None if not used)
    CAMS_ONLY = True 

    FINETUNING_KWARGS = {
        "name": "digit_dataset:19.0.0",
        "data_dir": "/home/joshwajones/tensorflow_datasets/",
        "image_obs_keys": {
            "primary": "image_0",
            "wrist": "image_1",
        },
        "digit_obs_keys": {}, # TODO: remove this, treating digits just as images
        "proprio_obs_key": None, # "state",
        "sensor_obs_keys": {},
        "language_key": "language_instruction",
        # We want to avoid normalizing the gripper
        "action_normalization_mask": [True, True, True, True, True, True, False],
        # standardize_fn is dynamically loaded from a file
        # for example: "experiments/kevin/custom_standardization_transforms.py:aloha_dataset_transform"
        "standardize_fn": ModuleSpec.create(
            "octo.data.oxe.oxe_standardization_transforms:bridge_dataset_transform",
        ),
        
        # If the default data loading speed is too slow, try these:
        # "num_parallel_reads": 8,  # for reading from disk / GCS
        # "num_parallel_calls": 16,  # for initial dataset construction
    }
    if not CAMS_ONLY: 
        digit_names = { 
            "digit_left": "digit_0",
            "digit_right": "digit_1",   
        }
        FINETUNING_KWARGS["image_obs_keys"].update(digit_names)

        sensor_names = { 
            "spectro": "mel_spectro",
            "imu": "imu",
            "digit_0_embedding": "digit_0_embedding", 
            "digit_1_embedding": "digit_1_embedding",
        }
        FINETUNING_KWARGS["sensor_obs_key"].update(sensor_names)



    # NEW_OBS_TOKENIZERS = { 
    #     "digits": 
    #         ModuleSpec.create(
    #             ImageTokenizer,
    #             obs_stack_keys=["image_digit_left", "image_digit_right"],
    #             task_stack_keys=[],
    #             encoder=ModuleSpec.create(SmallStem16),
    #         ),
    #     "spectrogram": 
    #         ModuleSpec.create( 
    #             UnsqueezingImageTokenizer,
    #             obs_stack_keys = ["spectro"], 
    #             task_stack_keys=[], 
    #             encoder=ModuleSpec.create(SmallStem16),
    #         ),
    #     "imu": ModuleSpec.create( 
    #             ProjectionTokenizer,
    #             num_output_tokens=7,
    #             n_bins=256,
    #             obs_keys=["imu"], 
    #         ),
    #    "digit_embeddings": ModuleSpec.create( 
    #             ProjectionTokenizer,
    #             num_output_tokens=7,
    #             n_bins=256,
    #             obs_keys=["digit_0_embedding", "digit_1_embedding"], 
    #         ),
    #     "siglip": {
    #             "freeze": False, # True, 
    #             "config": { 
    #                 'encoder_path':'/home/sjosh/nfs/octo_digit/siglip.npz:img', # '/home/joshwajones/octo/siglip.npz:img',
    #                 'image_model': 'vit', 
    #                 'image': dict(variant='B/16', pool_type='map')
    #             }
    #         }, 
    # }
    NEW_OBS_TOKENIZERS = {}

    NEW_ACTION_HEAD = None 
    NEW_ACTION_HEAD = ModuleSpec.create(
        MSEActionHead,
        readout_key="readout_action",
        use_map = False, # should this be disabled? 
        action_horizon=4,
        action_dim=7
    )
    
        


    if mode == "full":
        frozen_keys = None
    elif mode == "head_only":
        frozen_keys = ("octo_transformer.*",)
    elif mode == "head_mlp_only":
        frozen_keys = (
            "octo_transformer.*",
            "heads_*.map_head.probe",
            "heads_*.map_head.MultiHeadDotProductAttention_0.*",
        )
    elif mode == "frozen_transformer":
        frozen_keys = ("octo_transformer.BlockTransformer_0.*",)
    else:
        raise ValueError("Invalid mode")

    max_steps = FieldReference(50000)
    # max_steps = FieldReference(25000)
    window_size = FieldReference(default=2)

    config = dict(
        pretrained_path="hf://rail-berkeley/octo-small",
        batch_size=256,
        shuffle_buffer_size=10000,
        num_steps=max_steps,
        log_interval=100,
        eval_interval=2500,
        save_interval=5000,
        # save_dir="/home/joshwajones/octo_save_dir/",
        save_dir="gs://619c8f721786ba/octo_ckpts/",
	seed=42,
        wandb=dict(
            project="octo", group=placeholder(str), entity=placeholder(str)
        ),
        dataset_kwargs=FINETUNING_KWARGS,
        modality=task,
        finetuning_mode=mode,
        window_size=window_size,
        optimizer=dict(
            learning_rate=dict(
                name="cosine",
                init_value=0.0,
                peak_value=3e-4,
                warmup_steps=2000,
                decay_steps=max_steps,
                end_value=0.0,
            ),
            weight_decay=0.01,
            clip_gradient=1.0,
            frozen_keys=frozen_keys,
            grad_accumulation_steps=None,  # if you are using grad accumulation, you need to adjust max_steps accordingly
        ),
        val_kwargs=dict(
            val_shuffle_buffer_size=1000,
            num_val_batches=16,
        ),
        viz_kwargs=dict(
            eval_batch_size=64,
            trajs_for_metrics=100,
            trajs_for_viz=8,
            samples_per_state=8,
        ),
        gradcam_kwargs=dict( 
            eval_batch_size=4, 
            shuffle_buffer_size=1000, 
            train=False, 
            gradcam_kwargs_list=(
                    ('obs_primary', {'psuedo_loss_type': 'loss'}),
                    ('obs_wrist', {'psuedo_loss_type': 'loss'})
            )
        )
    )

    if "siglip" in NEW_OBS_TOKENIZERS: 
        should_freeze, siglip_cfg = NEW_OBS_TOKENIZERS["siglip"]["freeze"], NEW_OBS_TOKENIZERS["siglip"]["config"]
        config["siglip_config"] = siglip_cfg
        NEW_OBS_TOKENIZERS["siglip"] = ModuleSpec.create( 
                    SiglipTokenizer,
                    image=siglip_cfg["image"],
                    image_model=siglip_cfg["image_model"],
                    encoder_path=siglip_cfg["encoder_path"],
                    n_bins=256,
                    obs_keys=['siglip'],
        )

        if should_freeze: 
            prev_frozen = frozen_keys if frozen_keys else ()
            config["optimizer"]["frozen_keys"] = prev_frozen + ("octo_transformer.observation_tokenizers_siglip.*", "*hf_model*")


    if task == "image_conditioned":
        goal_relabeling_strategy = "uniform"
        keep_image_prob = 1.0
    elif task == "language_conditioned":
        goal_relabeling_strategy = None
        keep_image_prob = 0.0
    elif task == "multimodal":
        goal_relabeling_strategy = "uniform"
        keep_image_prob = 0.5
    else:
        raise ValueError("Invalid modality")


    traj_transform_kwargs = dict(
        window_size=window_size,
        action_horizon=4,
        goal_relabeling_strategy=goal_relabeling_strategy,
        task_augment_strategy="delete_task_conditioning",
        task_augment_kwargs=dict(
            keep_image_prob=keep_image_prob,
        ),
        # If the default data loading speed is too slow, try these:
        # num_parallel_calls=16,  # for less CPU-intensive ops
    )

    workspace_augment_kwargs = dict(
        random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
        random_brightness=[0.1],
        random_contrast=[0.9, 1.1],
        random_saturation=[0.9, 1.1],
        random_hue=[0.05],
        augment_order=[
            "random_resized_crop",
            "random_brightness",
            "random_contrast",
            "random_saturation",
            "random_hue",
        ],
    )
    wrist_augment_kwargs = dict(
        random_brightness=[0.1],
        random_contrast=[0.9, 1.1],
        random_saturation=[0.9, 1.1],
        random_hue=[0.05],
        augment_order=[
            "random_brightness",
            "random_contrast",
            "random_saturation",
            "random_hue",
        ],
    )
    digit_augment_kwargs = dict(
        random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
        augment_order=[
            "random_resized_crop"
        ],
    )
    frame_transform_kwargs = dict(
        resize_size={
            "primary": (256, 256),  # workspace (3rd person) camera is at 256x256
            "wrist": (128, 128),   # wrist camera is at 128x128
            "digit_left": (256, 256), #(128, 128),
            "digit_right": (256, 256),  # (128, 128)
        },
        image_augment_kwargs = { 
            "primary": workspace_augment_kwargs, 
            "wrist": wrist_augment_kwargs, 
            "digit_left": digit_augment_kwargs,
            "digit_right": digit_augment_kwargs,
        }
    )
    # If the default data loading speed is too slow, try these:
    config[
        "frame_transform_threads"
    ] = 16  # for the most CPU-intensive ops (decoding, resizing, augmenting)

    config["traj_transform_kwargs"] = traj_transform_kwargs
    config["frame_transform_kwargs"] = frame_transform_kwargs
    config["new_obs_tokenizers"] = NEW_OBS_TOKENIZERS

    config['update_config'] = { 
        "model":  {
            "repeat_task_tokens": True,
        }
    }
    if NEW_ACTION_HEAD is not None: 
        config['update_config']['model']['heads'] = { 
            'action': ConfigDict(NEW_ACTION_HEAD)
        } 
    config['update_config'] = ConfigDict(config['update_config'])

    return ConfigDict(config)


In [4]:
CONFIG = get_config(None)
FLAGS = { 
    "name": "experiment", 
    "debug": True, 
    "config": CONFIG, 
    "o_window_size": -1, 
    "o_batch_size": -1, 
    "o_steps": -1
} 
FLAGS = ConfigDict(FLAGS)

In [5]:
MAX_KEY_LEN = 15
INDENT_SIZE = MAX_KEY_LEN + 4
INDENT = ''.join([' ' for _ in range(INDENT_SIZE)])
def recursive_dict_print(dictionary, prefix=""): 
    for key, val in dictionary.items(): 
        key = key[:MAX_KEY_LEN]
        if isinstance(val, dict): 
            print(f'{prefix}{key}')
            new_prefix = prefix + INDENT
            recursive_dict_print(val, new_prefix)
        else: 
            indent = ''.join([' ' for _ in range(INDENT_SIZE - len(key))])
            print(f'{prefix}{key}:{indent}{val.shape}')



initialize_compilation_cache()
devices = jax.devices()
logging.info(
    f"""
    Octo Finetuning Script
    ======================
    Pretrained model: {FLAGS.config.pretrained_path}
    Finetuning Dataset: {FLAGS.config.dataset_kwargs.name}
    Data dir: {FLAGS.config.dataset_kwargs.data_dir}
    Task Modality: {FLAGS.config.modality}
    Finetuning Mode: {FLAGS.config.finetuning_mode}

    # Devices: {jax.device_count()}
    Batch size: {FLAGS.config.batch_size} ({FLAGS.config.batch_size // len(devices) } per device)
    # Steps: {FLAGS.config.num_steps}
"""
)

#########
#
# Setup Jax Data Parallelism
#
#########

assert (
    FLAGS.config.batch_size % len(devices) == 0
), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})"
assert (
    FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0
), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})"

# create a 1D mesh with a single axis named "batch"
mesh = Mesh(jax.devices(), axis_names="batch")
# Our batches will be data-parallel sharded -- each device will get a slice of the batch
dp_sharding = NamedSharding(mesh, PartitionSpec("batch"))
# Our model will be replicated across devices (we are only doing data parallelism, not model parallelism)
replicated_sharding = NamedSharding(mesh, PartitionSpec())

# prevent tensorflow from using GPU memory since it's only used for data loading
tf.config.set_visible_devices([], "GPU")

#########
#
# Setup WandB
#
#########

name = format_name_with_config(
    FLAGS.name,
    FLAGS.config.to_dict(),
)
# wandb_id = "{name}_{time}".format(
#     name=name,
#     time=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"),
# )
# wandb.init(
#     config=FLAGS.config.to_dict(),
#     id=wandb_id,
#     name=name,
#     mode="disabled" if FLAGS.debug else None,
#     **FLAGS.config.wandb,
# )

#########
#
# Load Pretrained model + optionally modify config
#
#########
pretrained_model_kwargs = {
    "checkpoint_path": "/home/joshwajones/tpu_octo_ckpts/standard_largebatch_20240605_015537"
    # "checkpoint_path": FLAGS.config.pretrained_path
}
for step in range(5000, 50001, 5000):
        pretrained_model_kwargs["step"] = step 
pretrained_model = OctoModel.load_pretrained(
    **pretrained_model_kwargs
)
rng = jax.random.PRNGKey(FLAGS.config.seed)
rng, init_rng = jax.random.split(rng)
model = pretrained_model

flat_config = flax.traverse_util.flatten_dict(
    pretrained_model.config, keep_empty_nodes=True
)

config = ConfigDict(flax.traverse_util.unflatten_dict(flat_config))
# config.update(FLAGS.config.get("update_config", ConfigDict()))
config = config.to_dict()
# check_config_diff(config, pretrained_model.config)
#########
#
# Setup Data Loader
#
#########

# create text processor
if config["text_processor"] is None:
    text_processor = None
else:
    text_processor = ModuleSpec.instantiate(config["text_processor"])()

def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch

params = model.params
if FLAGS.config.optimizer.frozen_keys is None:
    FLAGS.config.optimizer.frozen_keys = model.config["optimizer"]["frozen_keys"]

tx, lr_callable, param_norm_callable = create_optimizer(
    params,
    **FLAGS.config.optimizer.to_dict(),
)
train_state = TrainState.create(
    model=model,
    tx=tx,
    rng=rng,
)

if FLAGS.config.modality == "image_conditioned":
    modes_to_evaluate = ["image_conditioned"]
elif FLAGS.config.modality == "text_conditioned":
    modes_to_evaluate = ["text_conditioned"]
elif FLAGS.config.modality == "multimodal":
    modes_to_evaluate = ["image_conditioned", "text_conditioned"]
else:
    modes_to_evaluate = ["base"]

dataset_kwargs_list = [FLAGS.config.dataset_kwargs]

# viz_callback = VisualizationCallback(
#     text_processor=text_processor,
#     val_dataset_kwargs_list=dataset_kwargs_list,
#     dataset_kwargs=FLAGS.config,
#     modes_to_evaluate=modes_to_evaluate,
#     **FLAGS.config.viz_kwargs,
# )



Initialized persistent compilation cache at /home/joshwajones/.jax_compilation_cache


NameError: name 'example_batch' is not defined

In [38]:
if config["text_processor"] is None:
    text_processor = None
else:
    text_processor = ModuleSpec.instantiate(config["text_processor"])()

def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch

dataset = make_single_dataset(
    FLAGS.config.dataset_kwargs,
    traj_transform_kwargs=FLAGS.config.traj_transform_kwargs,
    frame_transform_kwargs=FLAGS.config.frame_transform_kwargs,
    train=True,
)
# train_data_iter = (
#     dataset.repeat()
#     .unbatch()
#     .shuffle(FLAGS.config.shuffle_buffer_size)
#     .batch(FLAGS.config.batch_size)
#     .iterator()
# )
# train_data_iter = map(process_batch, train_data_iter)
# example_batch = next(train_data_iter)

In [39]:
from octo.data.utils.text_processing import MuseEmbedding

embedder = MuseEmbedding()
def process_batch(batch):
    batch = process_text(batch, embedder)
    return batch
    

In [40]:
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(FLAGS.config.shuffle_buffer_size)
    .batch(FLAGS.config.batch_size)
    .iterator()
)
# ex1 = next(train_data_iter)
# for key, val in ex1.items(): 
#     print(key) 
#     if isinstance(val, dict): 
#         for k2, v2 in val.items(): 
#             print("     ", k2)
train_data_iter = map(process_batch, train_data_iter)
example_batch = next(train_data_iter)


W0000 00:00:1717740454.661080 2344783 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 256 } dim { size: 256 } dim { size: -15 } } } inputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -2 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } } device { type: "CPU" vendor: "GenuineIntel" model: "101" frequency: 2300 num_cores: 64 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 1048576 l3_cache_size: 23068672 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: -24 } dim { size: -25 } dim { size: -15 } } }
W0000 00:00:171774045

In [59]:
# example_batch['task']['language_instruction']['input_ids'][6]
# ex1['task']['language_instruction']
lang = example_batch['task']['language_instruction'][:, None, ...]
print(lang.shape)
rep = jnp.tile(lang, (1, 2, 1))
print(rep.shape)

(256, 1, 512)
(256, 2, 512)


In [58]:
jnp.all(lang[:, 0, :] == rep[:, 1, :])

Array(True, dtype=bool)

In [72]:
class ResnetModule(nn.Module): 
    image_encoder_stages: Sequence[tuple[str, tuple]] = (("image_primary", (2, 2, 2, 2)), ("image_wrist", (2, 2, 2, 2))) 
    image_embedding_size: int = 512
    mlp_widths: tuple[int] = ()  
    language_key: str = "language_instruction"
    action_dim: int = 7
    action_pred_horizon: int = 1  

    @nn.compact
    def __call__(
        self, 
        batch: Data, 
    ): 
        observations = batch['observation']
        b, w = observations[self.image_encoder_stages[0][0]].shape[:2]
        embeddings = []
        for observation_key, encoder_stages in self.image_encoder_stages: 
            embedding = ViTResnet(num_layers=encoder_stages)(observations[observation_key])
            embedding = StdConv(
                self.image_embedding_size, 
                (3, 3)
            )(embedding) 
            embedding = jnp.mean(embedding, axis = (-2, -3)) # GAP
            embeddings.append(embedding)
        
        lang = jnp.tile(batch['task'][self.language_key][:, None, ...], (1, w, 1)) # repeat task embedding over window
        embeddings.append(lang)
    
        x = jnp.concatenate(embeddings, axis=-1)
        x = jnp.reshape(x, (b, -1))
        for width in self.mlp_widths: 
            x = nn.Dense(width)(x)
        x = nn.Dense(self.action_dim * self.action_pred_horizon)(x) 
        x = jnp.reshape(x, (-1, self.action_pred_horizon, self.action_dim))
        return x 

resnet = ResnetModule(mlp_widths=(100, 24))
varis = resnet.init(PRNGKey(0), example_batch)

out = resnet.apply(varis, example_batch)


In [73]:
out.shape

(256, 1, 7)

In [15]:
from octo.model.components.vit_encoders import ResNet26
from jax.random import PRNGKey

img_batch = example_batch['observation']['image_primary']
resnet = ResNet26()
variables = resnet.init(PRNGKey(0), img_batch) 

In [16]:
out = resnet.apply(variables, img_batch)

In [21]:
import jax.numpy as jnp 
import flax.linen as nn 
out.shape

# out2 = jnp.reshape(out, (*out.shape[:2], -1))
# out2.shape

(256, 2, 8, 8, 2048)

In [30]:
import flax.linen as nn 
from octo.model.components.vit_encoders import StdConv

class ResNetEncoder(nn.Module): 
    @nn.compact
    def __call__(self, obs): 
        x = ResNet26()(obs)
        x = StdConv(69, (3, 3))(x)
        x = jnp.mean(x, axis=(2, 3))
        return x 

        

In [31]:
rnet = ResNetEncoder()
varis = rnet.init(PRNGKey(0), img_batch )

In [32]:
out3 = rnet.apply(varis, img_batch)

In [33]:
out3.shape

(256, 2, 69)

In [34]:
from octo.data.utils.text_processing import MuseEmbedding

muse = MuseEmbedding()

enc = muse.encode("yo what's good")

In [36]:
enc.shape

(1, 512)

In [38]:
example_batch['observation'].keys()

dict_keys(['image_primary', 'image_wrist', 'timestep', 'pad_mask_dict', 'timestep_pad_mask', 'task_completed'])

In [134]:
from functools import partial
import json
import logging
from typing import Any, Optional, Tuple, Dict

import flax
from flax import struct
from flax.training import orbax_utils
import jax
from jax.experimental import multihost_utils
import jax.numpy as jnp
from jax.typing import ArrayLike
import numpy as np
import orbax.checkpoint
import tensorflow as tf
import flax.linen as nn

from octo.data.utils.text_processing import TextProcessor
from octo.model.components.action_heads import ActionHead
from octo.model.octo_module import OctoModule
from octo.utils.spec import ModuleSpec
from octo.utils.typing import Config, Data, Params, PRNGKey, Perturbations, Sequence
from octo.model.components.vit_encoders import StdConv, ViTResnet

class ResnetModule(nn.Module): 
    mlp_widths: tuple[int]
    image_embedding_size: int
    image_encoder_stages: Sequence[tuple[str, tuple]]
    language_key: Optional[str] = "language_instruction",
    action_dim: int = 7, 
    action_pred_horizon: int = 1, 

    @nn.compact
    def __call__(
        self, 
        observations,
        tasks, 
    ): 
        # observations = batch['observation']
        b, w = observations[self.image_encoder_stages[0][0]].shape[:2]
        embeddings = []
        for observation_key, encoder_stages in self.image_encoder_stages: 
            embedding = ViTResnet(num_layers=encoder_stages)(observations[observation_key])
            embedding = StdConv(
                self.image_embedding_size, 
                (3, 3)
            )(embedding) 
            embedding = jnp.mean(embedding, axis = (-2, -3)) # GAP
            embeddings.append(embedding)
        
        lang = jnp.tile(tasks[self.language_key][:, None, ...], (1, w, 1)) # repeat task embedding over window
        embeddings.append(lang)
        x = jnp.concatenate(embeddings, axis=-1)
        x = jnp.reshape(b, -1)
        for width in self.mlp_widths: 
            x = nn.Dense(width)(x)
        x = nn.Dense(self.action_dim * self.action_pred_horizon)(x) 
        x = jnp.reshape(x, (-1, self.action_pred_horizon, self.action_dim))
        return x 

    @classmethod
    def create(
        cls,
        mlp_widths: tuple[int], 
        image_embedding_size: int,
        image_encoder_stages: Optional[Sequence[tuple[str, tuple]]] = None,
        language_key: Optional[str] = "language_instruction",
        action_dim: int = 7, 
        action_pred_horizon: int = 1, 
    ) -> "OctoModule":
    
        if image_encoder_stages is None: 
           image_encoder_stages = (("image_primary", (2, 2, 2, 2)), ("image_wrist", (2, 2, 2, 2))) 

        return cls(
            mlp_widths, 
            image_embedding_size,
            image_encoder_stages,
            language_key,
            action_dim, 
            action_pred_horizon,
        )


@struct.dataclass
class ResnetModel:
    module: ResnetModule = struct.field(pytree_node=False)
    text_processor: TextProcessor = struct.field(pytree_node=False)
    config: Config = struct.field(pytree_node=False)
    params: Params
    perturbations: Perturbations
    example_batch: Data
    dataset_statistics: Optional[Data]

    def create_tasks(
        self, texts: Sequence[str] = None
    ):
        tasks = {} 

        assert self.text_processor is not None
        tasks["language_instruction"] = texts
        tasks["pad_mask_dict"]["language_instruction"] = np.ones(
            len(texts), dtype=bool
        )

        tasks["language_instruction"] = self.text_processor.encode(
            tasks["language_instruction"]
        )
            
        _verify_shapes(tasks, "tasks", self.example_batch["task"], starting_dim=1)
        return tasks


    @jax.jit
    def run_resnet(
        self,
        observations: Data,
        tasks: Data,
    ):
        # print('hi')
        # _verify_shapes(
        #     observations,
        #     "observations",
        #     self.example_batch["observation"],
        #     starting_dim=2,
        # )
        # _verify_shapes(tasks, "tasks", self.example_batch["task"], starting_dim=1)
        # print('yo')
        return self.module.apply(
            {"params": self.params},
            observations,
            tasks
        )


   
    @jax.jit 
    def sample_actions(
        self,
        observations: Data,
        tasks: Data,  
        unnormalization_statistics: Optional[Data] = None,
    ):
        # print("here")
        action = self.run_resnet(
            observations, tasks
        )
        if unnormalization_statistics is not None:
            mask = unnormalization_statistics.get(
                "mask", jnp.ones_like(unnormalization_statistics["mean"], dtype=bool)
            )
            action = action[..., : len(mask)]
            action = jnp.where(
                mask,
                (action * unnormalization_statistics["std"])
                + unnormalization_statistics["mean"],
                action,
            )
        return action

    @classmethod
    def load_pretrained(
        cls,
        checkpoint_path: str,
        step: Optional[int] = None,
    ) -> "ResnetModel":
        """Loads a model from a checkpoint that was saved via `save_pretrained`.

        Args:
            checkpoint_path (str): A path to either a directory of checkpoints or a single checkpoint.
            step (int, optional): If multiple checkpoints are present, which one to load. Defaults to the latest.
        """

        # load config
        with tf.io.gfile.GFile(
            tf.io.gfile.join(checkpoint_path, "config.json"), "r"
        ) as f:
            config = json.load(f)

        # load example batch
        with tf.io.gfile.GFile(
            tf.io.gfile.join(checkpoint_path, "example_batch.msgpack"), "rb"
        ) as f:
            example_batch = flax.serialization.msgpack_restore(f.read())
        # shim for migrating from "tasks" to "task"
        if "tasks" in example_batch:
            example_batch["task"] = example_batch.pop("tasks")

        logging.debug(
            "Model was trained with observations: %s",
            flax.core.pretty_repr(
                jax.tree_map(jnp.shape, example_batch["observation"])
            ),
        )
        logging.debug(
            "Model was trained with tasks: %s",
            flax.core.pretty_repr(jax.tree_map(jnp.shape, example_batch["task"])),
        )

        # load dataset statistics
        with tf.io.gfile.GFile(
            tf.io.gfile.join(checkpoint_path, "dataset_statistics.json"), "r"
        ) as f:
            dataset_statistics = json.load(f)
            dataset_statistics = jax.tree_map(
                np.array, dataset_statistics, is_leaf=lambda x: not isinstance(x, dict)
            )

        # create model def (an OctoModule)
        module = ResnetModule.create(**config["model"])
        # infer params shape without actually doing any computation


        init_args = (
            example_batch['observation'],
            example_batch['task']
        )
        perturbations = module.init(jax.random.PRNGKey(0), *init_args)['perturbations']
        params_shape = jax.eval_shape(
            partial(module.init), jax.random.PRNGKey(0), *init_args
        )["params"]
        # restore params, checking to make sure the shape matches
        checkpointer = orbax.checkpoint.CheckpointManager(
            checkpoint_path, orbax.checkpoint.PyTreeCheckpointer()
        )
        step = step if step is not None else checkpointer.latest_step()
        params = checkpointer.restore(step, params_shape)

        if config["text_processor"] is not None:
            text_processor = ModuleSpec.instantiate(config["text_processor"])()
        else:
            text_processor = None

        return cls(
            module=module,
            params=params,
            perturbations=perturbations,
            text_processor=text_processor,
            example_batch=example_batch,
            config=config,
            dataset_statistics=dataset_statistics,
        )

    def save_pretrained(
        self,
        step: int,
        checkpoint_path: Optional[str] = None,
        checkpoint_manager: Optional[orbax.checkpoint.CheckpointManager] = None,
    ):
        """Saves a model, as well as corresponding metadata needed for `load_pretrained`. Takes either a
        pre-existing checkpoint manager (which already knows where to save the checkpoint) or a path to a
        directory to save the checkpoint to.

        Args:
            step (int): Step number.
            checkpoint_path (str, optional): Path to save the checkpoint.
            checkpoint_manager (optional): Checkpoint manager to save the checkpoint.
            params (optional): Params to save. If None, uses self.params.
        """
        if (checkpoint_path is None) == (checkpoint_manager is None):
            raise ValueError(
                "Must provide exactly one of checkpoint_path or checkpoint_manager."
            )
        if checkpoint_manager is None:
            # checkpoint_manager = orbax.checkpoint.CheckpointManager(
            #     checkpoint_path, orbax.checkpoint.PyTreeCheckpointer()
            # )
            raise RuntimeError
        if checkpoint_path is None:
            checkpoint_path = str(checkpoint_manager._directory)

        # save params
        checkpoint_manager.save(
            step,
            self.params,
            {"save_args": orbax_utils.save_args_from_target(self.params)},
        )

        if jax.process_index() == 0:
            # save config
            config_path = tf.io.gfile.join(checkpoint_path, "config.json")
            if not tf.io.gfile.exists(config_path):
                with tf.io.gfile.GFile(config_path, "w") as f:
                    json.dump(self.config, f)

            # save example batch
            example_batch_path = tf.io.gfile.join(
                checkpoint_path, "example_batch.msgpack"
            )
            if not tf.io.gfile.exists(example_batch_path):
                with tf.io.gfile.GFile(example_batch_path, "wb") as f:
                    f.write(flax.serialization.msgpack_serialize(self.example_batch))

            # save dataset statistics
            dataset_statistics_path = tf.io.gfile.join(
                checkpoint_path, "dataset_statistics.json"
            )
            if not tf.io.gfile.exists(dataset_statistics_path):
                with tf.io.gfile.GFile(dataset_statistics_path, "w") as f:
                    json.dump(
                        jax.tree_map(lambda x: x.tolist(), self.dataset_statistics),
                        f,
                    )

    @classmethod
    def from_config(
        cls,
        config: Config,
        example_batch: Data,
        text_processor: Optional[Any] = None,
        verbose: bool = False,
        rng: Optional[PRNGKey] = None,
        dataset_statistics: Optional[Data] = None,
    ):

        module = ResnetModule.create(**config["model"])
        rng = rng if rng is not None else jax.random.PRNGKey(0)
        example_batch = multihost_utils.process_allgather(example_batch)
        example_batch = jax.tree_map(lambda x: x[:1], example_batch)

        init_args = (
            example_batch['observation'],
            example_batch['task']
        )

        if verbose:
            print(
                module.tabulate(rng, *init_args, verbose=True, depth=2)
            )  # Prints out the parameter count of our model, and tokenizer details

        @jax.jit
        def _init(rng):
            return module.init(rng, *init_args)

        variables = _init(rng)
        params = variables["params"]
        perturbations = variables.get('perturbations', None)
        
        return cls(
            module=module,
            params=params,
            perturbations=perturbations,
            text_processor=text_processor,
            example_batch=example_batch,
            config=config,
            dataset_statistics=dataset_statistics,
        )


def _verify_shapes(
    pytree,
    name: str,
    example_pytree,
    starting_dim: int = 0,
    strict: bool = False,
    raise_error: bool = True,
    silent: bool = False,
):
    weak_fail, fail = False, False
    pytree_flat = flax.traverse_util.flatten_dict(pytree)
    example_pytree_flat = flax.traverse_util.flatten_dict(example_pytree)

    # Check that all elements are present
    if set(pytree_flat.keys()) != set(example_pytree_flat.keys()):
        if not silent:
            extra = set(pytree_flat.keys()) - set(example_pytree_flat.keys())
            if extra:
                logging.warning(
                    "'%s' contains extra items compared to example_batch: %s",
                    name,
                    {"/".join(x) for x in extra},
                )
            missing = set(example_pytree_flat.keys()) - set(pytree_flat.keys())
            if missing:
                logging.warning(
                    "'%s' is missing items compared to example_batch: %s",
                    name,
                    {"/".join(x) for x in missing},
                )
        weak_fail = True

    mismatched_keys = {
        k: f"{pytree_flat[k].shape} != {example_pytree_flat[k].shape}"
        for k in pytree_flat
        if k in example_pytree_flat
        and pytree_flat[k].shape[starting_dim:]
        != example_pytree_flat[k].shape[starting_dim:]
    }
    if mismatched_keys:
        if not silent:
            logging.error(
                "'%s' contains mismatched shapes compared to example_batch: %s",
                name,
                flax.core.pretty_repr(
                    {"/".join(k): v for k, v in mismatched_keys.items()}
                ),
            )
        fail = True

    if raise_error and (fail or (weak_fail and strict)):
        raise AssertionError(f"{name} does not match example batch.")

    return weak_fail or fail




SPEC_TEMPLATE = """
This model is trained with a window size of {window_size}, predicting {action_dim} dimensional actions {action_horizon} steps into the future.
Observations and tasks conform to the following spec:

Observations: {observation_space}
Tasks: {task_space}

At inference, you may pass in any subset of these observation and task keys, with a history window up to {window_size} timesteps.
"""


In [135]:
# def from_config(
#         cls,
#         config: Config,
#         example_batch: Data,
#         text_processor: Optional[Any] = None,
#         verbose: bool = False,
#         rng: Optional[PRNGKey] = None,
#         dataset_statistics: Optional[Data] = None,
#     ):

#         module = ResnetModule.create(**config["model"])

config = {"model": 
    {
        "mlp_widths": (), 
        "image_embedding_size": 512
    }
}

model = ResnetModel.from_config(config, example_batch, text_processor)



In [90]:
from flax.traverse_util import flatten_dict
flat_params = flatten_dict(model.params)
for key in flat_params.keys():
    print(key)

('Dense_0', 'bias')
('Dense_0', 'kernel')
('StdConv_0', 'bias')
('StdConv_0', 'kernel')
('StdConv_1', 'bias')
('StdConv_1', 'kernel')
('ViTResnet_0', 'block1', 'unit1', 'conv1', 'kernel')
('ViTResnet_0', 'block1', 'unit1', 'conv2', 'kernel')
('ViTResnet_0', 'block1', 'unit1', 'conv3', 'kernel')
('ViTResnet_0', 'block1', 'unit1', 'conv_proj', 'kernel')
('ViTResnet_0', 'block1', 'unit1', 'gn1', 'bias')
('ViTResnet_0', 'block1', 'unit1', 'gn1', 'scale')
('ViTResnet_0', 'block1', 'unit1', 'gn2', 'bias')
('ViTResnet_0', 'block1', 'unit1', 'gn2', 'scale')
('ViTResnet_0', 'block1', 'unit1', 'gn3', 'bias')
('ViTResnet_0', 'block1', 'unit1', 'gn3', 'scale')
('ViTResnet_0', 'block1', 'unit1', 'gn_proj', 'bias')
('ViTResnet_0', 'block1', 'unit1', 'gn_proj', 'scale')
('ViTResnet_0', 'block1', 'unit2', 'conv1', 'kernel')
('ViTResnet_0', 'block1', 'unit2', 'conv2', 'kernel')
('ViTResnet_0', 'block1', 'unit2', 'conv3', 'kernel')
('ViTResnet_0', 'block1', 'unit2', 'gn1', 'bias')
('ViTResnet_0', 'block

In [138]:
# model.sample_actions(
    # example_batch['observation'], 
    # example_batch['task']
# )
out = model.sample_actions(
    example_batch['observation'], 
    example_batch['task']
)
import numpy as np 

# type(example_batch['observation'])
# for key, val in example_batch.items(): 
#     print(type(val))
#     if type(val) == dict: 
#         for v2 in val.values():
#             print("    ", type(v2))
#             if type(v2) == dict: 
#                 for v3 in v2.values():
#                     print("        ", type(v3))
#                     print(v3.flatten()[0])

                
#             elif type(v2) == np.ndarray: 
#                 print(v2.flatten()[0])
#     elif type(val) == np.ndarray: 
#         print(val.flatten()[0])


# model.run_resnet(
#     example_batch['observation'],
#     example_batch['task']
# )

# example_batch['observation']

# model.module.apply({"params": model.params}, example_batch['observation'], example_batch['task'])
# @partial(
#     jax.jit,
#     in_shardings=[replicated_sharding, dp_sharding],
# )
# def testfunc(batch): 
#     out = model.sample_actions(
#         batch['observation'], 
#         batch['task']
#     )
#     return out 

# out = testfunc(example_batch)

In [140]:
# obs = example_batch['observation']
# obs
out.shape

(1, 1, 7)

In [7]:
# # MAX_KEY_LEN = 15
# # INDENT_SIZE = MAX_KEY_LEN + 4
# # INDENT = ''.join([' ' for _ in range(INDENT_SIZE)])
# # def recursive_dict_print(dictionary, prefix=""): 
# #     for key, val in dictionary.items(): 
# #         key = key[:MAX_KEY_LEN]
# #         if isinstance(val, dict): 
# #             print(f'{prefix}{key}')
# #             new_prefix = prefix + INDENT
# #             recursive_dict_print(val, new_prefix)
# #         else: 
# #             indent = ''.join([' ' for _ in range(INDENT_SIZE - len(key))])
# #             print(f'{prefix}{key}:{indent}{val.shape}')


# # initialize_compilation_cache()
# # devices = jax.devices()
# # if FLAGS.o_batch_size > 0: 
# #     FLAGS.config.batch_size = FLAGS.o_batch_size
# # if FLAGS.o_window_size > 0: 
# #     FLAGS.config.window_size = FLAGS.o_window_size

# # logging.info(
# #     f"""
# #     Octo Finetuning Script
# #     ======================
# #     Pretrained model: {FLAGS.config.pretrained_path}
# #     Finetuning Dataset: {FLAGS.config.dataset_kwargs.name}
# #     Data dir: {FLAGS.config.dataset_kwargs.data_dir}
# #     Task Modality: {FLAGS.config.modality}
# #     Finetuning Mode: {FLAGS.config.finetuning_mode}

# #     # Devices: {jax.device_count()}
# #     Batch size: {FLAGS.config.batch_size} ({FLAGS.config.batch_size // len(devices) } per device)
# #     # Steps: {FLAGS.config.num_steps}
# #     # Window size: {FLAGS.config.window_size}
# # """
# # )

# # #########
# # #
# # # Setup Jax Data Parallelism
# # #
# # #########

# # assert (
# #     FLAGS.config.batch_size % len(devices) == 0
# # ), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})"
# # assert (
# #     FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0
# # ), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})"

# # # create a 1D mesh with a single axis named "batch"
# # mesh = Mesh(jax.devices(), axis_names="batch")
# # # Our batches will be data-parallel sharded -- each device will get a slice of the batch
# # dp_sharding = NamedSharding(mesh, PartitionSpec("batch"))
# # # Our model will be replicated across devices (we are only doing data parallelism, not model parallelism)
# # replicated_sharding = NamedSharding(mesh, PartitionSpec())

# # # prevent tensorflow from using GPU memory since it's only used for data loading
# # tf.config.set_visible_devices([], "GPU")


# #########
# #
# # Setup Data Loader
# #
# #########

# # # create text processor
# if config["text_processor"] is None:
#     text_processor = None
# else:
#     text_processor = ModuleSpec.instantiate(config["text_processor"])()

# def process_batch(batch):
#     # batch = process_text(batch, text_processor)
#     del batch["dataset_name"]
#     return batch

# dataset = make_single_dataset(
#     FLAGS.config.dataset_kwargs,
#     traj_transform_kwargs=FLAGS.config.traj_transform_kwargs,
#     frame_transform_kwargs=FLAGS.config.frame_transform_kwargs,
#     train=True,
# )
# train_data_iter = (
#     dataset.repeat()
#     .unbatch()
#     .shuffle(FLAGS.config.shuffle_buffer_size)
#     .batch(FLAGS.config.batch_size)
#     .iterator()
# )
# train_data_iter = map(process_batch, train_data_iter)
# example_batch = next(train_data_iter)


# # print example batch 
# print("############################################")
# print('Example batch:')
# print('\n\n')
# recursive_dict_print(example_batch)
# # if FLAGS.debug: 
# #    from PIL import Image 
# #    img_digit_l = Image.fromarray(example_batch['observation']['image_digit_left'][0][0])
# #    img_digit_l.save('./debug_example_batch/img_digit_l.jpeg')

# #    img_primary = Image.fromarray(example_batch['observation']['image_primary'][0][0])
# #    img_primary.save('./debug_example_batch/img_primary.jpeg')

# #    img_wrist = Image.fromarray(example_batch['observation']['image_wrist'][0][0])
# #    img_wrist.save('./debug_example_batch/img_wrist.jpeg')

# print('\n\n')
# print("############################################")

Cause: Unable to locate the source code of <function _gcd_import at 0x7f64ff82b400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


Cause: Unable to locate the source code of <function _gcd_import at 0x7f64ff82b400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


Cause: Unable to locate the source code of <function _gcd_import at 0x7f64ff82b400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


W0000 00:00:1717736599.767041 2128416 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 256 } dim { size: 256 } dim { size: -15 } } } inputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -2 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } } device { type: "CPU" vendor: "GenuineIntel" model: "101" frequency: 2300 num_cores: 64 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 1048576 l3_cache_size: 23068672 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: -24 } dim { size: -25 } dim { size: -15 } } }
W0000 00:00:171773659

############################################
Example batch:



observation
                   image_primary:      (256, 2, 256, 256, 3)
                   image_wrist:        (256, 2, 128, 128, 3)
                   timestep:           (256, 2)
                   pad_mask_dict
                                      image_primary:      (256, 2)
                                      image_wrist:        (256, 2)
                                      timestep:           (256, 2)
                   timestep_pad_ma:    (256, 2)
                   task_completed:     (256, 2, 4)
task
                   language_instru:    (256,)
                   pad_mask_dict
                                      language_instru:    (256,)
                                      image_primary:      (256,)
                                      image_wrist:        (256,)
                                      timestep:           (256,)
                   image_primary:      (256, 256, 256, 3)
                   i