# Using an LLM

The base model so far has been untrained. How can we make use of prior models to accelerate the learning process and make use of the prior learned structure. Can we integrate an LLM? [OpenVLA](https://openvla.github.io/) is an example of this.

Here is a example of this direction.

In [1]:

## Grab a chunk of data for training
import tensorflow_datasets as tfds
import cv2
import numpy as np
image_shape = [64, 64, 3]
num_episodes = 20 ## How many episodes to grab from the dataset for training

builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')
datasetRemote = builder.as_dataset(split='train[:' + str(num_episodes) + ']')
dataset = {"img": [], "action": [], "goal": [], "goal_img": [],
                "rotation_delta": [], "open_gripper": [] }
shortest_goal_txt = 10000000000
for episode in datasetRemote:
    episode_ = {'steps': [] }
    episode = list(episode['steps'])
    ## Goal image is just the last image/state/observation in the episode
    goal_img = cv2.resize(np.array(episode[-1]['observation']['image'], dtype=np.float32), (image_shape[0], image_shape[1]))
    for i in range(len(episode)):
        obs = cv2.resize(np.array(episode[i]['observation']['image'], dtype=np.float32), (image_shape[0], image_shape[1]))
        goal = episode[i]['observation']['natural_language_instruction'].numpy().decode()
        dataset["img"].append(obs)
        dataset["action"].append(np.array(np.concatenate((episode[i]['action']['world_vector'], 
                                                          episode[i]['action']['rotation_delta'],
                                                        [episode[i]['action']['open_gripper']]), axis=0)))
         
        dataset["rotation_delta"].append(np.array(episode[i]['action']['rotation_delta']))
        dataset["open_gripper"].append(np.array(episode[i]['action']['open_gripper']))
        dataset["goal"].append(goal)
        dataset["goal_img"].append(goal_img)
        if len(goal) < shortest_goal_txt: shortest_goal_txt = len(goal)

# here are all the unique characters that occur in this text
chars = sorted(list(set([item for row in dataset["goal"] for item in row])))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode_txt = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode_txy = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
print("vocab_size:", vocab_size)
print("example text encode:", encode_txt(dataset["goal"][0]))

print("Dataset shape:", len(dataset["img"]))
dataset["img"] = np.array(dataset["img"], dtype=np.uint8)
dataset["action"] = np.array(dataset["action"], dtype=np.float32)
# dataset["goal"] = np.array(encode_txt(dataset["goal"]), dtype=np.float32)
dataset["goal_img"] = np.array(dataset["goal_img"], dtype=np.uint8)


  from .autonotebook import tqdm as notebook_tqdm
2026-02-06 16:16:54.901788: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-06 16:16:54.901837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-06 16:16:54.903683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-06 16:16:54.914331: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
202

vocab_size: 31
example text encode: [6, 18, 8, 10, 12, 0, 25, 15, 12, 0, 10, 8, 20, 0, 25, 21, 0, 25, 15, 12, 0, 18, 12, 13, 25, 0, 21, 13, 0, 25, 15, 12, 0, 22, 21, 25, 2]
Dataset shape: 687


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
## This is an encoder head (full attention)
print ("torch version:", torch.__version__)
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size, n_embd, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B,T,C = x.shape
        if mask == None:
            mask = torch.ones((T, ), device=x.device) ## (1, T)
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        ### Block masked attention
        wei = wei.masked_fill(mask == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embd, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd=n_embd, dropout=dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x,)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head, dropout):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)
        self.ffwd = FeedFoward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, mask=None):
        x = x + self.sa(self.ln1(x), mask)
        x = x + self.ffwd(self.ln2(x))
        return x

