In [1]:
import os

# Restrict PyTorch to only see GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch

if torch.cuda.is_available():
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available, using CPU.")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Number of GPUs available: 1
GPU 0: NVIDIA L40S
Using device: cuda:0


In [2]:
from spatialvla.datasets import RLDSBatchTransform, RLDSDataset
from spatialvla.mobilevlm.model.mobilevlm import load_pretrained_vlm_for_vla, load_vla
from scripts.spatialvla_config import ModelArguments, TrainingArguments
import transformers
from spatialvla.datasets.rlds.utils.data_utils import PaddedCollatorForActionPrediction
from torch.utils.data import DataLoader
from spatialvla.mobilevlm.action_tokenizer import ActionTokenizer

2024-12-24 15:48:47.143379: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-24 15:48:47.172918: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-24 15:48:47.172960: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-24 15:48:47.173710: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-24 15:48:47.179253: I tensorflow/core/platform/cpu_feature_guar

[2024-12-24 15:48:55,755] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
model_args = ModelArguments()
# model_args.use_state_input = True

# tokenizer, model, image_processor, _ = load_vla(
#     'checkpoints/pick_tape_single_br_v7_ema',
#     load_8bit=False, 
#     load_4bit=False,
#     device='cuda',
# )

tokenizer, model, image_processor, _ = load_pretrained_vlm_for_vla(
    model_args, 
    load_8bit=False, 
    load_4bit=False,
    device='cuda',
    dtype=torch.bfloat16
)

at = ActionTokenizer(tokenizer)

batch_transform = RLDSBatchTransform(
    tokenizer,
    image_processor,
    use_state_input=False,
    action_tokenizer=at,
    window_size=1,
    future_action_window_size=7,
)
vla_dataset = RLDSDataset(
    data_root_dir='/home/shared/vla_benchmark_rlds',
    data_mix='bm_pick_tape_single',
    batch_transform=batch_transform,
    shuffle_buffer_size=100,
    window_size=1,
    future_action_window_size=7,
    train=True,
    use_state_input = False
)

collator = PaddedCollatorForActionPrediction(
    tokenizer.model_max_length, 
    tokenizer.pad_token_id, 
    padding_side='right', 
    use_state_input=False,
    use_label=True
)

Loading with torch.bfloat16


You are using a model of type mobilevlm to instantiate a model of type spatialvla. This is not supported for all configurations of models and can yield errors.


number of parameters: 3.688986e+07
number of parameters: 3.688986e+07
number of parameters: 3.688986e+07


