# Data Preprocessing

> Functions and definitions for preprocessing steps, including normalization stats, tokenization, template formatting, and batch transformations.

In [None]:
#| default_exp data.preprocessing

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [None]:
#| export
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
from fastai.vision.augment import Normalize
from fastai.vision.all import *
from fastai.text.all import *
from fastai.data.transforms import Transform
from fastai.torch_core import TensorBase, tensor
import torch
from typing import List, Dict, Union, Tuple, Any, Optional
from torch.nn.utils.rnn import pad_sequence
import copy

# Attempt to import from llava utils, handle potential ImportError if running standalone
try:
    from llava.utils import load_config
except ImportError:
    print("Warning: llava.utils not found. load_config function might be unavailable.")
    def load_config(path): return {}

# Import conversation handling logic (adapt from LLaVA reference or define here)
# For now, let's define a simple structure based on Vicuna v1 description
from llava.conversation import conv_templates, get_conv_template, SeparatorStyle # Assuming this exists based on reference

## Constants

In [None]:
#| export
DEFAULT_IMAGE_TOKEN = "<image>" # Placeholder token for image features
IMAGE_TOKEN_INDEX_PLACEHOLDER = -200 # Special marker used internally in input_ids
IGNORE_INDEX = -100 # Standard ignore index for labels in loss calculation

## Load Config and Initialize Processors/Tokenizers

In [None]:
#| export
# --- Configuration Loading --- 
CONFIG_PATH = 'configs/config.yaml'
config = {}
try:
    config = load_config(CONFIG_PATH)
    print(f"Loaded config from {CONFIG_PATH}")
except FileNotFoundError:
    print(f"Warning: Config file not found at {CONFIG_PATH}. Using default model names.")
except Exception as e:
    print(f"Warning: Error loading config from {CONFIG_PATH}: {e}. Using defaults.")

# Get model names from config or use defaults
VISION_ENCODER_NAME = config.get('model', {}).get('vision_encoder_name_or_path', 'openai/clip-vit-large-patch14-336')
LLM_NAME = config.get('model', {}).get('llm_name_or_path', 'lmsys/vicuna-7b-v1.5')
TOKENIZER_MAX_LEN = config.get('data', {}).get('tokenizer_model_max_length', 2048)
TOKENIZER_PADDING_SIDE = config.get('data', {}).get('tokenizer_padding_side', 'right')

# --- Image Processor and Normalization --- 
clip_image_processor = None
image_mean = [0.485, 0.456, 0.406] # Default ImageNet stats
image_std = [0.229, 0.224, 0.225]
try:
    clip_image_processor = AutoImageProcessor.from_pretrained(VISION_ENCODER_NAME)
    image_mean = clip_image_processor.image_mean
    image_std = clip_image_processor.image_std
    print(f"Successfully loaded CLIP image processor for: {VISION_ENCODER_NAME}")
except Exception as e:
    print(f"Warning: Error loading CLIP image processor for {VISION_ENCODER_NAME}: {e}. Using default ImageNet stats.")

clip_normalize = Normalize.from_stats(image_mean, image_std)
print(f"CLIP Mean: {image_mean}")
print(f"CLIP Std: {image_std}")
print(f"Fastai Normalize Transform: {clip_normalize}")

