# TinyVLA architecture for PushT

## Starting from https://medium.com/correll-lab/robotic-behavior-cloning-i-auto-regressive-transformers-a7be623f4291 and https://colab.research.google.com/drive/18GIHeOQ5DyjMN8iIRZL2EKZ0745NLIpg?usp=sharing

In [1]:
# !pip install transformers
!pip install pyvips
# !pip install gym-pusht
# !pip install einops
# !pip install "accelerate>=0.26.0"
# !pip install matplotlib
# !python --version
# !pip install datasets
# !pip install gdown
!pip install zarr
# !apt-get install -y libvips libvips-dev
!apt-get install -y -qq libvips libvips-dev > /dev/null 2>&1
!pip install peft -q
!pip install -U bitsandbytes
# !pip install flash-attn -q



In [2]:
!pip list | grep torch

'grep' is not recognized as an internal or external command,
operable program or batch file.


In [1]:
import os # to deal with files
import gdown # to download from google drive
import zipfile # to unzip
import zarr # to load the dataset
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from PIL import Image
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
extracted_dataset_path = "pusht_cchi_v7_replay.zarr"  # Path to extracted dataset

if not os.path.isfile(dataset_path):
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
    gdown.download(id=id, output=dataset_path, quiet=False)

# Extract the dataset if it hasn't been extracted yet
if not os.path.isdir(extracted_dataset_path):
    with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
        zip_ref.extractall(extracted_dataset_path)


In [4]:
def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class PushTImageDataset(torch.utils.data.Dataset):
    def __init__(self,
                 dataset_path: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int,
                 tokenizer = None):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')

        # float32, [0,1], (N,96,96,3)
        train_image_data = dataset_root['data']['img'][:]
        train_image_data = np.moveaxis(train_image_data, -1,1)
        # (N,3,96,96)

        # (N, D)
        train_data = {
            # first two dims of state vector are agent (i.e. gripper) locations
            'agent_pos': dataset_root['data']['state'][:,:2],
            'action': dataset_root['data']['action'][:]
        }
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

        # fixed prompt
        prompt = "Given a (96,96,3) RGB image, where the green T represents the goal state of the gray T block and the blue dot represents the robot's current position, determine the next [x, y] coordinates the robot should move toward. The goal is to push the gray T block onto the goal state in the same position and orientation as the green indication. Return only the next step."
        self.tokenized_prompt = tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
        # self.tokenizer = tokenizer

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get normalized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:]

        # add fixed prompt to every example
        # tokenized_prompt = self.tokenizer(self.prompt, return_tensors="pt", max_length=128, truncation=True)
        nsample['text'] = self.tokenized_prompt["input_ids"].squeeze(0)
        nsample['attn_mask'] = self.tokenized_prompt["attention_mask"].squeeze(0)

        return nsample

In [5]:
def plt_img(image, pos=None, acts=None):
    if hasattr(image, 'permute'):
        image = image.permute(1, 2, 0).numpy()  # [96, 96, 3]

    if image.dtype != 'uint8':
        image = (image / 255.0)

    # Plot the image
    plt.figure(figsize=(10, 10))
    plt.imshow(image, origin='upper')

    if pos is not None:
        # Handle a single 2D point
        if pos.ndim == 1:
            pos = pos.unsqueeze(0) if hasattr(pos, 'unsqueeze') else pos[None, :]
        plt.plot(pos[:, 0], pos[:, 1], 'r+', label="Agent Position")

    if acts is not None:
        # Handle a single 2D point
        if acts.ndim == 1:
            acts = acts.unsqueeze(0) if hasattr(acts, 'unsqueeze') else acts[None, :]
        plt.plot(acts[:, 0], acts[:, 1], 'b*', label="Actions")

    plt.legend()
    plt.grid()
    plt.show()

In [6]:
# revision = "2025-04-14"
revision = "2024-04-02"

md2_tokenizer = AutoTokenizer.from_pretrained(
    "vikhyatk/moondream2",
    revision=revision,
    trust_remote_code=True
)
md2_tokenizer.pad_token = md2_tokenizer.eos_token

In [7]:
from torch.utils.data import random_split
from torch.utils.data import Subset

# parameters
random_split = False
batch_size = 2
obs_horizon = 1
action_horizon = 0
pred_horizon = 1
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTImageDataset(
    dataset_path=extracted_dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    tokenizer = md2_tokenizer
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

if random_split:
  train_size = int(0.8 * len(dataset))
  val_size = len(dataset) - train_size
  train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
else:
  train_dataset = Subset(dataset, range(int(0.8 * len(dataset))))
  val_dataset = Subset(dataset, range(int(0.8 * len(dataset)), len(dataset)))

# create dataloader
dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    # num_workers=1,
    # persistent_workers=True
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    # num_workers=1,
    # persistent_workers=True
)

