In [1]:
import os
from collections import deque
from dataclasses import dataclass
from pathlib import Path

import draccus
import torch
import torch.distributed as dist
import tqdm
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

import wandb
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics

2024-08-30 20:54:48.180068: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-30 20:54:48.209138: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-30 20:54:48.209188: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-30 20:54:48.209934: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-30 20:54:48.215024: I tensorflow/core/platform/cpu_feature_guar

In [2]:
class FinetuneConfig:
    # fmt: off
    vla_path: str = "openvla/openvla-7b"                            # Path to OpenVLA model (on HuggingFace Hub)

    # Directory Paths
    data_root_dir: Path = Path("/home/jellyho/tensorflow_datasets")        # Path to Open-X dataset directory
    dataset_name: str = "onearm_clean_joint_pos"                                # Name of fine-tuning dataset (e.g., `droid_wipe`)
    run_root_dir: Path = Path("/home/jellyho/openvla_run")                               # Path to directory to store logs & checkpoints
    adapter_tmp_dir: Path = Path("/home/jellyho/openvla_tmp")                     # Temporary directory for LoRA weights before fusing

    # Fine-tuning Parameters
    batch_size: int = 2                                       # Fine-tuning batch size
    max_steps: int = 200_000                                        # Max number of fine-tuning steps
    save_steps: int = 1000                                          # Interval for checkpoint saving
    learning_rate: float = 2e-5                                     # Fine-tuning learning rate
    grad_accumulation_steps: int = 4                                # Gradient accumulation steps
    image_aug: bool = False                                       # Whether to train with image augmentations
    shuffle_buffer_size: int = 100_000                              # Dataloader shuffle buffer size (can reduce if OOM)

    # LoRA Arguments
    use_lora: bool = True                                           # Whether to use LoRA fine-tuning
    lora_rank: int = 64                                           # Rank of LoRA weight matrix
    lora_dropout: float = 0.0                                       # Dropout applied to LoRA weights
    use_quantization: bool = False                                  # Whether to 4-bit quantize VLA for LoRA fine-tuning
                                                                    #   => CAUTION: Reduces memory but hurts performance

    # Tracking Parameters
    wandb_project: str = "openvla"                                  # Name of W&B project to log to (use default!)
    wandb_entity: str = "onearm_clean" 

cfg = FinetuneConfig()

In [3]:
print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")

# [Validate] Ensure GPU Available & Set Device / Distributed Context
assert torch.cuda.is_available(), "Fine-tuning assumes at least one GPU is available!"
distributed_state = PartialState()
# torch.cuda.set_device(device_id := distributed_state.local_process_index)
torch.cuda.set_device(device_id := 0)
torch.cuda.empty_cache()

# Quantization Config =>> only if LoRA fine-tuning
quantization_config = None
if cfg.use_quantization:
    assert cfg.use_lora, "Quantized training only supported for LoRA fine-tuning!"
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
    )

# Load OpenVLA Processor and Model using HF AutoClasses
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)

# Create Action Tokenizer
action_tokenizer = ActionTokenizer(processor.tokenizer)

Fine-tuning OpenVLA Model `openvla/openvla-7b` on `onearm_clean_joint_pos`


In [4]:
import importlib
import prismatic.vla.datasets
window_size = 1
future_action_window_size = 29
# Force reload the module
importlib.reload(prismatic.vla.datasets)

from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset

batch_transform = RLDSBatchTransform(
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder if "v01" not in cfg.vla_path else VicunaV15ChatPromptBuilder,
    window_size=window_size,
    future_action_window_size=future_action_window_size
)
# batch_transform = lambda x: x
#INFO
vla_dataset = RLDSDataset(
    cfg.data_root_dir,
    cfg.dataset_name,
    batch_transform,
    resize_resolution=(224, 224),
    shuffle_buffer_size=cfg.shuffle_buffer_size,
    image_aug=cfg.image_aug,
    window_size=window_size,
    future_action_window_size=future_action_window_size
)
collator = PaddedCollatorForActionPrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
dataloader = DataLoader(
    vla_dataset,
    batch_size=cfg.batch_size,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

2024-08-30 20:55:02.508896: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2024-08-30 20:55:02.956838: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization



######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################



2024-08-30 20:55:03.397565: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [5]:
for batch in dataloader:
    b = batch
    break

In [6]:
b['labels']

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100, 31864, 31786, 31901, 31953, 31870,
         31785, 31744, 31864, 31786, 31901, 31953, 31870, 31785, 31744, 31864,
         31785, 31900, 31953, 31870, 31785, 31744, 31864, 31785, 31900, 31953,
         31871, 31785, 31744, 31864, 31784, 31899, 31953, 31871, 31785, 31744,
         31864, 31784, 31899, 31953, 31872, 31785, 31744, 31864, 31784, 31898,
         31953, 31872, 31785, 31744, 31864, 31783, 31897, 31953, 31873, 31785,
         31744, 31864, 31783, 31896, 31953, 31874, 31785, 31744, 31864, 31783,
         31896, 31953, 31876, 31785, 31744, 31864, 31782, 31895, 31953, 31877,
         31785, 31744, 31864, 31782, 31894, 31953, 31878, 31785, 31744, 31864,
         31781, 31894, 31953, 31879, 31785, 31744, 31864, 31781, 31893, 31953,
         31881, 31785, 31744, 31864, 31781, 31892, 3

In [7]:
b['input_ids']

tensor([[    1,   512, 29901,  1724,  3158,   881,   278, 19964,  2125,   304,
          5839,   701,   278, 18002,   322,  1925,   372,   297,   278, 25972,
         29973,    13,  3744, 29901, 29871, 31864, 31786, 31901, 31953, 31870,
         31785, 31744, 31864, 31786, 31901, 31953, 31870, 31785, 31744, 31864,
         31785, 31900, 31953, 31870, 31785, 31744, 31864, 31785, 31900, 31953,
         31871, 31785, 31744, 31864, 31784, 31899, 31953, 31871, 31785, 31744,
         31864, 31784, 31899, 31953, 31872, 31785, 31744, 31864, 31784, 31898,
         31953, 31872, 31785, 31744, 31864, 31783, 31897, 31953, 31873, 31785,
         31744, 31864, 31783, 31896, 31953, 31874, 31785, 31744, 31864, 31783,
         31896, 31953, 31876, 31785, 31744, 31864, 31782, 31895, 31953, 31877,
         31785, 31744, 31864, 31782, 31894, 31953, 31878, 31785, 31744, 31864,
         31781, 31894, 31953, 31879, 31785, 31744, 31864, 31781, 31893, 31953,
         31881, 31785, 31744, 31864, 31781, 31892, 3