class GRP(nn.Module):
  def __init__(self, dataset, cfg, mlp_ratio=4):
    super(GRP, self).__init__()
    self._dataset = dataset
    self._cfg = cfg

    self.token_embedding_table = nn.Embedding(cfg.vocab_size, cfg.n_embd)
    self.patch_size = (self._cfg.image_shape[0] / self._cfg.n_patches, self._cfg.image_shape[1] / self._cfg.n_patches)

    #Positional embedding
    self.register_buffer('positional_embeddings', calc_positional_embeddings(1 + self._cfg.n_patches ** 2 + self._cfg.block_size + self._cfg.n_patches ** 2, cfg.n_embd), persistent=False)
    
    self.class_tokens = nn.Parameter(torch.rand(1, cfg.n_embd))

    self.input_d = int(self._cfg.image_shape[2] * self.patch_size[0] * self.patch_size[1])

    self.lin_map = nn.Linear(self.input_d, self._cfg.n_embd, bias=False) 

    # 4) Transformer encoder blocks
    self.blocks = nn.ModuleList([Block(self._cfg.n_embd, self._cfg.n_head, dropout=self._cfg.dropout) for _ in range(self._cfg.n_blocks)])

    # 5) Classification MLPk
    self.mlp = nn.Sequential(
        nn.Linear(self._cfg.n_embd, self._cfg.action_bins),
        # nn.Softmax(dim=-1)
    )

  def _init_weights(self, module):
      if isinstance(module, nn.Linear):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
          if module.bias is not None:
              torch.nn.init.zeros_(module.bias)
      elif isinstance(module, nn.Embedding):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def forward(self, images, goals, goal_imgs, targets=None):
    # Dividing images into patches
    n, c, h, w = images.shape
    B, T = goals.shape
    patches = get_patches_fast(images)
    patches_g = get_patches_fast(goal_imgs)
    goals_e = self.token_embedding_table(goals)
    
    # Running linear layer tokenization
    # Map the vector corresponding to each patch to the hidden size dimension
    out = self.lin_map(patches)
    out_g = self.lin_map(patches_g)
    
    # Adding classification and goal_img tokens to the tokens
    out = torch.cat((out, goals_e, out_g, self.class_tokens.expand(n, 1, -1)), dim=1)
    
    # Adding positional embedding
    out = out + self.positional_embeddings.repeat(n, 1, 1)

    ## Compute blocked masks
    mask = torch.ones((c + T + c + 1), device=self._cfg.device) ## (1, T)
    if targets is None:
        pass
    elif (torch.rand(1)[0] > 0.66):  
        mask[c: c+ T] = torch.zeros((1,T), device=self._cfg.device) ## Mask goal string
    elif (torch.rand(1)[0] > 0.33):
        mask[c + T: c + T + c] = torch.zeros((1,c), device=self._cfg.device) ## Mask goal image
        
    # Transformer Blocks
    for block in self.blocks:
        out = block(out, mask)

    # Getting the classification token only
    out = out[:, -1]
    out = self.mlp(out)
        
    if targets is None:
        loss = None
    else:
        B, C = out.shape
        loss = F.mse_loss(out, targets) ## B, C
    return (out, loss)


torch version: 2.5.1+cu124


## Vision Language Action (VLA) Model

This section implements a VLA model that combines:
- **T5 Text Encoder**: For processing natural language instructions
- **CNN Vision Encoder**: For processing current and goal images
- **Pretrained LLM**: For action prediction

In [3]:
## Import necessary libraries for VLA model
# !pip install transformers==4.41.0
from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

print("Transformers library imported successfully")
print("transformers version:", transformers.__version__)

Transformers library imported successfully
transformers version: 4.41.0


