In [1]:
from absl import app, flags, logging
import flax
import jax
import optax
import tensorflow as tf
import tqdm
import wandb

from octo.data.dataset import make_single_dataset
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_utils import (
    freeze_weights,
    merge_params,
    process_text,
    TrainState,
)







In [12]:
for name in list(flags.FLAGS):
  delattr(flags.FLAGS, name)

In [13]:
DATA_PATH = "C:/Users/willi/tensorflow_datasets/"    # UPDATE WITH PATH TO RLDS DATASETS
EXP_LOG_PATH = "C:/workspace/deligrasp_policy_learning/logs/octo" # UPDATE WITH PATH TO DESIRED LOGGING DIRECTORY
OCTO_CKPT_SMALL = "C:/Users/willi/.cache/huggingface/hub/models--rail-berkeley--octo-small-1.5/snapshots/dc9aa3019f764726c770814b27e4ab0fc6e32a58"
OCTO_CKPT_BASE = "C:/Users/willi/.cache/huggingface/hub/models--rail-berkeley--octo-base-1.5/snapshots/ee3c10e8edd6ce2e8b1e8744d3c6fba4097bed48"
FLAGS = flags.FLAGS

flags.DEFINE_string("f", "", "notebook path hack.")
flags.DEFINE_string(
    "pretrained_path",OCTO_CKPT_SMALL, "Path to pre-trained Octo checkpoint directory."
)
flags.DEFINE_string("data_dir", DATA_PATH, "Path to finetuning dataset, in RLDS format.")
flags.DEFINE_string("save_dir", EXP_LOG_PATH, "Directory for saving finetuning checkpoints.")
flags.DEFINE_integer("batch_size", 16, "Batch size for finetuning.")
flags.DEFINE_bool(
    "freeze_transformer",
    True,
    "Whether pre-trained transformer weights should be frozen.",
)
import sys
FLAGS(sys.argv)

['C:\\Users\\willi\\AppData\\Roaming\\Python\\Python310\\site-packages\\ipykernel_launcher.py']

In [29]:
# setup wandb for logging
wandb.init(name="octo_sm_dg", project="jaf")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbadinkajink[0m ([33mcorrelllab[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# main training loop
initialize_compilation_cache()
# prevent tensorflow from using GPU memory since it's only used for data loading
tf.config.set_visible_devices([], "GPU")

# load pre-trained model
logging.info("Loading pre-trained model...")
pretrained_model = OctoModel.load_pretrained(FLAGS.pretrained_path)


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [16]:
# make finetuning dataset
# apply Gaussian normalization, load chunks of 50 actions since we'll train with action chunking
# delete goal images in the data loader since we will train a language-conditioned-only policy
# TODO: directly load this from raw data to make it less opaque?
logging.info("Loading finetuning dataset...")
dataset = make_single_dataset(
    dataset_kwargs=dict(
        name="deligrasp_dataset",
        data_dir=FLAGS.data_dir,
        image_obs_keys={"primary": "image", "wrist": "wrist_image"},
        proprio_obs_key="state",
        language_key="language_instruction",
    ),
    traj_transform_kwargs=dict(
        window_size=2,
        action_horizon=15,
    ),
    frame_transform_kwargs=dict(
        resize_size={"primary": (256, 256), "wrist": (128, 128)},
    ),
    train=True,
)
train_data_iter = (
    dataset.repeat()
    .unbatch()
    .shuffle(100)  # can reduce this if RAM consumption too high
    .batch(FLAGS.batch_size)
    .iterator()
)


In [17]:
iterator = dataset.iterator()
traj = next(iterator)
print("Top-level keys: ", traj.keys())
print("Observation keys: ", traj["observation"].keys())
print("Task keys: ", traj["task"].keys())

Top-level keys:  dict_keys(['observation', 'task', 'action', 'dataset_name', 'action_pad_mask'])
Observation keys:  dict_keys(['image_primary', 'image_wrist', 'proprio', 'timestep', 'pad_mask_dict', 'timestep_pad_mask', 'task_completed'])
Task keys:  dict_keys(['language_instruction', 'pad_mask_dict'])


In [28]:
text_processor = pretrained_model.text_processor

def process_batch(batch):
    batch = process_text(batch, text_processor)
    del batch["dataset_name"]
    return batch

train_data_iter = map(process_batch, train_data_iter)
example_batch = next(train_data_iter)


In [27]:
# load pre-training config and modify --> remove wrist cam, add proprio input, change action head
# following Zhao et al. we use "action chunks" of length 50 and L1 loss for ALOHA
config = pretrained_model.config
config["model"]["heads"]["action"]["kwargs"]["action_dim"] = 8


In [30]:
# initialize weights for modified Octo model, then merge in all applicable pre-trained weights
# new position encodings for proprio inputs & weights for new action head will remain "from scratch"
logging.info("Updating model for new observation & action space...")
model = OctoModel.from_config(
    config,
    example_batch,
    text_processor,
    verbose=True,
    dataset_statistics=dataset.dataset_statistics,
)
merged_params = merge_params(model.params, pretrained_model.params)
# can perform any additional parameter surgery here...
# ...
model = model.replace(params=merged_params)
del pretrained_model

# create optimizer & train_state, optionally freeze keys for pre-trained transformer
# train_state bundles parameters & optimizers
learning_rate = optax.join_schedules(
    [optax.linear_schedule(0, 3e-5, 100), optax.constant_schedule(3e-5)], [100]
)
tx = optax.adamw(learning_rate)
frozen_keys = model.config["optimizer"]["frozen_keys"]
if FLAGS.freeze_transformer:
    frozen_keys.append("BlockTransformer_0")
tx = freeze_weights(tx, model.params, frozen_keys)
train_state = TrainState.create(
    rng=jax.random.PRNGKey(1234),
    model=model,
    tx=tx,
)


    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})
    task_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    obs_*: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
    readout_action: <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>,
})



