# Data Preprocessing

> Functions and definitions for preprocessing steps, including normalization stats, tokenization, and template formatting.

In [1]:
#| default_exp data.preprocessing

In [2]:
#| hide
from nbdev.showdoc import *
PARENT_PATH = 

In [8]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [37]:
#| export
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
from fastai.vision.augment import Normalize
from fastai.vision.all import *
from fastai.text.all import *
from fastai.data.transforms import Transform
from fastai.torch_core import TensorBase, tensor
import torch
from typing import List, Dict, Union, Tuple, Any

from llava.utils import load_config

## Step 1.2 (Continued): Image Normalization Setup

Load the CLIP image processor to get the correct normalization statistics (mean and standard deviation) required for the vision encoder.

In [10]:
#| export
# Load config to get model name
CONFIG_PATH = 'configs/config.yaml'
try:
    config = load_config(CONFIG_PATH)
    VISION_ENCODER_NAME = config['model']['vision_encoder_name_or_path']
    LLM_NAME = config['model']['llm_name_or_path'] # Added for tokenizer
except FileNotFoundError:
    print(f"Warning: Config file not found at {CONFIG_PATH}. Using default model names.")
    VISION_ENCODER_NAME = 'openai/clip-vit-large-patch14-336' # Fallback
    LLM_NAME = 'lmsys/vicuna-7b-v1.5' # Fallback
except KeyError as e:
    print(f"Warning: Key {e} not found in {CONFIG_PATH}. Using defaults.")
    VISION_ENCODER_NAME = config.get('model', {}).get('vision_encoder_name_or_path', 'openai/clip-vit-large-patch14-336')
    LLM_NAME = config.get('model', {}).get('llm_name_or_path', 'lmsys/vicuna-7b-v1.5')

# Load the CLIP image processor
try:
    clip_image_processor = AutoImageProcessor.from_pretrained(VISION_ENCODER_NAME)
    print(f"Successfully loaded CLIP image processor for: {VISION_ENCODER_NAME}")
except Exception as e:
    print(f"Error loading CLIP image processor for {VISION_ENCODER_NAME}: {e}")
    # Handle error appropriately, maybe raise or use default stats
    clip_image_processor = None

# Get normalization stats
if clip_image_processor:
    image_mean = clip_image_processor.image_mean
    image_std = clip_image_processor.image_std
else:
    print("Warning: Using default ImageNet stats as fallback for normalization.")
    # Default fallback (ImageNet stats often used, but CLIP specific is better)
    image_mean = [0.485, 0.456, 0.406]
    image_std = [0.229, 0.224, 0.225]

# Create the fastai Normalize transform using CLIP stats
clip_normalize = Normalize.from_stats(image_mean, image_std)

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336


In [11]:
# Example: Print the stats and the transform
print(f"CLIP Mean: {image_mean}")
print(f"CLIP Std: {image_std}")
print(f"Fastai Normalize Transform: {clip_normalize}")

CLIP Mean: [0.48145466, 0.4578275, 0.40821073]
CLIP Std: [0.26862954, 0.26130258, 0.27577711]
Fastai Normalize Transform: Normalize -- {'mean': tensor([[[[0.4815]],

         [[0.4578]],

         [[0.4082]]]], device='cuda:0'), 'std': tensor([[[[0.2686]],

         [[0.2613]],

         [[0.2758]]]], device='cuda:0'), 'axes': (0, 2, 3)}
(enc:1,dec:1)


---

## Step 1.3: Text Tokenization and Template Handling (Stage 1 - Plain)

Load the LLM's tokenizer (Vicuna) and define the 'plain' template formatting for Stage 1 pre-training.

In [12]:
#| export
DEFAULT_IMAGE_TOKEN = "<image>" # Placeholder token for image features
IMAGE_TOKEN_INDEX_PLACEHOLDER = -200 # Special marker used internally in input_ids
IGNORE_INDEX = -100 # Standard ignore index for labels in loss calculation