In [8]:
# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['image'].shape)
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
print("batch['action'].shape", batch['action'].shape)
print("batch['text'].shape", batch['text'].shape)
print("batch['attn_mask'].shape", batch['attn_mask'].shape)

batch['image'].shape: torch.Size([2, 1, 3, 96, 96])
batch['agent_pos'].shape: torch.Size([2, 1, 2])
batch['action'].shape torch.Size([2, 1, 2])
batch['text'].shape torch.Size([2, 81])
batch['attn_mask'].shape torch.Size([2, 81])


In [9]:
vis_data = False
# Visualizing dataset:

if vis_data:
  B=35
  stats = {'min': 0, 'max': 96}
  image = batch['image'][B][0]
  pos = unnormalize_data(batch['agent_pos'][B], stats)
  acts = unnormalize_data(batch['action'][B], stats)
  plt_img(image, pos, acts)

  # image = batch['image'][B][1]
  # plt_img(image, pos, acts)

In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
flash_attn = False
fourbit = False


# https://huggingface.co/vikhyatk/moondream2/tree/2025-04-14
# https://github.com/vikhyat/moondream
# https://github.com/vikhyat/moondream/blob/main/notebooks/RepEng.ipynb
# https://github.com/vikhyat/moondream/blob/main/moondream/finetune/finetune_region.py

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2",
    revision=revision,
    trust_remote_code=True,
    device_map= {"": "cuda"}, # "auto",
    quantization_config=bnb_config if fourbit else None,
    attn_implementation="flash_attention_2" if flash_attn else None
)
model.text_model.lm_head = torch.nn.Linear(in_features=2048, out_features=2)
model.to(device);

PhiForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


### Example inference on base model before finetuning

In [11]:
vis_inf = False

if vis_inf:
  prompt = "Given a (96,96,3) RGB image, where the green T represents the goal state of the gray T block and the blue dot represents the robot's current position, determine the next [x, y] coordinates the robot should move toward. The goal is to push the gray T block onto the goal state in the same position and orientation as the green indication. Return only the next step."

  img = Image.fromarray(image.permute(1, 2, 0).byte().numpy())

  encoded_image = model.encode_image(img)
  action = model.query(encoded_image, prompt)
  print(action)

In [12]:
if vis_inf:
  # acts=[0.25,0.42]
  acts=[0.14*96,0.14*96]
  # for i, val in enumerate(acts):
  #   acts[i] = unnormalize_data(val, stats)
  acts = torch.tensor(acts)
  plt_img(image, acts)

### Finetuning with LoRA:

https://www.youtube.com/watch?v=5rH_VjKXuzg


In [13]:
parameter_verbosity = "none" # unique # all
# Look at model parameter names to determine target_modules

if parameter_verbosity == "unique":
  unique_layers = set()
  for name, _ in model.named_parameters():
      layer_type = name.split(".")[-2]  # Extract second-to-last part (usually "mlp", "ln", etc.)
      unique_layers.add(layer_type)
  for layer in sorted(unique_layers):
      print(layer)
elif parameter_verbosity == "all":
  for name, _ in model.named_parameters():
    print(name)

In [14]:
import re
target_modules = set()

vision_suffixes = {'fc1', 'fc2', 'proj', 'qkv', 'proj_mlp.'}
for name, _ in model.named_modules():
    if "visual." in name: # or ".region."
        if any(name.endswith(suffix) for suffix in vision_suffixes):
            target_modules.add(name)

# target_modules.add("lm_head")

# print("Final target_modules for vision-only LoRA:")
# target_modules = sorted(target_modules)
# for tm in target_modules:
#     print(" ", tm)
len(target_modules)

108

In [15]:
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig

# 4 bit quantization
if fourbit:
  model.gradient_checkpointing_enable()
  model = prepare_model_for_kbit_training(model)

lora = True
lora_alpha = 32 # affects contribution of LoRA updates
lora_rank = 64 # how much the LoRA layers compress the full-rank weight updates

if lora:
  lora_config = LoraConfig(
      r = lora_rank,
      lora_alpha = lora_alpha,
      target_modules= target_modules,
      lora_dropout = 0.1, # regularization
      bias = "none",
      task_type="CAUSAL_LM",
  )