In [4]:
## CNN Image Encoder for processing visual observations
class CNNImageEncoder(nn.Module):
    """
    Convolutional Neural Network for encoding images into feature vectors.
    Processes both current observation and goal images.
    """
    def __init__(self, image_shape=[64, 64, 3], output_dim=512):
        super(CNNImageEncoder, self).__init__()
        self.image_shape = image_shape
        self.output_dim = output_dim
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)  # 64x64 -> 32x32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 32x32 -> 16x16
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 16x16 -> 8x8
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) # 8x8 -> 4x4
        
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        
        # Calculate flattened size: 256 channels * 4 * 4 = 4096
        self.fc = nn.Linear(256 * 4 * 4, output_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        # Input: (B, C, H, W) where C=3, H=64, W=64
        # Expect input in range [-1, 1]
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Flatten with reshape
        x = x.reshape(x.size(0), -1)  # (B, 256*4*4)
        #
        # x = x.view(x.size(0), -1)  # (B, 256*4*4)
        
        # Project to output dimension
        x = self.dropout(F.relu(self.fc(x)))
        
        return x  # (B, output_dim)

print("CNN Image Encoder defined")

CNN Image Encoder defined


In [6]:
## Vision Language Action (VLA) Model
class VLAModel(nn.Module):
    """
    Vision Language Action Model that combines:
    - T5 text encoder for processing language instructions
    - CNN vision encoder for processing current and goal images
    - Pretrained LLM backbone for multimodal fusion and action prediction
    """
    def __init__(self, cfg, t5_model_name='google/t5-v1_1-small', llm_hidden_dim=512):
        super(VLAModel, self).__init__()
        self.cfg = cfg
        self.llm_hidden_dim = llm_hidden_dim
        
        # 1. T5 Text Encoder (frozen by default, can be fine-tuned)
        print(f"Loading T5 model: {t5_model_name}")
        self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
        self.text_encoder = T5EncoderModel.from_pretrained(t5_model_name)
        self.t5_hidden_dim = self.text_encoder.config.d_model
        
        # Option to freeze T5 parameters
        if cfg.get('freeze_text_encoder', True):
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            print("T5 text encoder frozen")
        
        # 2. CNN Vision Encoder
        self.vision_encoder = CNNImageEncoder(
            image_shape=cfg.image_shape,
            output_dim=llm_hidden_dim
        )
        
        # 3. Projection layers to align dimensions
        self.text_projection = nn.Linear(self.t5_hidden_dim, llm_hidden_dim)
        self.vision_projection = nn.Linear(llm_hidden_dim, llm_hidden_dim)
        
        # 4. Cross-modal fusion transformer
        self.fusion_layers = nn.ModuleList([
            Block(llm_hidden_dim, n_head=cfg.n_head, dropout=cfg.dropout)
            for _ in range(cfg.get('n_fusion_blocks', 4))
        ])
        
        # 5. Action prediction head
        self.action_head = nn.Sequential(
            nn.Linear(llm_hidden_dim, llm_hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(llm_hidden_dim * 2, cfg.action_bins)
        )
        
        # Learnable aggregation token for pooling multimodal features
        self.aggregation_token = nn.Parameter(torch.randn(1, 1, llm_hidden_dim))
        
    def encode_text(self, text_instructions):
        """
        Encode text instructions using T5 encoder
        Args:
            text_instructions: List of text strings
        Returns:
            text_embeddings: (B, T, hidden_dim)
        """
        # Tokenize text
        encoded = self.tokenizer(
            text_instructions,
            padding=True,
            truncation=True,
            max_length=self.cfg.get('max_text_length', 64),
            return_tensors='pt'
        ).to(next(self.text_encoder.parameters()).device)
        
        # Get T5 embeddings
        with torch.no_grad() if self.cfg.get('freeze_text_encoder', True) else torch.enable_grad():
            outputs = self.text_encoder(**encoded)
            text_embeddings = outputs.last_hidden_state  # (B, seq_len, d_model)
        
        # Project to LLM hidden dimension
        text_embeddings = self.text_projection(text_embeddings)  # (B, seq_len, llm_hidden_dim)
        
        return text_embeddings
    
    def encode_vision(self, current_obs, goal_obs):
        """
        Encode current and goal images using CNN
        Args:
            current_obs: (B, C, H, W) current observation images
            goal_obs: (B, C, H, W) goal images
        Returns:
            vision_embeddings: (B, 2, hidden_dim) - 2 tokens for current and goal
        """
        # Encode both images
        current_features = self.vision_encoder(current_obs)  # (B, hidden_dim)
        goal_features = self.vision_encoder(goal_obs)  # (B, hidden_dim)
        
        # Project features
        current_features = self.vision_projection(current_features)
        goal_features = self.vision_projection(goal_features)
        
        # Stack to create sequence: [current, goal]
        vision_embeddings = torch.stack([current_features, goal_features], dim=1)  # (B, 2, hidden_dim)
        
        return vision_embeddings
    
    def forward(self, images, goal_texts, goal_images, targets=None):
        """
        Forward pass
        Args:
            images: (B, C, H, W) current observation images (in [-1, 1])
            goal_texts: List of B text strings describing the goal
            goal_images: (B, C, H, W) goal images (in [-1, 1])
            targets: (B, action_dim) target actions for training
        Returns:
            action_predictions: (B, action_dim)
            loss: scalar loss if targets provided
        """
        batch_size = images.shape[0]
        device = images.device
        
        # 1. Encode text instructions
        text_embeddings = self.encode_text(goal_texts)  # (B, T_text, hidden_dim)
        
        # 2. Encode vision (current + goal images)
        ## Need to update image to channel first
        # images = images.permute(0, 3, 1, 2)  # (B, C, H, W)
        # goal_images = goal_images.permute(0, 3, 1, 2)  # (B, C, H, W)
        vision_embeddings = self.encode_vision(images, goal_images)  # (B, 2, hidden_dim)
        
        # 3. Concatenate all tokens: [text_tokens, vision_tokens, aggregation_token]
        agg_token = self.aggregation_token.expand(batch_size, -1, -1)  # (B, 1, hidden_dim)
        multimodal_tokens = torch.cat([
            text_embeddings,
            vision_embeddings,
            agg_token
        ], dim=1)  # (B, T_text + 2 + 1, hidden_dim)
        
        # 4. Apply fusion transformer layers
        fused_features = multimodal_tokens
        for fusion_layer in self.fusion_layers:
            fused_features = fusion_layer(fused_features)
        
        # 5. Extract aggregation token for action prediction
        action_features = fused_features[:, -1, :]  # (B, hidden_dim)
        
        # 6. Predict actions
        action_predictions = self.action_head(action_features)  # (B, action_dim)
        
        # 7. Compute loss if targets provided
        loss = None
        if targets is not None:
            loss = F.mse_loss(action_predictions, targets)
        
        return action_predictions, loss

print("VLA Model defined successfully")

VLA Model defined successfully


In [8]:
## Helper function to use VLA model instead of GRP
def create_vla_model(cfg, device='cuda'):
    """
    Create and initialize VLA model to replace GRP
    """
    # Add VLA-specific config if not present
    if not hasattr(cfg, 'freeze_text_encoder'):
        cfg.freeze_text_encoder = True  # Freeze T5 by default
    if not hasattr(cfg, 'n_fusion_blocks'):
        cfg.n_fusion_blocks = 4  # Number of fusion transformer blocks
    if not hasattr(cfg, 'max_text_length'):
        cfg.max_text_length = 64  # Max text sequence length
    
    # Create VLA model
    model = VLAModel(
        cfg=cfg,
        t5_model_name='google/t5-v1_1-small',  # Can change to 'base' or 'large'
        llm_hidden_dim=512
    )
    
    model = model.to(device)
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params/1e6:.2f}M")
    print(f"Trainable parameters: {trainable_params/1e6:.2f}M")
    
    return model

## Modified training function for VLA
def train_vla_model(model, dataset, cfg, device='cuda'):
    """
    Training loop for VLA model
    Note: VLA expects text strings, not encoded characters
    """
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=float(cfg.learning_rate)
    )
    
    for iter in range(cfg.max_iters):
        # Evaluate loss periodically
        if iter % cfg.eval_interval == 0 or iter == cfg.max_iters - 1:
            model.eval()
            with torch.no_grad():
                # Sample validation batch
                indices = np.random.choice(len(dataset["img"]), cfg.batch_size, replace=False)
                val_images = dataset["image_enc"][indices].permute(0, 3, 1, 2)  # (B, C, H, W)
                val_goal_texts = [dataset["goal"][i] for i in indices]
                val_goal_images = dataset["goal_image_enc"][indices].permute(0, 3, 1, 2)
                val_actions = dataset["action_enc"][indices]
                
                _, val_loss = model(val_images, val_goal_texts, val_goal_images, val_actions)
                print(f"step {iter}: val loss {val_loss.item():.4f}")
            model.train()
        
        # Sample training batch
        indices = np.random.choice(len(dataset["img"]), cfg.batch_size, replace=False)
        images = dataset["image_enc"][indices].permute(0, 3, 1, 2)  # (B, C, H, W)
        goal_texts = [dataset["goal"][i] for i in indices]  # List of strings
        goal_images = dataset["goal_image_enc"][indices].permute(0, 3, 1, 2)
        actions = dataset["action_enc"][indices]
        
        # Forward pass
        predictions, loss = model(images, goal_texts, goal_images, actions)
        
        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if iter % 100 == 0:
            print(f"step {iter}: train loss {loss.item():.4f}")
    
    return model