Some weights of SpatialVLAForCausalLM were not initialized from the model checkpoint at remyxai/SpaceLLaVA-lite and are newly initialized: ['si.net.s_net.mid_modules.0.blocks.0.block.0.weight', 'si.net.b_net.down_modules.2.1.blocks.0.block.0.weight', 'si.net.s_net.down_modules.0.1.blocks.0.block.0.weight', 'si.net.s_net.mid_modules.1.blocks.0.block.1.bias', 'si.net.s_net.final_conv.0.block.1.weight', 'si.net.v_net.up_modules.1.1.blocks.0.block.0.weight', 'si.net.v_net.up_modules.1.1.cond_encoder.1.weight', 'condition_projector.map_head.p', 'si.net.v_net.down_modules.1.2.conv.weight', 'si.net.s_net.mid_modules.0.blocks.1.block.0.weight', 'si.net.b_net.diffusion_step_encoder.3.weight', 'si.net.b_net.down_modules.2.1.blocks.0.block.1.bias', 'si.net.s_net.up_modules.1.1.blocks.0.block.1.bias', 'si.net.b_net.down_modules.1.0.cond_encoder.1.bias', 'si.net.b_net.down_modules.0.1.blocks.1.block.0.bias', 'si.net.s_net.up_modules.1.1.cond_encoder.1.bias', 'si.net.v_net.up_modules.1.0.blocks.0.bl

{'name': 'bm_pick_tape_single', 'data_dir': '/home/shared/vla_benchmark_rlds', 'image_obs_keys': {'primary': 'image'}, 'absolute_action_mask': [False, False, False, False, False, False, True], 'action_normalization_mask': [True, True, True, True, True, True, False], 'action_proprio_normalization_type': <NormalizationType.NORMAL: 'normal'>, 'language_key': 'language_instruction', 'standardize_fn': <function lg_delta_ee_transform at 0x14b13637beb0>}


2024-12-24 15:49:13.779675: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2024-12-24 15:49:14.506331: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization



######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################

Threads per Dataset:  [1]
Reads per Dataset:  [1]
Constructing datasets...


2024-12-24 15:49:14.923645: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


Applying frame transforms on dataset...


In [4]:
for k in model.state_dict().keys():
    if 'ema' in k:
        print(k)

In [5]:
dataloader = DataLoader(
    vla_dataset,
    batch_size=16,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

In [6]:
cfg = TrainingArguments()

In [7]:
for d in dataloader:
    batch = d
    break
device_id = 0
input_ids=batch['input_ids'].to(device_id)
images=batch['pixel_values'].to(device_id)
attention_mask=batch['attention_mask'].to(device_id)
actionss=batch['action'].to(device_id)
use_cache=True
# states=batch['proprio'].to(device_id)
past_key_values = None
labels = batch['labels'].to(device_id)

In [8]:
with torch.autocast('cuda', dtype=torch.bfloat16):
    loss = model.forward(
        input_ids=batch['input_ids'].to(device_id),
        images=batch['pixel_values'].to(device_id),
        attention_mask=batch['attention_mask'].to(device_id),
        actions=batch['action'].to(device_id),
        states=None,
        labels=batch['labels'] if model.config.head_args['head_type'] == 'BR' else None
    )

In [9]:
loss.loss

(tensor(-2.2192, device='cuda:0', grad_fn=<AddBackward0>),
 tensor(0.0812, device='cuda:0'),
 tensor(-22.8653, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.0003, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-23.1423, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-46.0073, device='cuda:0', grad_fn=<AddBackward0>))

In [9]:
import sys
import torch
import argparse
from PIL import Image
from pathlib import Path
import numpy as np

from spatialvla.mobilevlm.model.mobilevlm import load_vla, load_pretrained_model
from spatialvla.mobilevlm.conversation import conv_templates, SeparatorStyle
from spatialvla.mobilevlm.utils import disable_torch_init, process_images, tokenizer_image_token, KeywordsStoppingCriteria
from spatialvla.mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN

In [10]:
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Tuple, Type

import numpy as np
import torch
import copy
from PIL import Image

from torch.utils.data import Dataset, IterableDataset

from transformers import PreTrainedTokenizerBase
from transformers import AutoTokenizer, BitsAndBytesConfig

from spatialvla.mobilevlm.utils import disable_torch_init, process_images, tokenizer_image_token, KeywordsStoppingCriteria
# from prismatic.models.backbones.llm.prompting import PromptBuilder
# from prismatic.models.backbones.vision import ImageTransform

from spatialvla.mobilevlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from spatialvla.mobilevlm.conversation import conv_templates, SeparatorStyle

from spatialvla.datasets.rlds.utils.data_utils import tree_map
# from prismatic.vla.action_tokenizer import ActionTokenizer
from spatialvla.datasets.rlds import make_interleaved_dataset, make_single_dataset
from spatialvla.datasets.rlds.oxe import OXE_NAMED_MIXTURES, get_oxe_dataset_kwargs_and_weights
from spatialvla.datasets.rlds.utils.data_utils import NormalizationType
from transformers import PreTrainedTokenizerBase

In [28]:
with torch.inference_mode():
    with torch.autocast('cuda', dtype=torch.bfloat16):
        actions_br = model.predict_action_br(
            input_ids=input_ids[:, :-8],
            images=images[:],
            num_denoise_steps=100
    )

torch.Size([16, 64]) torch.Size([16, 64]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 65]) torch.Size([16, 65]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 66]) torch.Size([16, 66]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 67]) torch.Size([16, 67]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 68]) torch.Size([16, 68]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 69]) torch.Size([16, 69]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 70]) torch.Size([16, 70]) None torch.Size([16, 1, 3, 336, 336])
torch.Size([16, 71]) torch.Size([16, 71]) None torch.Size([16, 1, 3, 336, 336])


In [29]:
import torch.nn as nn
nn.functional.mse_loss(actionss, actions_br, reduction='mean')

tensor(0.8740, device='cuda:0')

In [31]:
actions_br[0]

tensor([[ 1.5019, -0.7926, -1.8367, -0.1677,  0.2403, -1.0333,  0.2909],
        [ 1.6887,  0.2262, -1.9474,  0.4048,  0.4033, -0.8842,  0.4572],
        [ 1.9272,  0.6293, -1.9488,  0.0117,  0.5113, -0.9435,  0.2222],
        [ 1.4732,  0.5903, -1.9944, -1.1992, -0.1808, -1.9376,  0.9984],
        [ 2.3378,  0.1889, -1.1431, -1.1910, -0.1030, -0.7218,  0.2110],
        [ 1.1323,  0.1495, -1.6374, -0.4839, -1.4024, -1.4258,  0.6771],
        [ 1.7719, -0.0337,  0.0914, -0.6590, -1.5853, -0.7280,  0.5479],
        [ 0.1736, -0.0557, -1.1310,  0.1668, -1.5495, -1.4959,  0.7174]],
       device='cuda:0')

In [37]:
nn.functional.mse_loss(actions_br[1], actionss[1], reduction='mean')

tensor(0.3622, device='cuda:0')

