In [1]:
import tqdm
import tensorflow as tf
import torch
import wandb
import os
import copy
import json
import logging
import pathlib
from tqdm import tqdm
import transformers
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_scheduler
from transformers.trainer import ALL_LAYERNORM_LAYERS, get_parameter_names
from PIL import Image
from accelerate import PartialState
from dataclasses import dataclass, asdict
from typing import Dict, Optional, Sequence, List
from torch.nn import MSELoss, L1Loss, SmoothL1Loss
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast

from spatialvla.datasets import RLDSDataset, RLDSBatchTransform
from spatialvla.datasets.rlds.utils.data_utils import save_dataset_statistics
from spatialvla.datasets.rlds.utils.data_utils import PaddedCollatorForActionPrediction

from spatialvla.mobilevlm.model.mobilevlm import load_pretrained_vlm_for_vla
from spatialvla.mobilevlm.train.train import find_all_linear_names, get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, find_all_names_from_module

from scripts.spatialvla_config import ModelArguments, TrainingArguments, HEAD_ARGS

2024-11-24 14:26:44.974831: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-24 14:26:45.002832: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-24 14:26:45.002882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-24 14:26:45.003557: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-24 14:26:45.008880: I tensorflow/core/platform/cpu_feature_guar

[2024-11-24 14:26:48,530] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
model_args = ModelArguments()
model_args.action_head = 'DiT'
model_args.head_args = HEAD_ARGS['DiT']
training_args = TrainingArguments()
dtype = torch.bfloat16
device_id = 0

In [3]:
tokenizer, model, image_processor, _ = load_pretrained_vlm_for_vla(
    model_args, 
    load_8bit=False, 
    load_4bit=False,
    device=device_id,
    dtype=dtype
)

Loading with torch.bfloat16


You are using a model of type mobilevlm to instantiate a model of type spatialvla. This is not supported for all configurations of models and can yield errors.
Some weights of SpatialVLAForCausalLM were not initialized from the model checkpoint at remyxai/SpaceLLaVA-lite and are newly initialized: ['action_head.action_proj.2.weight', 'action_head.noise_pos', 'action_head.time_net.out_net.0.bias', 'action_head.time_net.w', 'action_head.time_net.out_net.0.weight', 'action_head.eps_net.linear.bias', 'action_head.eps_net.adaLN_modulation.1.weight', 'action_head.action_proj.0.bias', 'action_head.action_proj.2.bias', 'action_head.time_net.out_net.2.bias', 'action_head.timestep_pos', 'action_head.eps_net.adaLN_modulation.1.bias', 'action_head.action_proj.0.weight', 'action_head.eps_net.linear.weight', 'action_head.time_net.out_net.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You are resizing the embedding layer witho

In [4]:
model = model.to(device_id)

In [5]:
for param in model.parameters():
    param.requires_grad = False

In [6]:
batch_transform = RLDSBatchTransform(
    tokenizer,
    image_processor,
)

dataset = RLDSDataset(
    data_root_dir=training_args.data_root_dir,
    data_mix=training_args.data_mix,
    batch_transform=batch_transform,
    shuffle_buffer_size=training_args.shuffle_buffer_size,
    train=True,
    window_size=1,
    future_action_window_size=model_args.action_len - 1
)
collator = PaddedCollatorForActionPrediction(tokenizer.model_max_length, tokenizer.pad_token_id, padding_side='right')

dataloader = DataLoader(
    dataset,
    batch_size=training_args.batch_size,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

2024-11-24 14:27:18.128313: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
2024-11-24 14:27:18.527325: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization



######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################

Threads per Dataset: %s [1]
Reads per Dataset: %s [1]
Constructing datasets...


2024-11-24 14:27:18.926053: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


Applying frame transforms on dataset...


In [7]:
batch_idx, batch = next(enumerate(dataloader))

In [8]:
batch.keys()

dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'action', 'dataset_names'])

In [9]:
input_ids=batch['input_ids'].to(device_id)
images=batch['pixel_values'].to(device_id)
attention_mask=batch['attention_mask'].to(device_id)
actions=batch['action'].to(device_id)
past_key_values = None
labels = None

In [10]:
with torch.autocast('cuda', dtype=dtype):
    input_ids, attention_mask, past_key_values, inputs_embeds, labels = model.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
    input_ids, attention_mask, past_key_values, inputs_embeds, labels, time_enc, noise =  model.action_head.prepare_inputs_for_DiT(actions, input_ids, attention_mask, past_key_values, inputs_embeds, labels)
    outputs = model.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
    hidden = outputs[0].contiguous()  

In [11]:
time_enc.shape

torch.Size([32, 1, 2048])

In [14]:
output_tokens = hidden[:, -model.config.action_len:, :]
eps_out = model.action_head.eps_net(output_tokens, time_enc)

torch.Size([32, 1, 2048]) torch.Size([32, 1, 2048])
torch.Size([32, 1, 2048])


In [15]:
eps_out.shape

torch.Size([32, 1, 7])

In [18]:
import torch.nn as nn
loss = nn.functional.mse_loss(eps_out, noise, reduction='none')
loss = loss.sum(1).mean() # Sum over the actions

In [19]:
loss

tensor(0.8999, device='cuda:0')