# Data Preprocessing

> Functions and definitions for preprocessing steps, including normalization stats, tokenization, and template formatting.

In [None]:
#| default_exp data.preprocessing

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
from fastai.vision.augment import Normalize
from fastai.data.transforms import Transform
import torch
from typing import List, Dict, Union

from Adaptive_Patching_VIT_fastai.utils import load_config

## Step 1.2 (Continued): Image Normalization Setup

Load the CLIP image processor to get the correct normalization statistics (mean and standard deviation) required for the vision encoder.

In [None]:
#| export
# Load config to get model name
CONFIG_PATH = 'configs/config.yaml'
try:
    config = load_config(CONFIG_PATH)
    VISION_ENCODER_NAME = config['model']['vision_encoder_name_or_path']
    LLM_NAME = config['model']['llm_name_or_path'] # Added for tokenizer
except FileNotFoundError:
    print(f"Warning: Config file not found at {CONFIG_PATH}. Using default model names.")
    VISION_ENCODER_NAME = 'openai/clip-vit-large-patch14-336' # Fallback
    LLM_NAME = 'lmsys/vicuna-7b-v1.5' # Fallback
except KeyError as e:
    print(f"Warning: Key {e} not found in {CONFIG_PATH}. Using defaults.")
    VISION_ENCODER_NAME = config.get('model', {}).get('vision_encoder_name_or_path', 'openai/clip-vit-large-patch14-336')
    LLM_NAME = config.get('model', {}).get('llm_name_or_path', 'lmsys/vicuna-7b-v1.5')

# Load the CLIP image processor
try:
    clip_image_processor = AutoImageProcessor.from_pretrained(VISION_ENCODER_NAME)
    print(f"Successfully loaded CLIP image processor for: {VISION_ENCODER_NAME}")
except Exception as e:
    print(f"Error loading CLIP image processor for {VISION_ENCODER_NAME}: {e}")
    # Handle error appropriately, maybe raise or use default stats
    clip_image_processor = None

# Get normalization stats
if clip_image_processor:
    image_mean = clip_image_processor.image_mean
    image_std = clip_image_processor.image_std
else:
    print("Warning: Using default ImageNet stats as fallback for normalization.")
    # Default fallback (ImageNet stats often used, but CLIP specific is better)
    image_mean = [0.485, 0.456, 0.406]
    image_std = [0.229, 0.224, 0.225]

# Create the fastai Normalize transform using CLIP stats
clip_normalize = Normalize.from_stats(image_mean, image_std)

Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336


In [None]:
# Example: Print the stats and the transform
print(f"CLIP Mean: {image_mean}")
print(f"CLIP Std: {image_std}")
print(f"Fastai Normalize Transform: {clip_normalize}")

CLIP Mean: [0.48145466, 0.4578275, 0.40821073]
CLIP Std: [0.26862954, 0.26130258, 0.27577711]
Fastai Normalize Transform: Normalize -- Tries to normalize batch with `mean` and `std` specified on `axes`


---

## Step 1.3: Text Tokenization and Template Handling (Stage 1 - Plain)

Load the LLM's tokenizer (Vicuna) and define the 'plain' template formatting for Stage 1 pre-training.

In [None]:
#| export
DEFAULT_IMAGE_TOKEN = "<image>" # Placeholder token for image features

