In [1]:
import os

# Restrict PyTorch to only see GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch

if torch.cuda.is_available():
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available, using CPU.")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Number of GPUs available: 1
GPU 0: NVIDIA L40S
Using device: cuda:0


In [2]:
from spatialvla.datasets import RLDSBatchTransform, RLDSDataset

2024-12-11 10:35:50.174649: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-11 10:35:50.204144: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-11 10:35:50.204193: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-11 10:35:50.205024: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-11 10:35:50.210359: I tensorflow/core/platform/cpu_feature_guar

In [26]:
from spatialvla.mobilevlm.model.mobilevlm import load_pretrained_vlm_for_vla, load_vla
from scripts.spatialvla_config import ModelArguments, TrainingArguments
import transformers

model_args = ModelArguments()
model_args.use_state_input = True

tokenizer, model, image_processor, _ = load_pretrained_vlm_for_vla(
    model_args,
    load_8bit=False, 
    load_4bit=False,
    device='cuda',
)
# tokenizer, model, image_processor, _ = load_vla('/home/jellyho/Bimanual_Imitation/MobileVLM-VLA/checkpoints/libero_object_octo_full')

Loading with torch.bfloat16


You are using a model of type mobilevlm to instantiate a model of type spatialvla. This is not supported for all configurations of models and can yield errors.
Some weights of SpatialVLAForCausalLM were not initialized from the model checkpoint at remyxai/SpaceLLaVA-lite and are newly initialized: ['action_head.diffusion_model.reverse_network.layers.1.dense2.weight', 'action_head.diffusion_model.reverse_network.layers.1.dense_residual.weight', 'action_head.diffusion_model.reverse_network.layers.2.dense2.weight', 'action_head.diffusion_model.reverse_network.layers.2.dense1.bias', 'action_head.diffusion_model.cond_encoder.mlp.1.bias', 'action_head.diffusion_model.reverse_network.out_dense.weight', 'action_head.diffusion_model.reverse_network.layers.2.layer_norm.weight', 'action_head.diffusion_model.reverse_network.layers.0.dense2.bias', 'action_head.diffusion_model.reverse_network.layers.0.dense2.weight', 'action_head.diffusion_model.reverse_network.layers.0.dense1.weight', 'action_head.

In [27]:
## RLDS Dataset loading
batch_transform = RLDSBatchTransform(
        tokenizer,
        image_processor,
    )
# Init complete

In [28]:
cfg = TrainingArguments()

In [29]:
vla_dataset = RLDSDataset(
        data_root_dir='/home/shared/rlds_datasets',
        data_mix='libero_object_no_noops',
        batch_transform=batch_transform,
        shuffle_buffer_size=100,
        window_size=1,
        future_action_window_size=0,
        use_state_input = True
    )

{'name': 'libero_object_no_noops', 'data_dir': '/home/shared/rlds_datasets', 'image_obs_keys': {'primary': 'image'}, 'state_obs_keys': ['EEF_state', 'gripper_state'], 'absolute_action_mask': [False, False, False, False, False, False, True], 'action_normalization_mask': [True, True, True, True, True, True, False], 'action_proprio_normalization_type': <NormalizationType.NORMAL: 'normal'>, 'language_key': 'language_instruction', 'standardize_fn': <function libero_dataset_transform at 0x148fb0cee440>}
['EEF_state', 'gripper_state']


2024-12-11 10:40:41.314177: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2024-12-11 10:40:41.435512: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


['EEF_state', 'gripper_state']

######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################

Threads per Dataset:  [1]
Reads per Dataset:  [1]
Constructing datasets...
['EEF_state', 'gripper_state']
Applying frame transforms on dataset...


2024-12-11 10:40:41.572008: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [30]:
from spatialvla.datasets.rlds.utils.data_utils import PaddedCollatorForActionPrediction

collator = PaddedCollatorForActionPrediction(tokenizer.model_max_length, tokenizer.pad_token_id, padding_side='right')

In [31]:
from torch.utils.data import DataLoader
dataloader = DataLoader(
        vla_dataset,
        batch_size=128,
        sampler=None,
        collate_fn=collator,
        num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
    )

In [32]:
for d in dataloader:
    batch = d
    if torch.sum(d['attention_mask'] == False) != 0:
        break

In [34]:
model.config.use_state_input

True

In [22]:
device_id = 0
input_ids=batch['input_ids'].to(device_id)
images=batch['pixel_values'].to(device_id)
attention_mask=batch['attention_mask'].to(device_id)
use_cache=True
states=batch['proprio']
past_key_values = None
labels = None

with torch.no_grad():
    with torch.autocast('cuda', dtype=torch.bfloat16):
        input_ids, attention_mask, past_key_values, inputs_embeds, labels = model.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)

In [20]:
attention_mask.shape

torch.Size([128, 209])

In [23]:
states