In [1]:
import os

# Restrict PyTorch to only see GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch

if torch.cuda.is_available():
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available, using CPU.")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Number of GPUs available: 1
GPU 0: NVIDIA L40S
Using device: cuda:0


In [5]:
import tqdm
import tensorflow as tf
import torch
import wandb
import os
import copy
import json
import logging
import pathlib
from tqdm import tqdm
import transformers
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_scheduler
from transformers.trainer import ALL_LAYERNORM_LAYERS, get_parameter_names
from PIL import Image
from accelerate import PartialState
from dataclasses import dataclass, asdict
from typing import Dict, Optional, Sequence, List
from torch.nn import MSELoss, L1Loss, SmoothL1Loss
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast

from spatialvla.datasets import RLDSDataset, RLDSBatchTransform
from spatialvla.datasets.rlds.utils.data_utils import save_dataset_statistics
from spatialvla.datasets.rlds.utils.data_utils import PaddedCollatorForActionPrediction

from spatialvla.mobilevlm.model.mobilevlm import load_pretrained_vlm_for_vla
from spatialvla.mobilevlm.train.train import find_all_linear_names, get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, find_all_names_from_module

from scripts.spatialvla_config import ModelArguments, TrainingArguments, HEAD_ARGS

In [6]:

model_args = ModelArguments()
model_args.action_head = 'Diffusion'
model_args.head_args = HEAD_ARGS['Diffusion']
training_args = TrainingArguments()
dtype = torch.bfloat16
device_id = 0
tokenizer, model, image_processor, _ = load_pretrained_vlm_for_vla(
    model_args, 
    load_8bit=False, 
    load_4bit=False,
    device=device_id,
    dtype=dtype
)

Loading with torch.bfloat16


You are using a model of type mobilevlm to instantiate a model of type spatialvla. This is not supported for all configurations of models and can yield errors.
Some weights of SpatialVLAForCausalLM were not initialized from the model checkpoint at remyxai/SpaceLLaVA-lite and are newly initialized: ['action_head.diffusion_model.reverse_network.layers.2.dense2.weight', 'action_head.diffusion_model.cond_encoder.mlp.0.bias', 'action_head.diffusion_model.reverse_network.layers.2.dense_residual.bias', 'action_head.diffusion_model.reverse_network.in_dense.bias', 'action_head.diffusion_model.reverse_network.in_dense.weight', 'action_head.diffusion_model.reverse_network.layers.2.dense_residual.weight', 'action_pos', 'action_head.diffusion_model.reverse_network.layers.0.dense_residual.weight', 'action_head.diffusion_model.reverse_network.layers.0.dense1.bias', 'action_head.diffusion_model.reverse_network.layers.0.dense2.bias', 'action_head.diffusion_model.reverse_network.layers.1.layer_norm.bias

In [29]:
from PIL import Image
from spatialvla.mobilevlm.model.mobilevlm import load_vla, load_pretrained_model
from spatialvla.mobilevlm.conversation import conv_templates, SeparatorStyle
from spatialvla.mobilevlm.utils import disable_torch_init, process_images, tokenizer_image_token, KeywordsStoppingCriteria
from spatialvla.mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
image = Image.open('lasagna.png')
images = [image]
images_tensor = process_images(images, image_processor, {'image_aspect_ratio' : 'pad'}).to(model.device, dtype=torch.bfloat16)
conv = conv_templates['v1'].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + 'How many layers it have?')
conv.append_message(conv.roles[1], 'hello')
prompt = conv.get_prompt()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
# Input
input_ids = (tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda())
labels = input_ids.clone()
labels[:, :-5] = IGNORE_INDEX
attention_mask = input_ids.ne(IGNORE_INDEX)
stopping_criteria = KeywordsStoppingCriteria([stop_str],tokenizer, input_ids)

In [30]:
attention_mask = None
past_key_values = None
inputs_embeds = None
# labels = None
use_cache = None
output_attentions = False
output_hidden_states = False
return_dict = None
actions = None
output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
output_hidden_states = (output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else model.config.use_return_dict 

In [31]:
input_ids, attention_mask, past_key_values, inputs_embeds, labels = \
            model.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images_tensor)

In [32]:
outputs = model.model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    inputs_embeds=inputs_embeds,
    use_cache=use_cache,
    output_attentions=False,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict
)

In [34]:
outputs[0].shape

torch.Size([1, 194, 2048])