# Load the Vicuna tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(
        LLM_NAME,
        model_max_length=config.get('data', {}).get('tokenizer_model_max_length', 2048),
        padding_side=config.get('data', {}).get('tokenizer_padding_side', 'right'),
        use_fast=True,
    )
    print(f"Successfully loaded tokenizer for: {LLM_NAME}")
    
    # Set pad token if missing (common for LLaMA models)
    if tokenizer.pad_token is None:
        # tokenizer.pad_token = tokenizer.eos_token # Option 1: Use EOS
        tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Option 2: Add a new pad token
        print(f"Added pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
        # Note: If adding a new token, the model's embedding layer needs resizing later!
        # For LLaVA/Vicuna, typically pad_token is implicitly handled or eos is used, check conventions.
        # LLaVA often uses 0 (unk) or eos. Let's stick to eos if pad is None initially.
        if tokenizer.pad_token is None:
             tokenizer.pad_token = tokenizer.eos_token
             print(f"Using EOS token as pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
        
except Exception as e:
    print(f"Error loading tokenizer for {LLM_NAME}: {e}")
    tokenizer = None

Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5


In [None]:
type(tokenizer)

<class 'NoneType'>

In [None]:
#| export
def format_plain_template(conversations: List[Dict[str, str]]) -> str:
    """Formats conversations using the 'plain' template for Stage 1 pre-training.

    The 'plain' template uses the format: <image>\n{caption}
    where {caption} is the value of the first 'gpt' turn.

    Args:
        conversations: A list of conversation turns (dictionaries with 'from' and 'value').

    Returns:
        The formatted string. Returns just the image token if no 'gpt' turn is found.
    """
    caption = "" # Default to empty caption if no 'gpt' turn found
    for turn in conversations:
        if turn.get('from', '').lower() == 'gpt':
            caption = turn.get('value', '')
            break # Use the first GPT response as the caption

    # Ensure the <image> token is always first, followed by newline and caption
    # Remove any existing <image> token from caption to avoid duplicates
    caption = caption.replace(DEFAULT_IMAGE_TOKEN, '').strip()
    
    # Return formatted string, handle empty caption case
    return f"{DEFAULT_IMAGE_TOKEN}\n{caption}".strip() if caption else f"{DEFAULT_IMAGE_TOKEN}"

In [None]:
show_doc(format_plain_template)

```python
#| export
def format_plain_template(conversations: List[Dict[str, str]]) -> str:
    """Formats conversations using the 'plain' template for Stage 1 pre-training.

    The 'plain' template uses the format: <image>\n{caption}
    where {caption} is the value of the first 'gpt' turn.

    Args:
        conversations: A list of conversation turns (dictionaries with 'from' and 'value').

    Returns:
        The formatted string. Returns just the image token if no 'gpt' turn is found.
    """
    caption = "" # Default to empty caption if no 'gpt' turn found
    for turn in conversations:
        if turn.get('from', '').lower() == 'gpt':
            caption = turn.get('value', '')
            break # Use the first GPT response as the caption

    # Ensure the <image> token is always first, followed by newline and caption
    # Remove any existing <image> token from caption to avoid duplicates
    caption = caption.replace(DEFAULT_IMAGE_TOKEN, '').strip()
    
    # Return formatted string, handle empty caption case
    return f"{DEFAULT_IMAGE_TOKEN}\n{caption}".strip() if caption else f"{DEFAULT_IMAGE_TOKEN}"
```

#### Example Usage & Test (Template Formatting)

In [None]:
conv1 = [{'from': 'human', 'value': '<image>
Describe.'}, {'from': 'gpt', 'value': 'This is the caption.'}]
conv2 = [{'from': 'human', 'value': 'Describe.'}, {'from': 'gpt', 'value': '<image>Caption with image token removed.'}]
conv3 = [{'from': 'human', 'value': '<image>
Describe.'}]
conv4 = []

print(f"Test Case 1 (Standard): {format_plain_template(conv1)}")
print(f"Test Case 2 (Image token in caption): {format_plain_template(conv2)}")
print(f"Test Case 3 (No GPT turn): {format_plain_template(conv3)}")
print(f"Test Case 4 (Empty conversation): {format_plain_template(conv4)}")


Test Case 1 (Standard): <image>
This is the caption.
Test Case 2 (Image token in caption): <image>
Caption with image token removed.
Test Case 3 (No GPT turn): <image>
Test Case 4 (Empty conversation): <image>


In [None]:
#| export
class LLaVATextTokenizer(Transform):
    """A fastai Transform to format and tokenize text data for LLaVA stage 1.
    
    Applies the 'plain' template formatting and then tokenizes the text.
    """
    def __init__(self, tokenizer, template_formatter=format_plain_template):
        store_attr()
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be provided and loaded successfully.")
        
    def encodes(self, conversations: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
        """Applies formatting and tokenization to conversation data.

        Args:
            conversations: Raw conversation list from the dataset sample.

        Returns:
            A dictionary containing 'input_ids' and potentially 'attention_mask' 
            as PyTorch tensors. The exact output depends on the tokenizer.
        """
        formatted_text = self.template_formatter(conversations)
        # Tokenize the formatted text. `return_tensors='pt'` is usually handled by 
        # DataBlock/DataLoader, but we ensure it returns tensors if used standalone.
        # We don't pad here; padding is done at the batch level.
        tokenized_output = self.tokenizer(formatted_text, 
                                         return_tensors=None, # Let batch collation handle tensor conversion + padding
                                         add_special_tokens=True, # Add BOS/EOS if tokenizer configured to do so
                                         truncation=False # Truncation can be done later if needed
                                        )
        # Return just the input_ids list for DataBlock item_tfms
        # Attention mask will be created during batch collation
        return {'input_ids': tokenized_output['input_ids']}

#### Example Usage & Test (Tokenizer Transform)

In [None]:
if tokenizer:
    # Example conversation
    example_conv = [{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}]

    # Create the transform instance
    llava_tokenizer_tfm = LLaVATextTokenizer(tokenizer)

    # Apply the transform
    tokenized_result = llava_tokenizer_tfm(example_conv)
    
    print(f"Original Conversations: {example_conv}")
    # Re-format to show what was tokenized
    formatted = format_plain_template(example_conv)
    print(f"Formatted Text: {formatted}") 
    print(f"Tokenized Output: {tokenized_result}")
    
    # Decode for verification
    decoded_tokens = tokenizer.convert_ids_to_tokens(tokenized_result['input_ids'])
    print(f"Decoded Tokens: {decoded_tokens}")
    
    # Ensure the output is a dictionary with 'input_ids'
    assert isinstance(tokenized_result, dict)
    assert 'input_ids' in tokenized_result
    assert isinstance(tokenized_result['input_ids'], list) # Before batching, should be list
else:
    print("Tokenizer not loaded, skipping tokenizer transform test.")


Original Conversations: [{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}]
Formatted Text: <image>
It is red.
Tokenized Output: {'input_ids': [1, 32000, 29871, 13, 490, 338, 2307, 29889, 2]}
Decoded Tokens: ['<s>', '<image>', '\n', 'It', ' is', ' red', '.', '</s>']


---

## Step 1.5: Implement Custom Batch Transform / Collate Function (Stage 1 - Placeholder)

This section will be implemented later. It will include padding, attention mask creation, image token marker replacement (-200), and label masking.

In [None]:
# Placeholder for LLaVABatchTransform class definition

---

## Step 4.1: Update Data Handling for Stage 2 (Placeholder)

This section will adapt text processing for the Vicuna v1 template and update label masking logic.

In [None]:
# Placeholder for format_v1_template function
# Placeholder for updated LLaVABatchTransform logic

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()