In [16]:
if lora:
  model = get_peft_model(model, lora_config) # adds LoRA layers and freezes other layers
  model.print_trainable_parameters()

trainable params: 30,799,872 || all params: 1,783,373,682 || trainable%: 1.7271


In [17]:
num_epochs = 1
grad_accum_steps = 4
lr = 1.53e-5
if lora:
  lr_scaling = lora_alpha / (lora_rank**0.5)
  print("Learning rate scaling of", lr_scaling, "for  LoRA adapters")

Learning rate scaling of 4.0 for  LoRA adapters


In [18]:
# Optimizer
if lora:
  lora_params = [p for n, p in model.named_parameters() if "lora" in n and p.requires_grad]
  optimizer = torch.optim.AdamW(lora_params, lr=lr)
else:
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [19]:
from torchvision.transforms.functional import to_pil_image
loss_fn = torch.nn.MSELoss()

def compute_loss(batch):
    images = batch["image"].to(device).squeeze(1)
    text = batch["text"].to(device)
    labels = batch["action"].to(device).squeeze(1)
    attn_mask = batch["attn_mask"].to(device)
    
    pil_images = [to_pil_image(img) for img in images]
    img_emb = model.vision_encoder(pil_images)
    
    text_emb = model.text_model.get_input_embeddings()(text)
    
    # input_embs = torch.cat((text_emb[:, 0:1, :], img_emb, text_emb[:, 1:, :]), dim=1)
    input_embs = torch.cat((img_emb, text_emb[:, 1:, :], text_emb[:, 0:1, :]), dim=1) # [CLS]-esk token at the end
    
    outputs = model.text_model(
        inputs_embeds=input_embs,
        # labels=labels,
        # attention_mask = attn_mask
    )

    print(outputs.logits[:,-1])

    return loss_fn(outputs.logits[:,-1], labels)

In [20]:
from torchvision.transforms.functional import to_pil_image
loss_fn = torch.nn.MSELoss()


images = batch["image"].to(device).squeeze(1)
text = batch["text"].to(device)
labels = batch["action"].to(device).squeeze(1)
attn_mask = batch["attn_mask"].to(device)

pil_images = [to_pil_image(img) for img in images]
img_emb = model.vision_encoder(pil_images)

text_emb = model.text_model.get_input_embeddings()(text)

input_embs = torch.cat((img_emb, text_emb[:, 1:, :], text_emb[:, 0:1, :]), dim=1) # [CLS]-esk token at the end

outputs = model.text_model(inputs_embeds=input_embs)
print(outputs.logits[:,-1])
loss = loss_fn(outputs.logits[:,-1], labels)

loss

tensor([[ 3.4885, -1.2377],
        [ 3.3657, -1.1319]], device='cuda:0')


tensor(9.4760, device='cuda:0')

In [21]:
print(img_emb.shape)
print(text_emb.shape)
print(input_embs.shape)
print(outputs.logits.shape)

torch.Size([2, 729, 2048])
torch.Size([2, 81, 2048])
torch.Size([2, 810, 2048])
torch.Size([2, 810, 2])


In [22]:
compute_loss(batch)

tensor([[ 3.4885, -1.2377],
        [ 3.3657, -1.1319]], device='cuda:0')


tensor(9.4760, device='cuda:0')

In [23]:
model.text_model.transformer.gradient_checkpointing_enable()
# model.vision_encoder.transformer.gradient_checkpointing_enable()
model.vision_encoder;
# model.text_model
# model.transformer.gradient_checkpointing_enable()

In [None]:
from tqdm import tqdm
from torchvision import transforms


model.train()
model.text_model.transformer.gradient_checkpointing_enable()

i = 0
for epoch in range(num_epochs):
  for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
      i += 1

      loss = compute_loss(batch)
      loss.backward()

      if i % grad_accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

      lr = lr_schedule(i / grad_accum_steps, total_steps)
      for param_group in optimizer.param_groups:
        if lora and param_group['params'] == lora_params:
          param_group['lr'] = lr * lr_scaling
        else:
          param_group['lr'] = lr

      if i % eval_steps == 0:
        val_loss = 0
        for val_batch in tqdm(val_dataloader, desc="Validatioon"):
          with torch.no_grad():
            val_loss += compute_loss(val_batch).item()
        val_loss /= len(val_dataloader)
        print(val_loss)

        # Save model
        rand_id = random.randint(10000, 99999)
        filename = f"model_{rand_id}.pt"
        torch.save(model.state_dict(), filename)
        print(f"Saved model to {filename}")

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...   | 0/10178 [00:00<?, ?it/s]