print("VLA training utilities defined")

VLA training utilities defined


In [9]:
## Example: Create and test VLA model
## Uncomment to run VLA instead of GRP


# Create VLA model
from box import Box
import yaml

# Load config
with open('./conf/config.yaml', 'r') as f:
    cfg_dict = yaml.safe_load(f)    
cfg = Box(cfg_dict)

# Prepare data encodings
a_std, a_mean = (dataset["action"].std(axis=0) + 0.001) * 1.5, dataset["action"].mean(axis=0)
cfg.action_bins = len(a_mean)
encode_action = lambda af: (((af - a_mean)/(a_std))).astype(np.float32)
encode_state = lambda af: ((af/(255.0)*2.0)-1.0).astype(np.float32)

# Encode dataset for VLA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset["image_enc"] = torch.tensor(encode_state(dataset["img"])).to(device)
dataset["goal_image_enc"] = torch.tensor(encode_state(dataset["goal_img"])).to(device)
dataset["action_enc"] = torch.tensor(encode_action(dataset["action"]), dtype=torch.float).to(device)

# Initialize VLA model
vla_model = create_vla_model(cfg, device=device)

# Test forward pass with a small batch
test_batch_size = 4
test_indices = np.random.choice(len(dataset["img"]), test_batch_size, replace=False)
test_images = dataset["image_enc"][test_indices].permute(0, 3, 1, 2)
test_goal_texts = [dataset["goal"][i] for i in test_indices]
test_goal_images = dataset["goal_image_enc"][test_indices].permute(0, 3, 1, 2)
test_actions = dataset["action_enc"][test_indices]

