# Data Preprocessing

> Functions and definitions for preprocessing steps, including normalization stats, tokenization, and template formatting.

In [None]:
#| default_exp data.preprocessing

In [None]:
#| hide
from nbdev.showdoc import *

In [1]:
#| export
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
from fastai.vision.augment import Normalize
from fastai.data.transforms import Transform
from fastai.torch_core import TensorBase, tensor, TitledList
import torch
from typing import List, Dict, Union, Tuple, Any

from Adaptive_Patching_VIT_fastai.utils import load_config

## Step 1.2 (Continued): Image Normalization Setup

Load the CLIP image processor to get the correct normalization statistics (mean and standard deviation) required for the vision encoder.

In [2]:
#| export
# Load config to get model name
CONFIG_PATH = 'configs/config.yaml'
try:
    config = load_config(CONFIG_PATH)
    VISION_ENCODER_NAME = config['model']['vision_encoder_name_or_path']
    LLM_NAME = config['model']['llm_name_or_path'] # Added for tokenizer
except FileNotFoundError:
    print(f"Warning: Config file not found at {CONFIG_PATH}. Using default model names.")
    VISION_ENCODER_NAME = 'openai/clip-vit-large-patch14-336' # Fallback
    LLM_NAME = 'lmsys/vicuna-7b-v1.5' # Fallback
except KeyError as e:
    print(f"Warning: Key {e} not found in {CONFIG_PATH}. Using defaults.")
    VISION_ENCODER_NAME = config.get('model', {}).get('vision_encoder_name_or_path', 'openai/clip-vit-large-patch14-336')
    LLM_NAME = config.get('model', {}).get('llm_name_or_path', 'lmsys/vicuna-7b-v1.5')

# Load the CLIP image processor
try:
    clip_image_processor = AutoImageProcessor.from_pretrained(VISION_ENCODER_NAME)
    print(f"Successfully loaded CLIP image processor for: {VISION_ENCODER_NAME}")
except Exception as e:
    print(f"Error loading CLIP image processor for {VISION_ENCODER_NAME}: {e}")
    # Handle error appropriately, maybe raise or use default stats
    clip_image_processor = None

# Get normalization stats
if clip_image_processor:
    image_mean = clip_image_processor.image_mean
    image_std = clip_image_processor.image_std
else:
    print("Warning: Using default ImageNet stats as fallback for normalization.")
    # Default fallback (ImageNet stats often used, but CLIP specific is better)
    image_mean = [0.485, 0.456, 0.406]
    image_std = [0.229, 0.224, 0.225]

# Create the fastai Normalize transform using CLIP stats
clip_normalize = Normalize.from_stats(image_mean, image_std)

Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336


In [3]:
# Example: Print the stats and the transform
print(f"CLIP Mean: {image_mean}")
print(f"CLIP Std: {image_std}")
print(f"Fastai Normalize Transform: {clip_normalize}")

CLIP Mean: [0.48145466, 0.4578275, 0.40821073]
CLIP Std: [0.26862954, 0.26130258, 0.27577711]
Fastai Normalize Transform: Normalize -- Tries to normalize batch with `mean` and `std` specified on `axes`


---

## Step 1.3: Text Tokenization and Template Handling (Stage 1 - Plain)

Load the LLM's tokenizer (Vicuna) and define the 'plain' template formatting for Stage 1 pre-training.

In [4]:
#| export
DEFAULT_IMAGE_TOKEN = "<image>" # Placeholder token for image features
IMAGE_TOKEN_INDEX_PLACEHOLDER = -200 # Special marker used internally in input_ids
IGNORE_INDEX = -100 # Standard ignore index for labels in loss calculation

