In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import flax 
import jax 
import optax 
import tensorflow as tf 
import tqdm 

from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_utils import (
    freeze_weights,
    merge_params,
    process_text,
    TrainState,
)

2024-04-13 12:57:58.653481: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-13 12:57:58.653502: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-13 12:57:58.654401: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
initialize_compilation_cache() 
tf.config.set_visible_devices([], "GPU")

Initialized persistent compilation cache at /home/nick/.jax_compilation_cache


In [4]:
pretrained_model = OctoModel.load_pretrained('hf://rail-berkeley/octo-small')
text_processor = pretrained_model.text_processor

Fetching 8 files: 100%|█████████████████████████████████████████████| 8/8 [00:00<00:00, 98400.09it/s]


In [5]:
import os 
os.getcwd()

'/home/nick/Documents/octo/octo/notebooks'

In [6]:
from octo.data.dataset import make_single_dataset 
dataset = make_single_dataset(
    dataset_kwargs = dict(
        name='cobot', 
        data_dir='../../tmp/cobot_from_directory', 
        image_obs_keys = {"primary" : "image_primary", "wrist": "image_wrist"}, 
        state_obs_keys=["proprio"], 
        action_proprio_normalization_type=NormalizationType.NORMAL, 
        absolute_action_mask=[True]*6, 
    ), 
    traj_transform_kwargs=dict(
        goal_relabeling_strategy='uniform',
        window_size=1, 
        future_action_window_size=1, 
    ), 
    frame_transform_kwargs=dict(
        resize_size={"primary": (256, 256), "wrist": (256, 256)}, 
    ), 
    train=True, 
)

DATASET BEFORE DLIMP
{'train': <_PrefetchDataset element_spec={'steps': DatasetSpec({'action': TensorSpec(shape=(6,), dtype=tf.float64, name=None), 'discount': TensorSpec(shape=(), dtype=tf.float64, name=None), 'is_first': TensorSpec(shape=(), dtype=tf.bool, name=None), 'is_last': TensorSpec(shape=(), dtype=tf.bool, name=None), 'is_terminal': TensorSpec(shape=(), dtype=tf.bool, name=None), 'observation': {'image_primary': TensorSpec(shape=(480, 640, 3), dtype=tf.uint8, name=None), 'image_wrist': TensorSpec(shape=(480, 640, 3), dtype=tf.uint8, name=None), 'proprio': TensorSpec(shape=(6,), dtype=tf.float64, name=None)}, 'reward': TensorSpec(shape=(), dtype=tf.float64, name=None)}, TensorShape([]))}>}


In [7]:
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(1000)
    .batch(20)
    .iterator()
)

In [8]:
# def process_batch(batch): 
#     batch = process_text(batch, None)
#     del batch["dataset_name"]
#     return batch

# train_data_iter = map(process_batch, train_data_iter)

In [9]:
example_batch = next(train_data_iter)

In [10]:
example_batch

{'observation': {'image_primary': array([[[[[ 77, 137, 165],
            [ 77, 137, 165],
            [ 75, 135, 163],
            ...,
            [ 51,  76,  75],
            [ 49,  76,  76],
            [ 43,  74,  82]],
  
           [[ 76, 136, 164],
            [ 77, 137, 165],
            [ 75, 135, 163],
            ...,
            [ 49,  75,  73],
            [ 45,  72,  74],
            [ 42,  76,  84]],
  
           [[ 75, 135, 163],
            [ 76, 136, 164],
            [ 75, 135, 163],
            ...,
            [ 48,  73,  72],
            [ 43,  71,  75],
            [ 38,  75,  86]],
  
           ...,
  
           [[111, 117, 103],
            [112, 118, 106],
            [113, 119, 107],
            ...,
            [  4,  23,  37],
            [ 24,  42,  55],
            [ 16,  34,  40]],
  
           [[111, 117, 103],
            [112, 118, 106],
            [113, 119, 107],
            ...,
            [ 15,  36,  50],
            [ 19,  37,  49],
       

In [11]:
# Load pre-training config and modify -> add proprio input, change action head 
config = pretrained_model.config 
config["model"]["observation_tokenizers"]["proprio"] = ModuleSpec.create(
    LowdimObsTokenizer, 
    n_bins = 256, 
    bin_type = "normal", 
    low = -2.0, 
    high = 2.0, 
    obs_keys = ["proprio"]
)

In [12]:
# Fully override the old action head with a new one 
config["model"]["heads"]["action"] = ModuleSpec.create(
    L1ActionHead, 
    pred_horizon=1, 
    action_dim=6, 
    readout_key="readout_action", 
)

In [13]:
# Initialize weights for the modified octo model 
model = OctoModel.from_config(
    config, 
    example_batch, 
    text_processor, 
    verbose=True, 
    dataset_statistics=dataset.dataset_statistics, 
)

    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    readout_action: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})



[3m                               OctoModule Summary                               [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs      [0m[1m [0m┃[1m [0m[1mparams      [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│               │ OctoModule    │ -             │ - obs:        │              │
│               │               │ image_primar… │     mask:     │              │
│               │               │ [2muint8[0m[1,1,25… │ [2mbool[0m[1,1,518] │              │
│               │               │   image_wris… │     tokens:   │              │
│               │               │ [2muint8[0m[1,1,25… │ [2mfloat32[0m[1,1,… │              │
│               │               │   pad_mask:   │   obs_primar… │              │
│               │            



In [14]:
merged_params = merge_params(model.params, pretrained_model.params) 
model = model.replace(params=merged_params) 

In [20]:
del pretrained_model

In [22]:
# Create optimizer and train state, optionally freeze keys for pre-trained transformer 
learning_rate = optax.join_schedules(
    [optax.linear_schedule(0, 3e-5, 100), optax.constant_schedule(3e-4)], [100]
)
tx = optax.adamw(learning_rate) 
frozen_keys = model.config["optimizer"]["frozen_keys"]
freeze_transformer = False 
if freeze_transformer: 
    frozen_keys.append("BlockTransformer_0")
tx = freeze_weights(tx, model.params, frozen_keys)

In [23]:
train_state = TrainState.create(
    rng=jax.random.PRNGKey(1234), 
    model=model, 
    tx=tx,
)

In [24]:
# Define loss function and train step 
def loss_fn(params, batch, rng, train=True):
    bound_module = model.module.bind({"params": params}, rngs={"dropout": rng})
    transformer_embeddings = bound_module.octo_transformer(
        batch["observation"], 
        batch["task"], 
        batch["observation"]["pad_mask"], 
        train=train, 
    )
    action_loss, action_metrics = bound_module.heads["action"].loss(
        transformer_embeddings, 
        batch["action"], 
        pad_mask=batch["observation"]["pad_mask"], 
        train=train, 
    )
    return action_loss, action_metrics


In [25]:
def train_step(state, batch):
    rng, dropout_rng = jax.random.split(state.rng) 
    (loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        state.model.params, batch, dropout_rng, train=True
    )
    new_state = state.apply_gradients(grads=grads, rng=rng)
    return new_state, info 

In [26]:
# Run finetuning loop 
print(f"Starting finetuning") 
for i in tqdm.tqdm(range(5000), total=5000, dynamic_ncols=True): 
    batch = next(train_data_iter) 
    train_state, update_info = train_step(train_state, batch) 

    update_info = jax.device_get(update_info)
    
    print(f"Training on epoch {i}") 
    print(f"With update info: {update_info}")

Starting finetuning


  0%|                                              | 0/5000 [00:00<?, ?it/s]


TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.