print(f"Test batch shapes:")
print(f"  Images: {test_images.shape}")
print(f"  Goal texts: {len(test_goal_texts)} strings")
print(f"  Goal images: {test_goal_images.shape}")
print(f"  Actions: {test_actions.shape}")

# Forward pass
predictions, loss = vla_model(test_images, test_goal_texts, test_goal_images, test_actions)
print(f"\\nVLA Model forward pass successful!")
print(f"  Predictions shape: {predictions.shape}")
print(f"  Loss: {loss.item():.4f}")

# Train VLA model (uncomment to train)
vla_model = train_vla_model(vla_model, dataset, cfg, device=device)


print("VLA example code ready (uncomment to run)")

Loading T5 model: google/t5-v1_1-small


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


T5 text encoder frozen
Total parameters: 51.48M
Trainable parameters: 16.15M
Test batch shapes:
  Images: torch.Size([4, 3, 64, 64])
  Goal texts: 4 strings
  Goal images: torch.Size([4, 3, 64, 64])
  Actions: torch.Size([4, 7])
\nVLA Model forward pass successful!
  Predictions shape: torch.Size([4, 7])
  Loss: 0.2945
step 0: val loss 0.6108
step 0: train loss 0.4853
step 100: val loss 0.4716
step 100: train loss 0.3479
step 200: val loss 0.2215
step 200: train loss 0.1673
step 300: val loss 0.1098
step 300: train loss 0.1197
step 400: val loss 0.0907
step 400: train loss 0.0824
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3222185/551402741.py", line 50, in <module>
    vla_model = train_vla_model(vla_model, dataset, cfg, device=device)
  File "/tmp/ipykernel_3222185/1077614655.py", line 70, in train_vla_model
    loss.backward()
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run 