# --- Tokenizer --- 
tokenizer = None
IMAGE_TOKEN_ID = None
try:
    tokenizer = AutoTokenizer.from_pretrained(
        LLM_NAME,
        model_max_length=TOKENIZER_MAX_LEN,
        padding_side=TOKENIZER_PADDING_SIDE,
        use_fast=True,
    )
    print(f"Successfully loaded tokenizer for: {LLM_NAME}")

    current_vocab = tokenizer.get_vocab()
    if DEFAULT_IMAGE_TOKEN not in current_vocab:
        print(f"Adding special token {DEFAULT_IMAGE_TOKEN} to tokenizer.")
        num_added = tokenizer.add_special_tokens({'additional_special_tokens': [DEFAULT_IMAGE_TOKEN]})
        if num_added > 0:
            print(f"Added {num_added} token(s). New vocab size: {len(tokenizer)}")
    
    IMAGE_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
    print(f"Using token ID for {DEFAULT_IMAGE_TOKEN}: {IMAGE_TOKEN_ID}")
    if IMAGE_TOKEN_ID == tokenizer.unk_token_id:
         print(f"Warning: {DEFAULT_IMAGE_TOKEN} resolved to UNK token ID ({tokenizer.unk_token_id}). Check tokenizer setup.")
         # Attempt to force it if necessary and if vocab doesn't contain it
         # This is risky if the ID is already used
         if DEFAULT_IMAGE_TOKEN not in current_vocab:
              IMAGE_TOKEN_ID = len(tokenizer) - 1 # Use the newly added token ID
              print(f"Using explicitly added token ID: {IMAGE_TOKEN_ID}")

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token # Use UNK as pad if no PAD exists (like Llama-2)
        print(f"Set pad token to UNK token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
    else:
        print(f"Tokenizer already has pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

except Exception as e:
    print(f"Fatal Error: Could not load tokenizer for {LLM_NAME}: {e}")
    # Handle error appropriately in a real application
    # For notebook execution, print warning and continue if possible
    tokenizer = None
    IMAGE_TOKEN_ID = None

Loaded config from configs/config.yaml


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336
CLIP Mean: [0.48145466, 0.4578275, 0.40821073]
CLIP Std: [0.26862954, 0.26130258, 0.27577711]
Fastai Normalize Transform: Normalize -- {'mean': tensor([[[[0.4815]],

         [[0.4578]],

         [[0.4082]]]], device='cuda:0'), 'std': tensor([[[[0.2686]],

         [[0.2613]],

         [[0.2758]]]], device='cuda:0'), 'axes': (0, 2, 3)}
(enc:2,dec:2)
Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Adding special token <image> to tokenizer.
Added 1 token(s). New vocab size: 32001
Using token ID for <image>: 32000
Tokenizer already has pad token: <unk> (ID: 0)


## Template Formatting Functions

In [None]:
#| export
def format_plain_template(conversations: List[Dict[str, str]], tokenizer: AutoTokenizer = tokenizer) -> str:
    """Formats conversations using the 'plain' template for Stage 1 pre-training.
    
    The 'plain' template uses the format: <image>\n{caption}\n
    where {caption} is the value of the first 'gpt' turn.
    Handles moving <image> token to the start if found elsewhere in the input.

    Args:
        conversations: A list of conversation turns (dictionaries with 'from' and 'value').
        tokenizer: The tokenizer instance (needed for special tokens, though not used here).

    Returns:
        The formatted string. Returns just the image token if no 'gpt' turn is found.
    """
    caption = ""
    image_token_found = False
    first_human_turn_value = None

    # First pass: find caption and check for image token
    for i, turn in enumerate(conversations):
        value = turn.get('value', '')
        if turn.get('from', '').lower() == 'gpt' and not caption: # Only take first caption
            caption = value
        if DEFAULT_IMAGE_TOKEN in value:
            image_token_found = True
        if turn.get('from', '').lower() == 'human' and first_human_turn_value is None:
             first_human_turn_value = value

    # Ensure <image> token is at the start, conceptually
    # Remove <image> from caption if present
    caption = caption.replace(DEFAULT_IMAGE_TOKEN, '').strip()

    # Construct final output: <image>\n{caption}
    formatted = f"{DEFAULT_IMAGE_TOKEN}\n{caption}".strip() if caption else DEFAULT_IMAGE_TOKEN
    return formatted


def format_v1_template(conversations: List[Dict[str, str]], tokenizer: AutoTokenizer = tokenizer) -> str:
    """Formats conversations using the Vicuna v1 template.

    Handles moving the <image> token to the beginning of the *first* human message.
    Uses the `conv_templates['v1']` structure.

    Args:
        conversations: A list of conversation turns.
        tokenizer: The tokenizer instance.

    Returns:
        The fully formatted prompt string according to Vicuna v1 template.
    """
    if 'v1' not in conv_templates:
         raise ValueError("Vicuna v1 conversation template ('v1') not found in conversation_lib.")

    # Create a deep copy to avoid modifying the template dictionary directly
    conv = copy.deepcopy(conv_templates['v1'])
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Preprocess: Move <image> token to the start of the first human turn
    processed_conversations = []
    image_token_moved = False
    for i, turn in enumerate(conversations):
        value = turn.get('value', '')
        from_role = turn.get('from', '').lower()

        new_turn = copy.deepcopy(turn)

        if DEFAULT_IMAGE_TOKEN in value:
            if from_role == 'human' and not image_token_moved:
                # Move to start of this turn, remove from original position
                new_turn['value'] = DEFAULT_IMAGE_TOKEN + '\n' + value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
                image_token_moved = True
            else:
                 # Remove image token if found elsewhere (e.g., GPT response or later human turn)
                 new_turn['value'] = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
        
        processed_conversations.append(new_turn)
    
    # If image token was never found, add it to the start of the first human turn if one exists
    if not image_token_moved:
         found_human = False
         for i, turn in enumerate(processed_conversations):
              if turn.get('from', '').lower() == 'human':
                   processed_conversations[i]['value'] = DEFAULT_IMAGE_TOKEN + '\n' + turn.get('value', '')
                   found_human = True
                   break
         # If no human turn exists, prepend <image>\n before the system prompt or start
         # This is unlikely for instruct data but handles edge cases.
         # However, standard LLaVA format assumes <image> is tied to a human turn.
         # Let's stick to adding it to the first human turn.
         # If no human turn exists, the template formatting will likely handle it appropriately
         # or it might indicate an issue with the input data format.
         if not found_human:
              print("Warning: No 'human' turn found to prepend <image> token to.")


    # Append conversations to the template
    for turn in processed_conversations:
        role_key = turn.get('from', '').lower()
        if role_key in roles:
            conv.append_message(roles[role_key], turn.get('value'))
        else:
            # Handle unknown roles if necessary, e.g., skip or raise error
            print(f"Warning: Skipping turn with unknown role '{role_key}'.")
            continue
    
    # Append the assistant prompt
    conv.append_message(roles['gpt'], None) 

    return conv.get_prompt()

#### Example Usage & Test (Template Formatting - V1)

In [None]:
#| test
conv1 = [
    {'from': 'human', 'value': '<image>\nDescribe this image.'},
    {'from': 'gpt', 'value': 'This is a red object.'},
    {'from': 'human', 'value': 'What shape is it?'},
    {'from': 'gpt', 'value': 'It is round.'},
    {'from': 'human', 'value': 'Thanks!'}
]
conv2 = [
    {'from': 'human', 'value': 'Describe this image.'},
    {'from': 'gpt', 'value': 'This is a green object.'},
    {'from': 'human', 'value': 'What shape is it <image> ?'},
    {'from': 'gpt', 'value': 'It is square.'},
    {'from': 'human', 'value': 'Thanks!'}
]
conv3 = [
    {'from': 'human', 'value': '<image>Describe.'},
    {'from': 'gpt', 'value': 'This is a blue object.'},
    {'from': 'human', 'value': 'Anything else?'},
    {'from': 'gpt', 'value': 'It might be <image> shiny.'},
    {'from': 'human', 'value': 'Ok'}
]

if tokenizer:
    print("--- Test Case 1: Standard --- ")
    print(format_v1_template(conv1, tokenizer))
    print("\n--- Test Case 2: Image token later --- ")
    print(format_v1_template(conv2, tokenizer))
    print("\n--- Test Case 3: Image token in GPT response (removed) --- ")
    print(format_v1_template(conv3, tokenizer))
else:
    print("Tokenizer not loaded, skipping v1 template test.")

--- Test Case 1: Standard --- 
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>
Describe this image. ASSISTANT: This is a red object.</s> USER: What shape is it? ASSISTANT: It is round.</s> USER: Thanks! ASSISTANT:

--- Test Case 2: Image token later --- 
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Describe this image. ASSISTANT: This is a green object.</s> USER: <image>
What shape is it  ? ASSISTANT: It is square.</s> USER: Thanks! ASSISTANT:

--- Test Case 3: Image token in GPT response (removed) --- 
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>
Describe. ASSISTANT: This is a blue object.</s> USER: Anything else? ASSISTANT: It mi

## Text Tokenization Transform

In [None]:
#| export
class LLaVATextTokenizer(Transform):
    """A fastai Transform to format and tokenize text data for LLaVA stages.
    
    Applies the specified template formatting (e.g., 'plain' or 'v1') 
    and then tokenizes the text, returning only the input IDs.
    """
    def __init__(self, tokenizer, template_formatter):
        store_attr()
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided and loaded successfully.")

    def encodes(self, conversations: list) -> list:
        """Applies formatting and tokenization to conversation data.

        Args:
            conversations: Raw conversation list from the dataset sample.

        Returns:
            A list of input token IDs.
        """
        formatted_text = self.template_formatter(conversations, self.tokenizer)
        tokenized_output = self.tokenizer(formatted_text,
                                         return_tensors=None, 
                                         add_special_tokens=True, 
                                         truncation=False 
                                        )
        return tokenized_output['input_ids']

## Batch Transformation

In [None]:
#| export
class LLaVABatchTransform(Transform):
    """ Custom batch transform for LLaVA stages.
        Handles image normalization, text padding reconstruction, attention mask creation,
        image token marker replacement, label creation, and template-specific label masking.
    """
    def __init__(self, tokenizer, normalize_tfm: Normalize, template: str = 'plain', image_token_id: Optional[int] = None):
        store_attr() # Stores tokenizer, normalize_tfm, template
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided.")

        if image_token_id is None:
            self.image_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
            if self.image_token_id == tokenizer.unk_token_id and DEFAULT_IMAGE_TOKEN in tokenizer.added_tokens_decoder:
                 self.image_token_id = tokenizer.added_tokens_decoder[DEFAULT_IMAGE_TOKEN]
            if self.image_token_id == tokenizer.unk_token_id:
                print(f"Warning: {DEFAULT_IMAGE_TOKEN} not found in tokenizer vocab. Using UNK ID: {self.image_token_id}")
        else:
            self.image_token_id = image_token_id

        self.pad_token_id = tokenizer.pad_token_id
        if self.pad_token_id is None:
            raise ValueError("Tokenizer must have a pad_token_id defined.")
            
        # Store template specific info if needed (e.g., separator tokens for v1)
        self.assistant_role_token_ids = None
        self.eos_token_id = tokenizer.eos_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.sep = None
        self.sep2 = None

        if self.template == 'v1':
            # Assuming conv_templates['v1'] exists and has roles and separators
            conv_v1 = conv_templates.get('v1')
            if conv_v1:
                assistant_role_str = conv_v1.roles[1] # Typically 'ASSISTANT'
                self.assistant_role_token_ids = self.tokenizer.encode(f"{assistant_role_str}:", add_special_tokens=False)
                self.sep = conv_v1.sep 
                self.sep2 = conv_v1.sep2 
                print(f"V1 template assistant role tokens: {self.assistant_role_token_ids}")
            else:
                 print("Warning: Vicuna v1 template not found, v1 masking might not work correctly.")

        print(f"LLaVABatchTransform initialized. Image Token ID: {self.image_token_id}, Pad Token ID: {self.pad_token_id}, Template: {self.template}")

    def encodes(self, collated_batch: tuple) -> dict:
        """Applies normalization, reconstructs padded tensors, applies masking based on template.
        Args:
            collated_batch: A tuple (collated_image_tensors, list_of_positional_token_tensors).
        Returns:
            A dictionary containing the fully processed batch ready for the model.
        """
        if not isinstance(collated_batch, tuple) or len(collated_batch) != 2:
             print(f"Warning: LLaVABatchTransform received unexpected input type: {type(collated_batch)}. Expected tuple of 2 elements. Skipping.")
             # To avoid breaking further pipeline, return something plausible or raise error
             # For now, let's try to return the input if it's not what we expect, though this might hide issues.
             # A better approach might be to raise an error or return a pre-defined empty/error structure.
             return collated_batch # Or raise ValueError("Invalid input to LLaVABatchTransform.encodes")

        collated_images, list_of_positional_tensors = collated_batch

        if not isinstance(collated_images, torch.Tensor):
             raise TypeError(f"Expected first element of collated batch to be a Tensor, got {type(collated_images)}")
        if not isinstance(list_of_positional_tensors, list) or not all(isinstance(t, torch.Tensor) for t in list_of_positional_tensors):
             # Allow empty list for robustness if a batch somehow has no text, though unusual.
             if list_of_positional_tensors: # Only raise if it's not empty and not list of Tensors
                raise TypeError(f"Expected second element of collated batch to be a list of Tensors, got {type(list_of_positional_tensors)}")
        
        if not list_of_positional_tensors: # Handle case where text part is empty
            print("Warning: Received empty list of positional tensors. Outputting only image data if available.")
            if collated_images is not None:
                normalized_images = self.normalize_tfm(collated_images)
                return {'pixel_values': normalized_images} # Or an empty dict if images are also None
            return {}


        # 1. Normalize images
        normalized_images = self.normalize_tfm(collated_images)

        # 2. Reconstruct padded input_ids tensor
        try:
            device = list_of_positional_tensors[0].device
            input_ids = pad_sequence([t.to(device) for t in list_of_positional_tensors], 
                                       batch_first=True, 
                                       padding_value=self.pad_token_id)
        except Exception as e:
             print("Error padding positional tensors. Check consistency.")
             for i, t in enumerate(list_of_positional_tensors): print(f"Tensor {i} shape: {t.shape}")
             raise e

        attention_mask = (input_ids != self.pad_token_id).long()
        labels = input_ids.clone()
        input_ids_processed = input_ids.clone()
        input_ids_processed[input_ids_processed == self.image_token_id] = IMAGE_TOKEN_INDEX_PLACEHOLDER

        if self.template == 'plain':
            self._apply_plain_masking(labels, attention_mask)
        elif self.template == 'v1':
            self._apply_v1_masking(labels, attention_mask)
        else:
            print(f"Warning: Unknown template '{self.template}'. Defaulting to plain masking.")
            self._apply_plain_masking(labels, attention_mask)

        output_dict = {
            'pixel_values': normalized_images,
            'input_ids': input_ids_processed, 
            'attention_mask': attention_mask,
            'labels': labels 
        }
        # DIAGNOSTIC PRINT:
        # print(f"LLaVABatchTransform.encodes returning dict with keys: {output_dict.keys()} and types: {[type(v) for v in output_dict.values()]}")
        return output_dict

    def _apply_plain_masking(self, labels, attention_mask):
        """Masks labels for the 'plain' template.""" 
        for i in range(labels.shape[0]):
            image_token_indices = torch.where(labels[i] == self.image_token_id)[0]
            mask_until_idx = -1
            if len(image_token_indices) > 0:
                image_token_idx = image_token_indices[0].item()
                mask_until_idx = image_token_idx + 1 
                if image_token_idx + 1 < len(labels[i]) and labels[i, image_token_idx + 1] == 13: # Newline token
                     mask_until_idx += 1
            else:
                print(f"Warning: Image token ID {self.image_token_id} not found in labels for sample {i} (Plain template). Masking all.")
                mask_until_idx = labels.shape[1] 

            if mask_until_idx > 0:
                 labels[i, :mask_until_idx] = IGNORE_INDEX
            
            labels[i][attention_mask[i] == 0] = IGNORE_INDEX
            
            if labels.shape[1] > 0 and labels[i, 0] == self.bos_token_id: # BOS
                labels[i, 0] = IGNORE_INDEX
                
    def _find_subsequence(self, main_tensor, sub_tensor):
        """Finds the start indices of a sub-tensor within a main tensor."""
        n = main_tensor.size(0)
        m = sub_tensor.size(0)
        if m > n or m == 0: return []
        indices = []
        sub_tensor = sub_tensor.to(main_tensor.device)
        for i in range(n - m + 1):
            if torch.equal(main_tensor[i:i+m], sub_tensor):
                indices.append(i)
        return indices

    def _apply_v1_masking(self, labels, attention_mask):
        """Masks labels for the Vicuna 'v1' template.""" 
        if self.assistant_role_token_ids is None:
            print("Warning: Assistant role tokens not initialized for v1 template. Cannot perform v1 masking.")
            labels[attention_mask == 0] = IGNORE_INDEX
            return
        
        assistant_token_ids_tensor = torch.tensor(self.assistant_role_token_ids, dtype=torch.long)
        len_assistant_prompt = len(self.assistant_role_token_ids)

        for i in range(labels.shape[0]):
            current_labels = labels[i]
            current_mask = torch.ones_like(current_labels) * IGNORE_INDEX 

            assistant_indices = self._find_subsequence(current_labels, assistant_token_ids_tensor)

            for start_idx in assistant_indices:
                response_start_idx = start_idx + len_assistant_prompt
                response_end_idx = -1
                eos_indices = torch.where(current_labels[response_start_idx:] == self.eos_token_id)[0]
                if len(eos_indices) > 0:
                    response_end_idx = response_start_idx + eos_indices[0].item()
                else:
                    response_end_idx = torch.sum(attention_mask[i]).item() 
                
                if response_start_idx < response_end_idx:
                    current_mask[response_start_idx:response_end_idx] = current_labels[response_start_idx:response_end_idx]
            
            image_token_indices = torch.where(current_labels == self.image_token_id)[0]
            if len(image_token_indices) > 0:
                current_mask[image_token_indices[0].item()] = IGNORE_INDEX
                
            current_mask[attention_mask[i] == 0] = IGNORE_INDEX
            
            if current_labels.shape[0] > 0 and current_labels[0] == self.bos_token_id: # BOS
                current_mask[0] = IGNORE_INDEX
                
            labels[i] = current_mask

    def decodes(self, batch: dict) -> tuple:
        """Decodes a batch dictionary back into a tuple of (images, texts)."""
        decoded_images = []
        decoded_texts = []
        if not isinstance(batch, dict) or 'pixel_values' not in batch or 'input_ids' not in batch:
             print(f"Decode expected dict with 'pixel_values' and 'input_ids', got {type(batch)}\nContent: {batch}")
             return ([], []) 

        imgs = batch['pixel_values']
        ids = batch['input_ids']
        bs = imgs.shape[0]

        for i in range(bs):
            img_tensor_cpu = imgs[i].cpu()
            img_decoded = self.normalize_tfm.decode(img_tensor_cpu.unsqueeze(0))[0]

            ids_i = ids[i].clone().cpu() 
            ids_i[ids_i == IMAGE_TOKEN_INDEX_PLACEHOLDER] = self.image_token_id 
            
            attn_mask_i = (ids_i != self.pad_token_id)
            actual_ids = ids_i[attn_mask_i].tolist()
            
            text_decoded = self.tokenizer.decode(actual_ids, skip_special_tokens=True)
            image_token_str_decoded = self.tokenizer.decode([self.image_token_id], skip_special_tokens=False)
            text_decoded = text_decoded.replace(image_token_str_decoded, DEFAULT_IMAGE_TOKEN)

            decoded_images.append(img_decoded)
            decoded_texts.append(TitledStr(text_decoded)) 

        return (decoded_images, decoded_texts)

LLaVABatchTransform.split_idx = None

#### Example Usage & Test (Batch Transform - V1 Masking)

In [None]:
#| test
if tokenizer:
    try:
        # 1. Create Sample Data (Tokenized using v1 template)
        conv_a = [{'from': 'human', 'value': '<image>\nDescribe image.'}, {'from': 'gpt', 'value': 'It is a red object.'}]
        conv_b = [{'from': 'human', 'value': '<image>What color?'}, {'from': 'gpt', 'value': 'It is green.'}, {'from':'human', 'value': 'And shape?'}, {'from':'gpt', 'value':'It is round'}]
        
        tokenizer_tfm_v1 = LLaVATextTokenizer(tokenizer, template_formatter=format_v1_template)
        
        token_ids_a = tokenizer_tfm_v1(conv_a)
        token_ids_b = tokenizer_tfm_v1(conv_b)
        
        # 2. Simulate Collation (Get padded tensor)
        collated_ids_unprocessed = pad_sequence([torch.tensor(token_ids_a), torch.tensor(token_ids_b)], 
                                                batch_first=True, 
                                                padding_value=tokenizer.pad_token_id)
        print("\n--- Original Collated Input IDs (Padded) ---")
        print(collated_ids_unprocessed)
        
        # Dummy images
        dummy_images = torch.rand(2, 3, 336, 336)
        # The LLaVABatchTransform.encodes expects a tuple: (collated_images, list_of_positional_tensors)
        simulated_collated_batch_for_transform = (dummy_images, [torch.tensor(token_ids_a), torch.tensor(token_ids_b)])

        # 3. Instantiate and Apply Batch Transform for V1
        batch_transform_v1 = LLaVABatchTransform(tokenizer, normalize_tfm=clip_normalize, template='v1')
        raw_processed_batch = batch_transform_v1(simulated_collated_batch_for_transform)

        # 4. Inspect Output
        print("\n--- Processed Batch (Output of Transform) ---")
        print(f"Raw processed batch type: {type(raw_processed_batch)}")
        if isinstance(raw_processed_batch, tuple): # Print details if it's a tuple
            print(f"  Length of tuple: {len(raw_processed_batch)}")
            for i_debug, item_debug in enumerate(raw_processed_batch):
                print(f"  Element {i_debug} type: {type(item_debug)}")
                if isinstance(item_debug, torch.Tensor):
                    print(f"    Element {i_debug} shape: {item_debug.shape}")
                elif isinstance(item_debug, list) and all(isinstance(li, torch.Tensor) for li in item_debug):
                    print(f"    Element {i_debug} is a list of {len(item_debug)} tensors.")
                    for j_debug, sub_item_debug in enumerate(item_debug):
                         print(f"      Sub-element {j_debug} shape: {sub_item_debug.shape}")


        processed_batch_dict = None
        if isinstance(raw_processed_batch, dict):
            processed_batch_dict = raw_processed_batch
            print("Processed batch is a dictionary.")
        elif isinstance(raw_processed_batch, tuple) and len(raw_processed_batch) == 1 and isinstance(raw_processed_batch[0], dict):
            processed_batch_dict = raw_processed_batch[0]
            print("Processed batch was a tuple containing a single dictionary. Using the dictionary.")
        # elif isinstance(raw_processed_batch, tuple) and len(raw_processed_batch) == 4: # Assuming order: pv, ids, attn_mask, labels
        #     print("Processed batch is a tuple of 4 elements. Reconstructing dictionary.")
        #     processed_batch_dict = {
        #         'pixel_values': raw_processed_batch[0],
        #         'input_ids': raw_processed_batch[1],
        #         'attention_mask': raw_processed_batch[2],
        #         'labels': raw_processed_batch[3]
        #     }
        else:
            # This path will be taken if it's a 2-element tuple as seen in the error
            # Or any other unexpected tuple structure
            raise TypeError(f"Unexpected type or structure for processed_batch: {type(raw_processed_batch)}. Content (first few elements if tuple): {str(raw_processed_batch)[:500]}...")


        assert isinstance(processed_batch_dict, dict), "processed_batch_dict should be a dictionary"
        
        print(f"Pixel Values Shape: {processed_batch_dict.get('pixel_values').shape}")
        print(f"Input IDs Shape (with -200): {processed_batch_dict.get('input_ids').shape}")
        print(f"Input IDs:\n{processed_batch_dict.get('input_ids')}")
        print(f"Attention Mask Shape: {processed_batch_dict.get('attention_mask').shape}")
        print(f"Attention Mask:\n{processed_batch_dict.get('attention_mask')}")
        print(f"Labels Shape: {processed_batch_dict.get('labels').shape}")
        print(f"Labels:\n{processed_batch_dict.get('labels')}")

        print("\n--- Decoded Labels (Showing Loss Calculation Targets) ---")
        if 'labels' in processed_batch_dict:
            for i in range(processed_batch_dict['labels'].shape[0]):
                label_ids = processed_batch_dict['labels'][i]
                valid_label_ids = label_ids[label_ids != IGNORE_INDEX].tolist()
                decoded_tokens_list = tokenizer.convert_ids_to_tokens(valid_label_ids)
                print(f"Sample {i} Target Tokens: {decoded_tokens_list}")
        
        assert processed_batch_dict['labels'][0, 0] == IGNORE_INDEX
        assert processed_batch_dict['labels'][1, 0] == IGNORE_INDEX
        
        original_img_pos_0 = torch.where(collated_ids_unprocessed[0] == IMAGE_TOKEN_ID)[0]
        if len(original_img_pos_0) > 0:
            assert processed_batch_dict['labels'][0, original_img_pos_0[0].item()] == IGNORE_INDEX
        original_img_pos_1 = torch.where(collated_ids_unprocessed[1] == IMAGE_TOKEN_ID)[0]
        if len(original_img_pos_1) > 0:
            assert processed_batch_dict['labels'][1, original_img_pos_1[0].item()] == IGNORE_INDEX
        
        print("\nV1 Masking Test Passed (Check decoded labels above).")
            
    except Exception as e:
        print(f"Error during V1 masking test: {e}")
        import traceback
        traceback.print_exc()
else:
    print("Tokenizer not loaded, skipping batch transform v1 test.")


--- Original Collated Input IDs (Padded) ---
tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901, 29871, 32000,    13,  4002, 29581,  1967,
         29889,   319,  1799,  9047, 13566, 29901,   739,   338,   263,  2654,
          1203, 29889,     2,   319,  1799,  9047, 13566, 29901,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901, 29871, 32000,    13,  5618,  2927, 29973,
           319,  1799,  9047, 13566, 29901,   739,   338,  7933, 29889,     2,
     

Traceback (most recent call last):
  File "/tmp/ipykernel_2664/2423525345.py", line 62, in <module>
    raise TypeError(f"Unexpected type or structure for processed_batch: {type(raw_processed_batch)}. Content (first few elements if tuple): {str(raw_processed_batch)[:500]}...")
TypeError: Unexpected type or structure for processed_batch: <class 'tuple'>. Content (first few elements if tuple): (tensor([[[[4.4813e-01, 7.1428e-01, 2.1634e-01,  ..., 5.7698e-01,
           7.9262e-01, 2.5049e-02],
          [7.2263e-01, 7.7891e-01, 1.6354e-01,  ..., 2.7791e-01,
           2.4072e-01, 6.1231e-01],
          [3.6167e-01, 6.1093e-01, 3.1450e-01,  ..., 6.3710e-01,
           4.6449e-01, 4.6196e-02],
          ...,
          [9.4213e-02, 3.1170e-01, 7.2084e-01,  ..., 1.6722e-01,
           7.3435e-01, 8.8533e-01],
          [2.6279e-01, 6.3955e-01, 9.8602e-01,  ..., 9.9407e-01,
           1.03...


---

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()