In [21]:
sum(sum(sum((actions_br - actionss[0:3]) ** 2)))/(3 * 7 * 8)

tensor(0.9945, device='cuda:0')

In [14]:
prior_action = torch.tensor(at.detokenize(input_ids[0:3, -8:].cpu().numpy()))

In [15]:
with torch.inference_mode():
    with torch.autocast('cuda', dtype=torch.bfloat16):
        actions = model.predict_action(
            input_ids=input_ids[0:3],
            images=images[0:3],
            prior_actions=prior_action,
        )
    
# condition = model.condition_projector(action_hidden)
# predicted_action = model.si.sample(
#     x_prior=prior_action.cuda().to(dtype=torch.bfloat16),
#     cond=condition.float().flatten(1),
#     diffuse_step=5
# )

In [16]:
actions

tensor([[[ 1.0144e+00,  3.4179e-01,  7.2654e-01,  3.9015e-01, -9.3770e-01,
          -4.5149e-01, -7.6228e-01],
         [ 2.0379e+00,  8.4566e-01,  9.0550e-02, -8.2488e-01, -1.2245e+00,
          -6.4813e-01, -6.4755e-02],
         [ 8.8245e-01,  8.7198e-02,  4.4365e-01, -2.2315e-01, -1.5250e+00,
          -4.1966e-01, -4.1730e-01],
         [ 1.3478e+00,  8.5450e-01,  1.0696e+00, -8.3486e-01, -1.7778e-01,
          -2.8512e-01, -8.4387e-01],
         [ 1.1205e+00,  1.0850e-01,  6.1234e-01, -1.0628e+00, -1.1799e+00,
          -4.9026e-01, -5.8264e-01],
         [ 1.3364e+00,  1.9319e+00,  1.2419e+00, -8.2427e-01, -9.5564e-02,
          -8.9882e-01, -4.8563e-01],
         [ 6.5008e-01,  1.5054e+00,  1.4593e+00, -1.1284e+00, -4.1123e-01,
          -1.5325e-01, -4.8267e-01],
         [ 9.0568e-01,  1.9906e+00,  1.4715e+00, -1.0812e+00,  3.9649e-04,
          -3.5556e-01, -7.1115e-01]],

        [[-1.3972e+00,  5.5643e-01, -9.5746e-01,  3.8262e-01, -8.7173e-01,
          -3.8574e-01,  4.7

In [17]:
for d in dataloader:
    batch = d
    break
device_id = 0
input_ids=batch['input_ids'].to(device_id)
images=batch['pixel_values'].to(device_id)
attention_mask=batch['attention_mask'].to(device_id)
actionss=batch['action'].to(device_id)
use_cache=True
# states=batch['proprio'].to(device_id)
past_key_values = None
labels = batch['labels'].to(device_id)

with torch.inference_mode():
    with torch.autocast('cuda', dtype=torch.bfloat16):
        loss = model.forward(
                input_ids=batch['input_ids'].to(device_id),
                images=batch['pixel_values'].to(device_id),
                attention_mask=batch['attention_mask'].to(device_id),
                actions=batch['action'].to(device_id),
                states=batch['proprio'] if model_args.use_state_input else None,
                labels=batch['labels'] if model.config.head_args['head_type'] == 'BR' else None
        )
loss.loss[4]

tensor(3.2844, device='cuda:0')

In [18]:
action_logits = loss.logits[:, -51:-1]
action_preds = action_logits.argmax(dim=2)
action_gt = batch['labels'][:, -50:].to(action_logits.device)
mask = action_gt > at.action_token_begin_idx
correct_preds = (action_preds == action_gt) & mask
action_accuracy = correct_preds.sum().float() / mask.sum().float()
eps_loss = loss.loss[5]
v_loss = loss.loss[2]
s_loss = loss.loss[3]
b_loss = loss.loss[4]
tot_loss = loss.loss[0]
ce_loss = loss.loss[1]

In [19]:
b_loss

tensor(3.2844, device='cuda:0')

In [20]:
action_preds[:, -8:]

tensor([[29878, 30992, 31223, 31223, 31141, 31114, 31033, 31033],
        [30053, 30053, 30053, 30053, 30053, 30052, 29971, 29970],
        [29878, 31223, 31223, 31141, 31114, 31033, 31033, 31060],
        [29880, 29881, 29880, 29853, 29853, 29854, 30190, 30163],
        [31060, 31060, 31060, 31060, 31060, 31060, 31060, 31030],
        [ 4989, 31030, 31111, 31111, 31138, 31138, 31139, 31139],
        [29890, 29961, 29961, 29880, 29880, 29880, 29880, 29880],
        [29999, 30268, 30025, 30025, 30025, 30025, 30026, 30026],
        [29881, 30974, 30992, 31223, 31223, 31141, 31114, 31033],
        [29926, 29853, 29853, 29854, 30190, 30163, 30163, 30244],
        [29890, 29961, 29961, 29961, 29880, 29880, 29880, 29880],
        [29880, 29880, 29880, 29880, 29880, 29880, 29881, 29881],
        [29872, 30190, 30163, 30163, 30244, 30245, 30974, 30974],
        [29853, 30432, 30432, 30189, 30270, 30267, 30267, 30268],
        [29881, 29881, 29881, 29880, 29853, 29853, 29854, 30190],
        [2

In [21]:
action_gt[:, -8:]

tensor([[30974, 30992, 31223, 31223, 31141, 31114, 31033, 31033],
        [30026, 30053, 30053, 30053, 30053, 30052, 29971, 29970],
        [30992, 31223, 31223, 31141, 31114, 31033, 31033, 31060],
        [29881, 29881, 29880, 29853, 29853, 29854, 30190, 30163],
        [31060, 31060, 31060, 31060, 31060, 31060, 31060, 31030],
        [31030, 31030, 31111, 31111, 31138, 31138, 31139, 31139],
        [29961, 29961, 29961, 29880, 29880, 29880, 29880, 29880],
        [30267, 30268, 30025, 30025, 30025, 30025, 30026, 30026],
        [30974, 30974, 30992, 31223, 31223, 31141, 31114, 31033],
        [29880, 29853, 29853, 29854, 30190, 30163, 30163, 30244],
        [29961, 29961, 29961, 29961, 29880, 29880, 29880, 29880],
        [29880, 29880, 29880, 29880, 29880, 29880, 29881, 29881],
        [29854, 30190, 30163, 30163, 30244, 30245, 30974, 30974],
        [30432, 30432, 30432, 30189, 30270, 30267, 30267, 30268],
        [29881, 29881, 29881, 29880, 29853, 29853, 29854, 30190],
        [2

In [22]:
input_ids[:, -8:]

tensor([[30974, 30992, 31223, 31223, 31141, 31114, 31033, 31033],
        [30026, 30053, 30053, 30053, 30053, 30052, 29971, 29970],
        [30992, 31223, 31223, 31141, 31114, 31033, 31033, 31060],
        [29881, 29881, 29880, 29853, 29853, 29854, 30190, 30163],
        [31060, 31060, 31060, 31060, 31060, 31060, 31060, 31030],
        [31030, 31030, 31111, 31111, 31138, 31138, 31139, 31139],
        [29961, 29961, 29961, 29880, 29880, 29880, 29880, 29880],
        [30267, 30268, 30025, 30025, 30025, 30025, 30026, 30026],
        [30974, 30974, 30992, 31223, 31223, 31141, 31114, 31033],
        [29880, 29853, 29853, 29854, 30190, 30163, 30163, 30244],
        [29961, 29961, 29961, 29961, 29880, 29880, 29880, 29880],
        [29880, 29880, 29880, 29880, 29880, 29880, 29881, 29881],
        [29854, 30190, 30163, 30163, 30244, 30245, 30974, 30974],
        [30432, 30432, 30432, 30189, 30270, 30267, 30267, 30268],
        [29881, 29881, 29881, 29880, 29853, 29853, 29854, 30190],
        [2

In [23]:
actions

tensor([[[ 1.0144e+00,  3.4179e-01,  7.2654e-01,  3.9015e-01, -9.3770e-01,
          -4.5149e-01, -7.6228e-01],
         [ 2.0379e+00,  8.4566e-01,  9.0550e-02, -8.2488e-01, -1.2245e+00,
          -6.4813e-01, -6.4755e-02],
         [ 8.8245e-01,  8.7198e-02,  4.4365e-01, -2.2315e-01, -1.5250e+00,
          -4.1966e-01, -4.1730e-01],
         [ 1.3478e+00,  8.5450e-01,  1.0696e+00, -8.3486e-01, -1.7778e-01,
          -2.8512e-01, -8.4387e-01],
         [ 1.1205e+00,  1.0850e-01,  6.1234e-01, -1.0628e+00, -1.1799e+00,
          -4.9026e-01, -5.8264e-01],
         [ 1.3364e+00,  1.9319e+00,  1.2419e+00, -8.2427e-01, -9.5564e-02,
          -8.9882e-01, -4.8563e-01],
         [ 6.5008e-01,  1.5054e+00,  1.4593e+00, -1.1284e+00, -4.1123e-01,
          -1.5325e-01, -4.8267e-01],
         [ 9.0568e-01,  1.9906e+00,  1.4715e+00, -1.0812e+00,  3.9649e-04,
          -3.5556e-01, -7.1115e-01]],

        [[-1.3972e+00,  5.5643e-01, -9.5746e-01,  3.8262e-01, -8.7173e-01,
          -3.8574e-01,  4.7