## Key Differences: VLA vs GRP

### VLA Model Architecture:
1. **Text Encoding**: Uses pretrained T5 encoder instead of character-level embeddings
   - Better semantic understanding of instructions
   - Transfer learning from T5's pretraining
   - Can be frozen to save compute

2. **Vision Encoding**: Uses CNN instead of patch-based ViT approach
   - More parameter efficient
   - Better for smaller image sizes (64x64)
   - Separate encoding of current and goal images

3. **Multimodal Fusion**: Uses cross-attention transformer blocks
   - Learns interactions between vision and language
   - Aggregation token for pooling multimodal information
   - More flexible than simple concatenation

4. **Action Prediction**: MLP head on fused features
   - Direct regression to continuous actions
   - No vocabulary binning needed

### Advantages:
- ✅ Leverages pretrained language models (T5)
- ✅ More parameter efficient with frozen encoders
- ✅ Better language understanding
- ✅ Modular design (easy to swap components)
- ✅ Can handle variable-length text naturally

### Usage:
Replace `model = GRP(dataset, cfg)` with `model = create_vla_model(cfg)` in the training loop.

# A VLA from a VLM

Finally, recent models are basing their structures off of Vision Language Models.

- The challenges with using an LLM and having to decode the action space and figure out how to put in the images and which token in the language space to replace with the new image tokens. These don't scale well.