# Load the Vicuna tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(
        LLM_NAME,
        model_max_length=config.get('data', {}).get('tokenizer_model_max_length', 2048),
        padding_side=config.get('data', {}).get('tokenizer_padding_side', 'right'),
        use_fast=True,
    )
    print(f"Successfully loaded tokenizer for: {LLM_NAME}")

    # Add the image token as a special token, if it's not already there.
    # This ensures the tokenizer knows about it, even though we replace its embedding later.
    # Avoids warnings and potential token splitting issues.
    # Check if the token already exists
    current_vocab = tokenizer.get_vocab()
    if DEFAULT_IMAGE_TOKEN not in current_vocab:
        print(f"Adding special token {DEFAULT_IMAGE_TOKEN} to tokenizer.")
        # Add the token as a *special* token. This typically means it won't be split
        # and can be handled distinctly. `additional_special_tokens` is a common way.
        num_added = tokenizer.add_special_tokens({'additional_special_tokens': [DEFAULT_IMAGE_TOKEN]})
        if num_added > 0:
            print(f"Added {num_added} token(s). New vocab size: {len(tokenizer)}")
            # NOTE: If a token is truly added (increasing vocab size),
            # the model's embedding layer MUST be resized later.
            # For LLaVA, we usually map it to an existing ID or rely on the -200 replacement,
            # but adding it as a special token helps tokenizer processing.
    # Retrieve the assigned ID
    IMAGE_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
    print(f"Using token ID for {DEFAULT_IMAGE_TOKEN}: {IMAGE_TOKEN_ID}")
    # Handle case where it might still be UNK if not properly added/found
    if IMAGE_TOKEN_ID == tokenizer.unk_token_id:
         print(f"Warning: {DEFAULT_IMAGE_TOKEN} resolved to UNK token ID ({tokenizer.unk_token_id}). Check tokenizer setup.")


    # Set pad token if missing (common for LLaMA models)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token # Option 1: Use EOS
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Option 2: Add a new pad token
        print(f"Using EOS token as pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
        # Note: If adding a new pad token, the model's embedding layer needs resizing later!
    else:
        print(f"Tokenizer already has pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

except Exception as e:
    print(f"Error loading tokenizer for {LLM_NAME}: {e}")
    tokenizer = None
    IMAGE_TOKEN_ID = None

Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Using token ID for <image>: 32000
Using EOS token as pad token: </s> (ID: 2)


In [5]:
type(tokenizer)

transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast

In [6]:
#| export
def format_plain_template(conversations: List[Dict[str, str]]) -> str:
    """Formats conversations using the 'plain' template for Stage 1 pre-training.

    The 'plain' template uses the format: <image>\n{caption}\n
    where {caption} is the value of the first 'gpt' turn.

    Args:
        conversations: A list of conversation turns (dictionaries with 'from' and 'value').

    Returns:
        The formatted string. Returns just the image token if no 'gpt' turn is found.
    """
    caption = "" # Default to empty caption if no 'gpt' turn found
    for turn in conversations:
        if turn.get('from', '').lower() == 'gpt':
            caption = turn.get('value', '')
            break # Use the first GPT response as the caption

    # Ensure the <image> token is always first, followed by newline and caption
    # Remove any existing <image> token from caption to avoid duplicates
    caption = caption.replace(DEFAULT_IMAGE_TOKEN, '').strip()

    # Return formatted string, handle empty caption case
    # Add final newline consistent with some implementations?
    # LLaVA reference often has <image>\n{caption}
    # Let's stick to <image>\n{caption} for now, stripping trailing whitespace.
    formatted = f"{DEFAULT_IMAGE_TOKEN}\n{caption}".strip() if caption else f"{DEFAULT_IMAGE_TOKEN}"
    return formatted

In [7]:
show_doc(format_plain_template)

```python
#| export
def format_plain_template(conversations: List[Dict[str, str]]) -> str:
    """Formats conversations using the 'plain' template for Stage 1 pre-training.\n\nThe 'plain' template uses the format: <image>\n{caption}\n\nwhere {caption} is the value of the first 'gpt' turn.\n\nArgs:\n    conversations: A list of conversation turns (dictionaries with 'from' and 'value').\n\nReturns:\n    The formatted string. Returns just the image token if no 'gpt' turn is found.\n    """
    caption = "" # Default to empty caption if no 'gpt' turn found
    for turn in conversations:
        if turn.get('from', '').lower() == 'gpt':
            caption = turn.get('value', '')
            break # Use the first GPT response as the caption

    # Ensure the <image> token is always first, followed by newline and caption
    # Remove any existing <image> token from caption to avoid duplicates
    caption = caption.replace(DEFAULT_IMAGE_TOKEN, '').strip()
    
    # Return formatted string, handle empty caption case
    # Add final newline consistent with some implementations?
    # LLaVA reference often has <image>\n{caption}\n
    # Let's stick to <image>\n{caption} for now, stripping trailing whitespace.
    formatted = f"{DEFAULT_IMAGE_TOKEN}\n{caption}".strip() if caption else f"{DEFAULT_IMAGE_TOKEN}"
    return formatted
```

#### Example Usage & Test (Template Formatting)

In [8]:
conv1 = [{'from': 'human', 'value': '<image>\nDescribe.'}, {'from': 'gpt', 'value': 'This is the caption.'}]
conv2 = [{'from': 'human', 'value': 'Describe.'}, {'from': 'gpt', 'value': '<image>Caption with image token removed.'}]
conv3 = [{'from': 'human', 'value': '<image>\nDescribe.'}]
conv4 = []

print(f"Test Case 1 (Standard): {format_plain_template(conv1)}")
print(f"Test Case 2 (Image token in caption): {format_plain_template(conv2)}")
print(f"Test Case 3 (No GPT turn): {format_plain_template(conv3)}")
print(f"Test Case 4 (Empty conversation): {format_plain_template(conv4)}")

Test Case 1 (Standard): <image>
This is the caption.
Test Case 2 (Image token in caption): <image>
Caption with image token removed.
Test Case 3 (No GPT turn): <image>
Test Case 4 (Empty conversation): <image>


In [9]:
#| export
class LLaVATextTokenizer(Transform):
    """A fastai Transform to format and tokenize text data for LLaVA stage 1.

    Applies the 'plain' template formatting and then tokenizes the text, returning only the input IDs.
    """
    def __init__(self, tokenizer, template_formatter=format_plain_template):
        store_attr()
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided and loaded successfully.")

    def encodes(self, conversations: List[Dict[str, str]]) -> List[int]:
        """Applies formatting and tokenization to conversation data.

        Args:
            conversations: Raw conversation list from the dataset sample.

        Returns:
            A list of input token IDs.
        """
        formatted_text = self.template_formatter(conversations)
        # Tokenize the formatted text.
        # We don't pad here; padding is done at the batch level.
        tokenized_output = self.tokenizer(formatted_text,
                                         return_tensors=None, # Let batch collation handle tensor conversion + padding
                                         add_special_tokens=True, # Add BOS/EOS if tokenizer configured to do so
                                         truncation=False # Truncation can be done later if needed
                                        )
        # Return just the input_ids list for DataBlock item_tfms
        # We use TensorBase here so fastai recognizes it as tensor-like for collation
        # return TensorBase(tokenized_output['input_ids'])
        # Update: Returning raw list might be simpler, collation handles tensor conversion
        return tokenized_output['input_ids']

#### Example Usage & Test (Tokenizer Transform)

In [10]:
if tokenizer:
    # Example conversation
    example_conv = [{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}]

    # Create the transform instance
    llava_tokenizer_tfm = LLaVATextTokenizer(tokenizer)

    # Apply the transform
    tokenized_ids = llava_tokenizer_tfm(example_conv)

    print(f"Original Conversations: {example_conv}")
    # Re-format to show what was tokenized
    formatted = format_plain_template(example_conv)
    print(f"Formatted Text: {formatted}")
    print(f"Tokenized Output (Input IDs): {tokenized_ids}")

    # Decode for verification
    decoded_tokens = tokenizer.convert_ids_to_tokens(tokenized_ids)
    print(f"Decoded Tokens: {decoded_tokens}")

    # Ensure the output is a list of integers
    assert isinstance(tokenized_ids, list)
    assert all(isinstance(x, int) for x in tokenized_ids)
else:
    print("Tokenizer not loaded, skipping tokenizer transform test.")

Original Conversations: [{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}]
Formatted Text: <image>
It is red.
Tokenized Output (Input IDs): [1, 32000, 29871, 13, 490, 338, 2307, 29889, 2]
Decoded Tokens: ['<s>', '<image>', '\n', 'It', ' is', ' red', '.', '</s>']


---

## Step 1.5: Implement Custom Batch Transform / Collate Function (Stage 1)

This transform handles batch-level operations: padding, masking, image token replacement, and normalization.

In [11]:
#| export
class LLaVABatchTransform(Transform):
    """ Custom batch transform for LLaVA stage 1.
        Handles image normalization, text padding, attention mask creation,
        image token marker replacement, label creation, and label masking for the 'plain' template.
    """
    def __init__(self, tokenizer, image_token_id=None):
        store_attr()
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided.")

        if image_token_id is None:
            self.image_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
            if self.image_token_id == tokenizer.unk_token_id:
                print(f"Warning: {DEFAULT_IMAGE_TOKEN} not found in tokenizer vocab. Using UNK ID: {self.image_token_id}")
        else:
            self.image_token_id = image_token_id

        self.pad_token_id = tokenizer.pad_token_id
        if self.pad_token_id is None:
            raise ValueError("Tokenizer must have a pad_token_id defined.")

        print(f"LLaVABatchTransform initialized. Image Token ID: {self.image_token_id}, Pad Token ID: {self.pad_token_id}")

    def encodes(self, samples: List[Tuple[torch.Tensor, List[int]]]) -> Dict[str, torch.Tensor]:
        """Collates samples into a batch, applies padding, masking, and normalization.

        Args:
            samples: A list of tuples, where each tuple contains:
                (image_tensor: torch.Tensor, token_ids: List[int])
                as produced by the item transforms.

        Returns:
            A dictionary containing the collated batch:
            { 'pixel_values': torch.Tensor, 'input_ids': torch.Tensor,
              'attention_mask': torch.Tensor, 'labels': torch.Tensor }
        """
        if not samples:
            return {}

        # 1. Separate images and text token IDs
        image_tensors = [s[0] for s in samples]
        text_token_ids_list = [s[1] for s in samples]

        # 2. Stack and normalize image tensors
        # Assume images are already tensors via ToTensor in item_tfms
        # Apply normalization (expects BCHW format)
        # Make sure images are float first (IntToFloatTensor might be needed before this transform)
        images_stacked = torch.stack(image_tensors)
        # Need to ensure clip_normalize is available in the scope or passed in
        # For now, assume it's globally available from earlier cell
        normalized_images = clip_normalize(images_stacked)

        # 3. Pad text sequences using tokenizer
        # We provide the list of lists directly to tokenizer.pad
        padded_texts = self.tokenizer.pad(
            {'input_ids': text_token_ids_list},
            padding='longest', # Pad to the longest sequence in the batch
            return_tensors='pt', # Return PyTorch tensors
            return_attention_mask=True
        )
        input_ids = padded_texts['input_ids']
        attention_mask = padded_texts['attention_mask']

        # 4. Create labels by cloning input_ids
        labels = input_ids.clone()

        # 5. Find image token ID and replace with placeholder -200 in input_ids
        # We do this *after* padding and *before* masking labels
        input_ids[input_ids == self.image_token_id] = IMAGE_TOKEN_INDEX_PLACEHOLDER

        # 6. Apply label masking for 'plain' template
        # Mask everything up to and including the image token, and padding tokens.
        # The caption starts *after* the image token (and the newline following it).
        for i in range(labels.shape[0]): # Iterate through each sample in the batch
            # Find the index of the image token (before it was replaced in input_ids)
            # Note: We search for the *original* ID in the cloned labels tensor.
            image_token_indices = torch.where(labels[i] == self.image_token_id)[0]
            if len(image_token_indices) > 0:
                image_token_idx = image_token_indices[0].item()
                # Mask everything up to and including the image token.
                # For 'plain' template (<image>\n{caption}), the caption starts at index image_token_idx + 1
                # (assuming newline is one token, which might not be true). Be careful.
                # Let's find the first non-special token *after* the image token.
                # We assume the template is <image>\n{caption} or similar.
                # The image token is expected at the beginning (or after BOS).
                mask_until_idx = image_token_idx + 1 # Start masking after image token
                
                # Heuristic: Assume caption starts after image token and potential newline.
                # We mask the image token itself and tokens before it.
                labels[i, :mask_until_idx] = IGNORE_INDEX
            else:
                # If image token wasn't found (shouldn't happen with plain template),
                # mask the entire sequence as a safety measure.
                print(f"Warning: Image token ID {self.image_token_id} not found in labels for sample {i}. Masking all.")
                labels[i, :] = IGNORE_INDEX

            # Also mask padding tokens
            labels[i][attention_mask[i] == 0] = IGNORE_INDEX
            
            # Specific check: Mask the BOS token if present (usually token ID 1 for Llama)
            if labels[i, 0] == self.tokenizer.bos_token_id:
                labels[i, 0] = IGNORE_INDEX

        # 7. Return the prepared batch as a dictionary
        return {
            'pixel_values': normalized_images,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Make the transform usable in fastai pipelines
LLaVABatchTransform.split_idx = None # Apply to both train and validation

LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 2


#### Example Usage & Test (Batch Transform)

In [12]:
#| eval: false
if tokenizer:
    # 1. Create dummy item samples (output of item transforms)
    sample1_img = torch.rand(3, 336, 336) # Dummy image tensor
    sample1_text_ids = llava_tokenizer_tfm([{'from': 'human', 'value': '<image>'}, {'from': 'gpt', 'value': 'It is red.'}])
    sample1 = (sample1_img, sample1_text_ids)

    sample2_img = torch.rand(3, 336, 336)
    # Example with slightly longer text
    sample2_text_ids = llava_tokenizer_tfm([{'from': 'human', 'value': '<image>'}, {'from': 'gpt', 'value': 'Is it green?'}])
    sample2 = (sample2_img, sample2_text_ids)

    dummy_samples = [sample1, sample2]

    # 2. Instantiate the batch transform
    batch_transform = LLaVABatchTransform(tokenizer)

    # 3. Apply the transform to the dummy samples
    collated_batch = batch_transform(dummy_samples)

    # 4. Inspect the output
    print("--- Input Samples ---")
    for i, s in enumerate(dummy_samples):
        print(f"Sample {i}:")
        print(f"  Image shape: {s[0].shape}")
        print(f"  Token IDs: {s[1]}")
        
    print("\n--- Collated Batch ---")
    if collated_batch:
        print(f"Pixel Values Shape: {collated_batch['pixel_values'].shape}")
        print(f"Input IDs Shape: {collated_batch['input_ids'].shape}")
        print(f"Input IDs:\n{collated_batch['input_ids']}")
        print(f"Attention Mask Shape: {collated_batch['attention_mask'].shape}")
        print(f"Attention Mask:\n{collated_batch['attention_mask']}")
        print(f"Labels Shape: {collated_batch['labels'].shape}")
        print(f"Labels:\n{collated_batch['labels']}")
        
        print("\n--- Decoded Labels (showing non-masked tokens) ---")
        for i in range(collated_batch['labels'].shape[0]):
            label_ids = collated_batch['labels'][i]
            # Filter out ignored indices and decode
            valid_label_ids = label_ids[label_ids != IGNORE_INDEX].tolist()
            decoded_labels = tokenizer.convert_ids_to_tokens(valid_label_ids)
            # Clean up special tokens for display if needed
            cleaned_labels = [t for t in decoded_labels if t not in tokenizer.all_special_tokens]
            print(f"Sample {i} Labels: {cleaned_labels}")

        # Basic Assertions
        assert 'pixel_values' in collated_batch
        assert 'input_ids' in collated_batch
        assert 'attention_mask' in collated_batch
        assert 'labels' in collated_batch
        assert collated_batch['input_ids'].shape == collated_batch['attention_mask'].shape == collated_batch['labels'].shape
        # Check if image token replacement happened
        assert torch.any(collated_batch['input_ids'] == IMAGE_TOKEN_INDEX_PLACEHOLDER)
        # Check if labels are masked
        assert torch.any(collated_batch['labels'] == IGNORE_INDEX)
    else:
        print("Batch collation failed.")
else:
    print("Tokenizer not loaded, skipping batch transform test.")

--- Input Samples ---Sample 0:  Image shape: torch.Size([3, 336, 336])  Token IDs: [1, 32000, 29871, 13, 490, 338, 2307, 29889, 2]Sample 1:  Image shape: torch.Size([3, 336, 336])  Token IDs: [1, 32000, 29871, 13, 1243, 526, 29918, 6388, 29889, 2]--- Collated Batch ---Pixel Values Shape: torch.Size([2, 3, 336, 336])Input IDs Shape: torch.Size([2, 10])Input IDs:tensor([[    1,  -200, 29871,    13,   490,   338,  2307, 29889,     2,     2],        [    1,  -200, 29871,    13,  1243,   526, 29918,  6388, 29889,     2]])Attention Mask Shape: torch.Size([2, 10])Attention Mask:tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])Labels Shape: torch.Size([2, 10])Labels:tensor([[ -100,  -100,  -100,    13,   490,   338,  2307, 29889,     2,  -100],        [ -100,  -100,  -100,    13,  1243,   526, 29918,  6388, 29889,     2]])--- Decoded Labels (showing non-masked tokens) ---Sample 0 Labels: [' It', ' is', ' red', '.', '</s>']Sample 1 Labels: [' is', ' green', '?', '<

---

## Step 4.1: Update Data Handling for Stage 2 (Placeholder)

This section will adapt text processing for the Vicuna v1 template and update label masking logic.

In [None]:
# Placeholder for format_v1_template function
# Placeholder for updated LLaVABatchTransform logic for v1 template

In [13]:
#| hide
import nbdev; nbdev.nbdev_export()