[3m                              OctoModule Summary                               [0m
┌───────────────┬───────────────┬──────────────┬───────────────┬──────────────┐
│[1m [0m[1mpath         [0m[1m [0m│[1m [0m[1mmodule       [0m[1m [0m│[1m [0m[1minputs      [0m[1m [0m│[1m [0m[1moutputs      [0m[1m [0m│[1m [0m[1mparams      [0m[1m [0m│
├───────────────┼───────────────┼──────────────┼───────────────┼──────────────┤
│               │ OctoModule    │ -            │ - obs:        │              │
│               │               │ image_prima… │     mask:     │              │
│               │               │ [2muint8[0m[1,2,2… │ [2mbool[0m[1,2,336] │              │
│               │               │   image_wri… │     tokens:   │              │
│               │               │ [2muint8[0m[1,2,1… │ [2mfloat32[0m[1,2,… │              │
│               │               │   pad_mask_… │   obs_primar… │              │
│               │               │     i

In [32]:
# define loss function and train step
def loss_fn(params, batch, rng, train=True):
    bound_module = model.module.bind({"params": params}, rngs={"dropout": rng})
    transformer_embeddings = bound_module.octo_transformer(
        batch["observation"],
        batch["task"],
        batch["observation"]["timestep_pad_mask"],
        train=train,
    )
    action_loss, action_metrics = bound_module.heads["action"].loss(
        transformer_embeddings,  # Action head knows to pull out the action readout_key
        batch["action"],
        batch["observation"]["timestep_pad_mask"],
        batch["action_pad_mask"],
        train=train,
    )
    return action_loss, action_metrics

@jax.jit
def train_step(state, batch):
    rng, dropout_rng = jax.random.split(state.rng)
    (loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        state.model.params, batch, dropout_rng, train=True
    )
    new_state = state.apply_gradients(grads=grads, rng=rng)
    return new_state, info


In [33]:
n_epochs = 30
n_steps = 100
total_steps = n_epochs * n_steps
save_every_n_epochs = 10
save_every_n_steps = save_every_n_epochs * n_steps
logging.info("Starting finetuning...")
for i in tqdm.tqdm(range(total_steps), total=total_steps, dynamic_ncols=True):
    batch = next(train_data_iter)
    train_state, update_info = train_step(train_state, batch)
    if (i + 1) % n_steps == 0:
        update_info = jax.device_get(update_info)
        wandb.log(
            flax.traverse_util.flatten_dict({"training": update_info}, sep="/"),
            step=i,
        )
    if (i + 1) % save_every_n_steps == 0:
        # save checkpoint
        train_state.model.save_pretrained(step=i, checkpoint_path=FLAGS.save_dir)


  0%|          | 0/3000 [00:15<?, ?it/s]


ScopeParamShapeError: Initializer expected to generate shape (448, 256) but got shape (551, 256) instead for parameter "kernel" in "/heads_action/diffusion_model/reverse_network/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)