An example of [Pi0](https://www.pi.website/blog/pi0).

In [18]:
## PaliGemma3B-based VLA model (image + goal image + text -> continuous actions)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

class PaliGemmaVLA(nn.Module):
    """
    VLA model using PaliGemma backbone with a separate action head.
    It fuses text + concatenated current/goal images, then predicts continuous actions.
    """
    def __init__(self, cfg, model_name='google/paligemma-3b-pt-224', freeze_backbone=True):
        super().__init__()
        self.cfg = cfg
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.backbone = PaliGemmaForConditionalGeneration.from_pretrained(model_name)
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        hidden_size = self.backbone.config.hidden_size
        self.action_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(hidden_size * 2, cfg.action_bins),
        )

    def _prepare_images(self, images, goal_images):
        # Convert [-1, 1] float tensors to uint8 HWC and concatenate along width.
        images_u8 = ((images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8)
        goals_u8 = ((goal_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8)
        images_u8 = images_u8.permute(0, 2, 3, 1).cpu().numpy()
        goals_u8 = goals_u8.permute(0, 2, 3, 1).cpu().numpy()
        combined = [np.concatenate([img, goal], axis=1) for img, goal in zip(images_u8, goals_u8)]
        return combined

    def forward(self, images, goal_texts, goal_images, targets=None):
        combined_images = self._prepare_images(images, goal_images)
        inputs = self.processor(
            text=goal_texts,
            images=combined_images,
            return_tensors='pt',
            padding=True,
        ).to(self.backbone.device)
        outputs = self.backbone(**inputs, output_hidden_states=True, return_dict=True)
        last_hidden = outputs.hidden_states[-1]  # (B, seq_len, hidden)
        pooled = last_hidden[:, -1, :]
        action_predictions = self.action_head(pooled)
        loss = None
        if targets is not None:
            loss = F.mse_loss(action_predictions, targets)
        return action_predictions, loss

def create_paligemma_vla_model(cfg, device='cuda', freeze_backbone=True):
    if not hasattr(cfg, 'max_text_length'):
        cfg.max_text_length = 64
    model = PaliGemmaVLA(
        cfg=cfg,
        model_name='google/paligemma-3b-pt-224',
        freeze_backbone=freeze_backbone,
    )
    model = model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params/1e6:.2f}M")
    print(f"Trainable parameters: {trainable_params/1e6:.2f}M")
    return model

def train_paligemma_vla_model(model, dataset, cfg, device='cuda'):
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=float(cfg.learning_rate),
    )
    for iter in range(cfg.max_iters):
        if iter % cfg.eval_interval == 0 or iter == cfg.max_iters - 1:
            model.eval()
            with torch.no_grad():
                indices = np.random.choice(len(dataset["img"]), cfg.batch_size, replace=False)
                val_images = dataset["image_enc"][indices].permute(0, 3, 1, 2)
                val_goal_texts = [dataset["goal"][i] for i in indices]
                val_goal_images = dataset["goal_image_enc"][indices].permute(0, 3, 1, 2)
                val_actions = dataset["action_enc"][indices]
                _, val_loss = model(val_images, val_goal_texts, val_goal_images, val_actions)
                print(f"step {iter}: val loss {val_loss.item():.4f}")
            model.train()
        indices = np.random.choice(len(dataset["img"]), cfg.batch_size, replace=False)
        images = dataset["image_enc"][indices].permute(0, 3, 1, 2)
        goal_texts = [dataset["goal"][i] for i in indices]
        goal_images = dataset["goal_image_enc"][indices].permute(0, 3, 1, 2)
        actions = dataset["action_enc"][indices]
        predictions, loss = model(images, goal_texts, goal_images, actions)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if iter % 100 == 0:
            print(f"step {iter}: train loss {loss.item():.4f}")
    return model

print("PaliGemma VLA utilities defined")

PaliGemma VLA utilities defined


In [20]:
## Example: Create and test VLA model
## Uncomment to run VLA instead of GRP


# Create VLA model
from box import Box
import yaml

# Load config
with open('./conf/config.yaml', 'r') as f:
    cfg_dict = yaml.safe_load(f)    
cfg = Box(cfg_dict)

# Prepare data encodings
a_std, a_mean = (dataset["action"].std(axis=0) + 0.001) * 1.5, dataset["action"].mean(axis=0)
cfg.action_bins = len(a_mean)
encode_action = lambda af: (((af - a_mean)/(a_std))).astype(np.float32)
encode_state = lambda af: ((af/(255.0)*2.0)-1.0).astype(np.float32)

# Encode dataset for VLA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset["image_enc"] = torch.tensor(encode_state(dataset["img"])).to(device)
dataset["goal_image_enc"] = torch.tensor(encode_state(dataset["goal_img"])).to(device)
dataset["action_enc"] = torch.tensor(encode_action(dataset["action"]), dtype=torch.float).to(device)

# Initialize VLA model
vla_model = create_paligemma_vla_model(cfg, device=device)

# Test forward pass with a small batch
test_batch_size = 4
test_indices = np.random.choice(len(dataset["img"]), test_batch_size, replace=False)
test_images = dataset["image_enc"][test_indices].permute(0, 3, 1, 2)
test_goal_texts = [dataset["goal"][i] for i in test_indices]
test_goal_images = dataset["goal_image_enc"][test_indices].permute(0, 3, 1, 2)
test_actions = dataset["action_enc"][test_indices]

print(f"Test batch shapes:")
print(f"  Images: {test_images.shape}")
print(f"  Goal texts: {len(test_goal_texts)} strings")
print(f"  Goal images: {test_goal_images.shape}")
print(f"  Actions: {test_actions.shape}")

# Forward pass
predictions, loss = vla_model(test_images, test_goal_texts, test_goal_images, test_actions)
print(f"  VLA Model forward pass successful!")
print(f"  Predictions shape: {predictions.shape}")
print(f"  Loss: {loss.item():.4f}")