# Load the Vicuna tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(
        LLM_NAME,
        model_max_length=config.get('data', {}).get('tokenizer_model_max_length', 2048),
        padding_side=config.get('data', {}).get('tokenizer_padding_side', 'right'),
        use_fast=True,
    )
    print(f"Successfully loaded tokenizer for: {LLM_NAME}")

    # Add the image token as a special token, if it's not already there.
    # This ensures the tokenizer knows about it, even though we replace its embedding later.
    # Avoids warnings and potential token splitting issues.
    # Check if the token already exists
    current_vocab = tokenizer.get_vocab()
    if DEFAULT_IMAGE_TOKEN not in current_vocab:
        print(f"Adding special token {DEFAULT_IMAGE_TOKEN} to tokenizer.")
        # Add the token as a *special* token. This typically means it won't be split
        # and can be handled distinctly. `additional_special_tokens` is a common way.
        num_added = tokenizer.add_special_tokens({'additional_special_tokens': [DEFAULT_IMAGE_TOKEN]})
        if num_added > 0:
            print(f"Added {num_added} token(s). New vocab size: {len(tokenizer)}")
            # NOTE: If a token is truly added (increasing vocab size),
            # the model's embedding layer MUST be resized later.
            # For LLaVA, we usually map it to an existing ID or rely on the -200 replacement,
            # but adding it as a special token helps tokenizer processing.
    # Retrieve the assigned ID
    IMAGE_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
    print(f"Using token ID for {DEFAULT_IMAGE_TOKEN}: {IMAGE_TOKEN_ID}")
    # Handle case where it might still be UNK if not properly added/found
    if IMAGE_TOKEN_ID == tokenizer.unk_token_id:
         print(f"Warning: {DEFAULT_IMAGE_TOKEN} resolved to UNK token ID ({tokenizer.unk_token_id}). Check tokenizer setup.")


    # Set pad token if missing (common for LLaMA models)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token # Option 1: Use EOS
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Option 2: Add a new pad token
        print(f"Using EOS token as pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
        # Note: If adding a new pad token, the model's embedding layer needs resizing later!
    else:
        print(f"Tokenizer already has pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

except Exception as e:
    print(f"Error loading tokenizer for {LLM_NAME}: {e}")
    tokenizer = None
    IMAGE_TOKEN_ID = None

tokenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Adding special token <image> to tokenizer.
Added 1 token(s). New vocab size: 32001
Using token ID for <image>: 32000
Tokenizer already has pad token: <unk> (ID: 0)


In [13]:
type(tokenizer)

transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast

In [15]:
#| export
def format_plain_template(conversations: List[Dict[str, str]]) -> str:
    """Formats conversations using the 'plain' template for Stage 1 pre-training.

    The 'plain' template uses the format: <image>\n{caption}\n
    where {caption} is the value of the first 'gpt' turn.

    Args:
        conversations: A list of conversation turns (dictionaries with 'from' and 'value').

    Returns:
        The formatted string. Returns just the image token if no 'gpt' turn is found.
    """
    caption = "" # Default to empty caption if no 'gpt' turn found
    for turn in conversations:
        if turn.get('from', '').lower() == 'gpt':
            caption = turn.get('value', '')
            break # Use the first GPT response as the caption

    # Ensure the <image> token is always first, followed by newline and caption
    # Remove any existing <image> token from caption to avoid duplicates
    caption = caption.replace(DEFAULT_IMAGE_TOKEN, '').strip()

    # Return formatted string, handle empty caption case
    # Add final newline consistent with some implementations?
    # LLaVA reference often has <image>\n{caption}
    # Let's stick to <image>\n{caption} for now, stripping trailing whitespace.
    formatted = f"{DEFAULT_IMAGE_TOKEN}\n{caption}".strip() if caption else f"{DEFAULT_IMAGE_TOKEN}"
    return formatted

In [16]:
show_doc(format_plain_template)

---

### format_plain_template

>      format_plain_template (conversations:List[Dict[str,str]])

*Formats conversations using the 'plain' template for Stage 1 pre-training.

    The 'plain' template uses the format: <image>
{caption}

    where {caption} is the value of the first 'gpt' turn.

    Args:
        conversations: A list of conversation turns (dictionaries with 'from' and 'value').

    Returns:
        The formatted string. Returns just the image token if no 'gpt' turn is found.*

#### Example Usage & Test (Template Formatting)

In [17]:
conv1 = [{'from': 'human', 'value': '<image>\nDescribe.'}, {'from': 'gpt', 'value': 'This is the caption.'}]
conv2 = [{'from': 'human', 'value': 'Describe.'}, {'from': 'gpt', 'value': '<image>Caption with image token removed.'}]
conv3 = [{'from': 'human', 'value': '<image>\nDescribe.'}]
conv4 = []

print(f"Test Case 1 (Standard): {format_plain_template(conv1)}")
print(f"Test Case 2 (Image token in caption): {format_plain_template(conv2)}")
print(f"Test Case 3 (No GPT turn): {format_plain_template(conv3)}")
print(f"Test Case 4 (Empty conversation): {format_plain_template(conv4)}")

Test Case 1 (Standard): <image>
This is the caption.
Test Case 2 (Image token in caption): <image>
Caption with image token removed.
Test Case 3 (No GPT turn): <image>
Test Case 4 (Empty conversation): <image>


In [22]:
#| export
class LLaVATextTokenizer(Transform):
    """A fastai Transform to format and tokenize text data for LLaVA stage 1.

    Applies the 'plain' template formatting and then tokenizes the text, returning only the input IDs.
    """
    def __init__(self, tokenizer, template_formatter=format_plain_template):
        store_attr()
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided and loaded successfully.")

    def encodes(self, conversations: list) -> list:
        """Applies formatting and tokenization to conversation data.

        Args:
            conversations: Raw conversation list from the dataset sample.

        Returns:
            A list of input token IDs.
        """
        formatted_text = self.template_formatter(conversations)
        # Tokenize the formatted text.
        # We don't pad here; padding is done at the batch level.
        tokenized_output = self.tokenizer(formatted_text,
                                         return_tensors=None, # Let batch collation handle tensor conversion + padding
                                         add_special_tokens=True, # Add BOS/EOS if tokenizer configured to do so
                                         truncation=False # Truncation can be done later if needed
                                        )
        # Return just the input_ids list for DataBlock item_tfms
        # We use TensorBase here so fastai recognizes it as tensor-like for collation
        # return TensorBase(tokenized_output['input_ids'])
        # Update: Returning raw list might be simpler, collation handles tensor conversion
        return tokenized_output['input_ids']

#### Example Usage & Test (Tokenizer Transform)

In [23]:
if tokenizer:
    # Example conversation
    example_conv = [{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}]

    # Create the transform instance
    llava_tokenizer_tfm = LLaVATextTokenizer(tokenizer)

    # Apply the transform
    tokenized_ids = llava_tokenizer_tfm(example_conv)

    print(f"Original Conversations: {example_conv}")
    # Re-format to show what was tokenized
    formatted = format_plain_template(example_conv)
    print(f"Formatted Text: {formatted}")
    print(f"Tokenized Output (Input IDs): {tokenized_ids}")

    # Decode for verification
    decoded_tokens = tokenizer.convert_ids_to_tokens(tokenized_ids)
    print(f"Decoded Tokens: {decoded_tokens}")

    # Ensure the output is a list of integers
    assert isinstance(tokenized_ids, list)
    assert all(isinstance(x, int) for x in tokenized_ids)
else:
    print("Tokenizer not loaded, skipping tokenizer transform test.")

Original Conversations: [{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}]
Formatted Text: <image>
It is red.
Tokenized Output (Input IDs): [1, 32000, 13, 3112, 338, 2654, 29889]
Decoded Tokens: ['<s>', '<image>', '<0x0A>', 'It', '▁is', '▁red', '.']


---

## Step 1.5: Implement Custom Batch Transform / Collate Function (Stage 1)

This transform handles batch-level operations: padding, masking, image token replacement, and normalization.

In [44]:
#| export
class LLaVABatchTransform(Transform):
    """ Custom batch transform for LLaVA stage 1.
        Handles image normalization, text padding reconstruction, attention mask creation,
        image token marker replacement, label creation, and label masking for the 'plain' template.
        Operates on the default collated batch tuple (images, list_of_positional_token_tensors).
    """
    def __init__(self, tokenizer, normalize_tfm: Normalize, image_token_id=None):
        store_attr() # Stores tokenizer, normalize_tfm
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided.")

        if image_token_id is None:
            self.image_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
            if self.image_token_id == tokenizer.unk_token_id:
                print(f"Warning: {DEFAULT_IMAGE_TOKEN} not found in tokenizer vocab. Using UNK ID: {self.image_token_id}")
        else:
            self.image_token_id = image_token_id

        self.pad_token_id = tokenizer.pad_token_id
        if self.pad_token_id is None:
            print(f"Warning: Tokenizer has no pad_token_id. Assuming fastai default padding (often 0) or using eos_token_id ({self.tokenizer.eos_token_id}) if specified.")
            # We will reconstruct the mask based on the actual pad ID used by fastai/tokenizer
            # Let's assume the tokenizer's pad_token_id IS defined now after previous setup steps
            # If not, add error handling here.
            if self.tokenizer.pad_token_id is None:
                 raise ValueError("Tokenizer must have a pad_token_id after setup.")
            self.pad_token_id = self.tokenizer.pad_token_id

        print(f"LLaVABatchTransform initialized. Image Token ID: {self.image_token_id}, Pad Token ID: {self.pad_token_id}")

    def encodes(self, collated_batch: tuple) -> dict:
        """Applies normalization, reconstructs padded tensors, applies masking.

        Args:
            collated_batch: A tuple containing:
                (collated_image_tensors: torch.Tensor,
                 list_of_positional_token_tensors: List[torch.Tensor])
                as produced by default fastai collation for sequences.

        Returns:
            A dictionary containing the fully processed batch ready for the model.
        """
        # 1. Unpack the collated batch
        if not isinstance(collated_batch, tuple) or len(collated_batch) != 2:
             print(f"Warning: LLaVABatchTransform received unexpected input type: {type(collated_batch)}. Skipping.")
             return collated_batch

        collated_images, list_of_positional_tensors = collated_batch

        # Ensure correct types
        if not isinstance(collated_images, torch.Tensor):
             raise TypeError(f"Expected first element of collated batch to be a Tensor, got {type(collated_images)}")
        if not isinstance(list_of_positional_tensors, list) or not all(isinstance(t, torch.Tensor) for t in list_of_positional_tensors):
             raise TypeError(f"Expected second element of collated batch to be a list of Tensors, got {type(list_of_positional_tensors)}")
        if not list_of_positional_tensors: # Handle empty batch case if it occurs
            print("Warning: Received empty list of positional tensors.")
            # Return empty dict or handle appropriately
            return {}


        # 2. Normalize images
        normalized_images = self.normalize_tfm(collated_images) # Use the stored normalize_tfm

        # 3. Reconstruct padded input_ids tensor from the list of positional tensors
        # Stack along a new dimension (dim=0) -> (seq_len, batch_size)
        # Transpose to get (batch_size, seq_len)
        try:
            input_ids = torch.stack(list_of_positional_tensors, dim=0).T
        except RuntimeError as e:
             print("Error stacking positional tensors. Check consistency.")
             # Print shapes for debugging
             for i, t in enumerate(list_of_positional_tensors): print(f"Tensor {i} shape: {t.shape}")
             raise e


        # 4. Create attention mask based on the reconstructed input_ids and pad_token_id
        # Assumes fa_collate used the tokenizer's pad_token_id (or default 0 if none specified - check!)
        attention_mask = (input_ids != self.pad_token_id).long()

        # 5. Create labels by cloning input_ids BEFORE replacement
        labels = input_ids.clone()

        # 6. Find image token ID and replace with placeholder -200 in input_ids
        input_ids[input_ids == self.image_token_id] = IMAGE_TOKEN_INDEX_PLACEHOLDER

        # 7. Apply label masking for 'plain' template
        for i in range(labels.shape[0]):
            image_token_indices = torch.where(labels[i] == self.image_token_id)[0]
            if len(image_token_indices) > 0:
                image_token_idx = image_token_indices[0].item()
                mask_until_idx = image_token_idx + 1
                labels[i, :mask_until_idx] = IGNORE_INDEX
            else:
                print(f"Warning: Image token ID {self.image_token_id} not found in labels for sample {i}. Masking all.")
                labels[i, :] = IGNORE_INDEX

            # Also mask padding tokens based on the attention mask we created
            labels[i][attention_mask[i] == 0] = IGNORE_INDEX

            # Specific check: Mask the BOS token if present
            if labels.shape[1] > 0 and labels[i, 0] == self.tokenizer.bos_token_id:
                labels[i, 0] = IGNORE_INDEX

        # 8. Return the prepared batch as a dictionary
        return {
            'pixel_values': normalized_images,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

    # decodes method remains the same as the last fix
    def decodes(self, batch: dict) -> tuple:
        """Decodes a batch dictionary back into a tuple of (images, texts)."""
        decoded_images = []
        decoded_texts = []
        if not isinstance(batch, dict) or 'pixel_values' not in batch or 'input_ids' not in batch:
             print(f"Decode expected dict with 'pixel_values' and 'input_ids', got {type(batch)}")
             return ([], []) # Return empty tuple of lists

        imgs = batch['pixel_values']
        ids = batch['input_ids']
        bs = imgs.shape[0]

        for i in range(bs):
            # Decode image
            img_decoded = self.normalize_tfm.decode(imgs[i].unsqueeze(0).cpu())[0]

            # Decode text
            ids_i = ids[i].clone()
            ids_i[ids_i == IMAGE_TOKEN_INDEX_PLACEHOLDER] = self.image_token_id
            # Use the *known* pad token ID for filtering
            actual_ids = ids_i[ids_i != self.pad_token_id].tolist()
            text_decoded = self.tokenizer.decode(actual_ids, skip_special_tokens=True)

            decoded_images.append(img_decoded)
            decoded_texts.append(TitledStr(text_decoded))

        return (decoded_images, decoded_texts)

# Make the transform usable in fastai pipelines
LLaVABatchTransform.split_idx = None

#### Example Usage & Test (Batch Transform)

In [46]:
#| eval: false
if tokenizer and 'llava_tokenizer_tfm' in locals(): # Ensure tokenizer and text transform are available
    # 1. Create dummy item samples (output of item transforms)
    try:
        sample1_img = torch.rand(3, 336, 336) # Dummy image tensor
        sample1_text_ids = llava_tokenizer_tfm([{'from': 'human', 'value': '<image>'}, {'from': 'gpt', 'value': 'It is red.'}])
        sample1 = (sample1_img, sample1_text_ids)

        sample2_img = torch.rand(3, 336, 336)
        # Example with slightly longer text
        sample2_text_ids = llava_tokenizer_tfm([{'from': 'human', 'value': '<image>'}, {'from': 'gpt', 'value': 'Is it green?'}])
        sample2 = (sample2_img, sample2_text_ids)

        dummy_samples = [sample1, sample2]

        # --- Simulate Collation ---
        # This is what fastai's default collate function would roughly do
        collated_images = torch.stack([s[0] for s in dummy_samples])
        list_of_token_ids_lists = [s[1] for s in dummy_samples]
        simulated_collated_batch = (collated_images, list_of_token_ids_lists)
        # --------------------------

        # 2. Instantiate the batch transform
        batch_transform = LLaVABatchTransform(tokenizer, normalize_tfm=clip_normalize)

        # 3. Apply the transform to the *simulated collated batch*
        processed_batch = batch_transform(simulated_collated_batch)

        # 4. Inspect the output
        print("--- Input Samples (Individual Items) ---")
        for i, s in enumerate(dummy_samples):
            print(f"Sample {i}:")
            print(f"  Image shape: {s[0].shape}")
            print(f"  Token IDs: {s[1]}")

        print("\n--- Simulated Collated Batch (Input to Transform) ---")
        print(f"Images Tensor Shape: {simulated_collated_batch[0].shape}")
        print(f"List of Token Lists Length: {len(simulated_collated_batch[1])}")

        print("\n--- Processed Batch (Output of Transform) ---")
        if processed_batch and isinstance(processed_batch, dict): # Check if transform returned a dict
            print(f"Pixel Values Shape: {processed_batch.get('pixel_values', torch.Tensor()).shape}")
            print(f"Input IDs Shape: {processed_batch.get('input_ids', torch.Tensor()).shape}")
            print(f"Input IDs:\n{processed_batch.get('input_ids')}")
            print(f"Attention Mask Shape: {processed_batch.get('attention_mask', torch.Tensor()).shape}")
            print(f"Attention Mask:\n{processed_batch.get('attention_mask')}")
            print(f"Labels Shape: {processed_batch.get('labels', torch.Tensor()).shape}")
            print(f"Labels:\n{processed_batch.get('labels')}")

            print("\n--- Decoded Labels (showing non-masked tokens) ---")
            if 'labels' in processed_batch:
                for i in range(processed_batch['labels'].shape[0]):
                    label_ids = processed_batch['labels'][i]
                    valid_label_ids = label_ids[label_ids != IGNORE_INDEX].tolist()
                    decoded_labels = tokenizer.convert_ids_to_tokens(valid_label_ids)
                    cleaned_labels = [t for t in decoded_labels if t not in tokenizer.all_special_tokens]
                    print(f"Sample {i} Labels: {cleaned_labels}")
            else:
                print("Labels key missing in processed batch.")


            # Basic Assertions
            assert 'pixel_values' in processed_batch
            assert 'input_ids' in processed_batch
            assert 'attention_mask' in processed_batch
            assert 'labels' in processed_batch
            assert processed_batch['input_ids'].shape == processed_batch['attention_mask'].shape == processed_batch['labels'].shape
            # Check if image token replacement happened
            assert torch.any(processed_batch['input_ids'] == IMAGE_TOKEN_INDEX_PLACEHOLDER)
            # Check if labels are masked
            assert torch.any(processed_batch['labels'] == IGNORE_INDEX)
        else:
            print("Batch processing failed or returned unexpected type.")
            print(f"Output: {processed_batch}")

    except Exception as e:
        import traceback
        print("\n--- ERROR DURING TEST ---")
        traceback.print_exc()

else:
    print("Tokenizer not loaded or llava_tokenizer_tfm not defined, skipping batch transform test.")

LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 0
--- Input Samples (Individual Items) ---
Sample 0:
  Image shape: torch.Size([3, 336, 336])
  Token IDs: [1, 32000, 13, 3112, 338, 2654, 29889]
Sample 1:
  Image shape: torch.Size([3, 336, 336])
  Token IDs: [1, 32000, 13, 3624, 372, 7933, 29973]

--- Simulated Collated Batch (Input to Transform) ---
Images Tensor Shape: torch.Size([2, 3, 336, 336])
List of Token Lists Length: 2

--- Processed Batch (Output of Transform) ---
Batch processing failed or returned unexpected type.
Output: (tensor([[[[0.8321, 0.0670, 0.9668,  ..., 0.9698, 0.0592, 0.2165],
          [0.5144, 0.4863, 0.4180,  ..., 0.1336, 0.7944, 0.5811],
          [0.1692, 0.7337, 0.7632,  ..., 0.6975, 0.9847, 0.7347],
          ...,
          [0.7079, 0.1073, 0.6360,  ..., 0.4958, 0.9614, 0.1172],
          [0.6989, 0.4944, 0.5440,  ..., 0.0147, 0.0686, 0.7959],
          [0.6766, 0.0987, 0.3949,  ..., 0.7271, 0.8536, 0.5163]],

         [[0.5814, 0.853

---

## Step 4.1: Update Data Handling for Stage 2 (Placeholder)

This section will adapt text processing for the Vicuna v1 template and update label masking logic.

In [28]:
# Placeholder for format_v1_template function
# Placeholder for updated LLaVABatchTransform logic for v1 template

In [47]:
#| hide
import nbdev; nbdev.nbdev_export()