# Train VLA model (uncomment to train)
vla_model = train_paligemma_vla_model(vla_model, dataset, cfg, device=device)


print("VLA example code ready (uncomment to run)")

Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.46s/it]
You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Total parameters: 2931.89M
Trainable parameters: 8.42M
Test batch shapes:
  Images: torch.Size([4, 3, 64, 64])
  Goal texts: 4 strings
  Goal images: torch.Size([4, 3, 64, 64])
  Actions: torch.Size([4, 7])
  VLA Model forward pass successful!
  Predictions shape: torch.Size([4, 7])
  Loss: 0.7738
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_1859699/988130187.py", line 50, in <module>
    vla_model = train_paligemma_vla_model(vla_model, dataset, cfg, device=device)
  File "/tmp/ipykernel_1859699/3479941056.py", line 84, in train_paligemma_vla_model
    _, val_loss = model(val_images, val_goal_texts, val_goal_images, val_actions)
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mila/g/glen.berseth/.conda/envs/roble/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_1859699/3479941056.py", line 46, in forward
    outputs = self.backbon

In [None]:
# pip install python-boximport hydra, json

def my_main():
    from hydra import compose, initialize
    # initialize(config_path="./conf", job_name="test_app")
    from box import Box
    import yaml
    with open('./conf/config.yaml', 'r') as f:
        cfg_dict = yaml.safe_load(f)    
    cfg = Box(cfg_dict)
    torch.manual_seed(cfg.r_seed)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    cfg.block_size = shortest_goal_txt = min([len(txt) for txt in dataset["goal"]])

    # here are all the unique characters that occur in this text
    chars = sorted(list(set([item for row in dataset["goal"] for item in row]))) ## Flatten to a long string
    cfg.vocab_size = len(chars)
    # create a mapping from characters to integers
    stoi = { ch:i for i,ch in enumerate(chars) }
    itos = { i:ch for i,ch in enumerate(chars) }
    encode_txt = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
    decode_txy = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
    print("vocab_size:", cfg.vocab_size)
    print("example text encode:", encode_txt(dataset["goal"][0]))

    a_std, a_mean = (dataset["action"].std(axis=0) + 0.001) * 1.5, dataset["action"].mean(axis=0)
    cfg.action_bins = len(a_mean)
    encode_action = lambda af:   (((af - a_mean)/(a_std))).astype(np.float32) # encoder: take a float, output an integer

    ## Get the actions and encode them to map to [-1, 1]
    encode_state = lambda af:   ((af/(255.0)*2.0)-1.0).astype(np.float32) # encoder: take a float, output an integer
    resize_state = lambda sf:   cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1]))  # resize state
    decode_action = lambda binN: (binN * a_std) + a_mean  # Undo mapping to [-1, 1]

    dataset["goal_enc"] = torch.tensor([encode_txt(goal[:cfg.block_size]) for goal in dataset["goal"]]).to(device)
    dataset["image_enc"] = torch.tensor(encode_state(dataset["img"])).to(device)
    dataset["goal_image_enc"] = torch.tensor(encode_state(dataset["goal_img"])).to(device)
    dataset["action_enc"] = torch.tensor(encode_action(dataset["action"]), dtype=torch.float).to(device)

    print("Dataset shape:", len(dataset["img"]))

    model = GRP(dataset, cfg)
    m = model.to(device)
    # print the number of parameters in the model
    print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.learning_rate))

    for iter in range(cfg.max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % cfg.eval_interval == 0 or iter == cfg.max_iters - 1:
            losses = estimate_loss(model)
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # sample a batch of data
        xb, xg, xgi, yb = get_batch_grp('train', dataset, cfg.batch_size)

        # evaluate the loss
        logits, loss = model(xb, xg, xgi, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

if __name__ == "__main__":
    results = my_main()
    print("results:", results)