# Data Loading

> Functions and classes for loading and parsing datasets for LLaVA-style training.

In [None]:
#| default_exp data.loading

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sys
from pathlib import Path
import os

# --- Project Root Setup --- 
project_root = Path(os.getcwd())
# Check if running from nbs/
if project_root.name == 'nbs' and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent
# Check if running from scripts/
elif project_root.name == 'scripts' and (project_root.parent / 'settings.ini').exists():
     project_root = project_root.parent
elif not (project_root / 'settings.ini').exists():
     # Try going up one level if settings.ini not found directly
     if (project_root.parent / 'settings.ini').exists():
          project_root = project_root.parent
     else:
          print("Warning: Could not automatically determine project root. Assuming current dir.")
          # Fallback: Assume running from project root if structure unknown

project_root_str = str(project_root.resolve())
if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    # print(f"Project root already in sys.path: {project_root_str}") # Less verbose
    pass
# --- End Project Root Setup --- 

Adding project root to sys.path: /workspace/llava


In [None]:
from llava.data.preprocessing import (
        tokenizer, 
        LLaVATextTokenizer, 
        clip_normalize, 
        format_plain_template, 
        format_v1_template, # Import the V1 formatter
        LLaVABatchTransform # Import the updated batch transform
    )

ModuleNotFoundError: No module named 'llava.llava'

In [None]:
#| export
import json
from pathlib import Path
from typing import List, Dict, Any, Union, Tuple
from dataclasses import dataclass
import PIL.Image
from functools import partial

from fastai.vision.all import *
from fastai.data.block import DataBlock, TransformBlock, CategoryBlock # Added CategoryBlock just in case
from fastai.data.transforms import parent_label, GrandparentSplitter, RandomSplitter, IntToFloatTensor
from fastai.data.core import DataLoaders, DataLoader # Import DataLoaders

# Attempt to import from llava utils, handle potential ImportError if running standalone
try:
    from llava.utils import load_config
except ImportError:
    print("Warning: llava.utils not found. load_config function might be unavailable.")
    def load_config(path): return {}
    
# Import necessary items from preprocessing notebook
try:
    from llava.data.preprocessing import (
        tokenizer, 
        LLaVATextTokenizer, 
        clip_normalize, 
        format_plain_template, 
        format_v1_template, # Import the V1 formatter
        LLaVABatchTransform # Import the updated batch transform
    )
except ImportError:
     print("Warning: llava.data.preprocessing not found or incomplete. Data loading might fail.")
     # Define dummy classes/functions if needed for basic script execution
     tokenizer = None
     clip_normalize = lambda x: x # Dummy normalize
     def format_plain_template(conv, tok): return "<image>\nplaceholder"
     def format_v1_template(conv, tok): return "USER: <image>\nplaceholder ASSISTANT: response"
     class LLaVATextTokenizer(Transform):
        def __init__(self, *args, **kwargs): pass
        def encodes(self, x): return [0, 1, 2]
     class LLaVABatchTransform(Transform):
        def __init__(self, *args, **kwargs): self.split_idx=None
        def encodes(self, x): return {'pixel_values': x[0], 'input_ids': x[1], 'attention_mask': torch.ones_like(x[1]), 'labels': x[1]}



## Step 1.1: Data Parsing

We need functions to parse the JSONL files commonly used in LLaVA datasets. Each line typically contains:
- `id`: A unique identifier for the sample.
- `image`: The filename of the image (often relative to an `image_folder`).
- `conversations`: A list of dictionaries, where each dictionary has `from` ('human' or 'gpt') and `value` (the text).

In [None]:
#| export
@dataclass
class LLaVASample:
    """Represents a single sample from a LLaVA-style dataset.

    Attributes:
        sample_id: Unique identifier for the sample.
        image_path: Absolute or relative path to the image file.
        conversations: List of conversation turns (dictionaries with 'from' and 'value').
        data_source: Optional field indicating the source dataset.
    """
    sample_id: str
    image_path: Path
    conversations: List[Dict[str, str]]
    data_source: str | None = None
    image_folder: Path | None = None # Store base image folder for resolving relative paths

In [None]:
show_doc(LLaVASample)

---

### LLaVASample

>      LLaVASample (sample_id:str, image_path:pathlib.Path,
>                   conversations:List[Dict[str,str]],
>                   data_source:str|None=None,
>                   image_folder:pathlib.Path|None=None)

*Represents a single sample from a LLaVA-style dataset.

Attributes:
    sample_id: Unique identifier for the sample.
    image_path: Absolute or relative path to the image file.
    conversations: List of conversation turns (dictionaries with 'from' and 'value').
    data_source: Optional field indicating the source dataset.*

In [None]:
#| export
def parse_llava_jsonl(jsonl_path: Union[str, Path], image_folder: Union[str, Path]) -> List[LLaVASample]:
    """Parses a LLaVA-style JSON Lines file (.jsonl) and prepares LLaVASample objects.
    Args:
        jsonl_path: Path to the JSON Lines file.
        image_folder: Path to the directory containing the images.
    Returns:
        A list of LLaVASample objects.
    Raises:
        FileNotFoundError: If the JSONL file does not exist.
        json.JSONDecodeError: If a line in the file is not valid JSON.
    """
    jsonl_path = Path(jsonl_path)
    image_folder = Path(image_folder).resolve() # Resolve to absolute path
    if not jsonl_path.is_file():
        raise FileNotFoundError(f"JSON Lines file not found: {jsonl_path}")

    samples = []
    with open(jsonl_path, 'r') as f:
        for i, line in enumerate(f):
            try:
                data = json.loads(line.strip()) # Parse each line as JSON
                # Allow flexibility: 'id' might not exist in all formats (e.g., pure caption data)
                # Minimum requirement: 'image' and ('conversations' or 'caption')
                is_valid = isinstance(data, dict) and 'image' in data and ('conversations' in data or 'caption' in data)
                if not is_valid:
                    print(f"Warning: Skipping line {i+1} due to missing 'image' or 'conversations'/'caption' key in {jsonl_path}. Data: {data}")
                    continue
                
                sample_id = data.get('id', f"item_{i}") # Use index if ID is missing
                image_ref = data['image']
                
                # Handle image path: could be relative to image_folder or potentially elsewhere
                # LLaVA often stores relative paths. Assume relative to image_folder for now.
                image_path_relative = Path(image_ref) 
                # Store the relative path in the sample, resolve later if needed or assume loader does.
                # Let's store the intended image folder as well for easier resolution later.
                
                conversations = data.get('conversations')
                caption = data.get('caption')
                data_source = data.get('data_source')
                
                # Handle cases where only 'caption' is present (e.g., CC3M format)
                if conversations is None and caption is not None:
                    conversations = [
                         {"from": "human", "value": "<image>"}, 
                         {"from": "gpt", "value": caption}
                     ]
                elif not isinstance(conversations, list) or not all(isinstance(turn, dict) and 'from' in turn and 'value' in turn for turn in conversations):
                    print(f"Warning: Skipping line {i+1} due to invalid 'conversations' format in {jsonl_path}.")
                    continue

                samples.append(LLaVASample(
                    sample_id=str(sample_id),
                    image_path=image_path_relative, # Store relative path
                    conversations=conversations,
                    data_source=data_source,
                    image_folder=image_folder # Store base folder
                ))

            except json.JSONDecodeError as e:
                print(f"Error decoding JSON on line {i+1} in {jsonl_path}: {e}. Line content: '{line.strip()}'")
                continue # Skip malformed lines
            except Exception as e:
                print(f"Error processing line {i+1} in {jsonl_path}: {e}. Data: {line.strip()}")
                continue
    return samples

In [None]:
show_doc(parse_llava_jsonl)

---

### parse_llava_jsonl

>      parse_llava_jsonl (jsonl_path:Union[str,pathlib.Path],
>                         image_folder:Union[str,pathlib.Path])

*Parses a LLaVA-style JSON Lines file (.jsonl) and prepares LLaVASample objects.
Args:
    jsonl_path: Path to the JSON Lines file.
    image_folder: Path to the directory containing the images.
Returns:
    A list of LLaVASample objects.
Raises:
    FileNotFoundError: If the JSONL file does not exist.
    json.JSONDecodeError: If a line in the file is not valid JSON.*

---

## Step 1.2: Image Loading and Basic Preprocessing

This involves defining the `ImageBlock` and basic `item_tfms` for loading and resizing images. Normalization stats were defined in `11_data_preprocessing.ipynb`.

In [None]:
#| export
# Define the ImageBlock using PILImage for loading images
image_block = ImageBlock(cls=PILImage)

# Define basic item transforms for image processing
# 1. Resize images to 336x336, padding if necessary
#    method='pad': Pads the image to the target size.
#    pad_mode='const': Uses a constant value for padding.
#    pad_value=0: Uses black for padding (common for vision models).
# 2. Convert the image to a PyTorch tensor.
# Note: Normalization will be applied in batch_tfms using clip_normalize.
# Update: Use Resize(336, method='pad', pad_mode=PadMode.Constant, pad_value=0) based on tech spec
basic_image_item_tfms = [
    Resize(336, method='pad', pad_mode=PadMode.Constant, pad_value=0), # Explicit padding
    ToTensor(),
]

---

## Step 1.4: Define Custom Dataset/DataBlock (Stage 1)

This section defines the fastai `DataBlock` for Stage 1 projector pre-training.

In [None]:
#| export
def get_llava_items(config_source: dict, stage: Union[int, str] = 1) -> List[LLaVASample]:
    """Loads LLaVA samples for a specific stage or dataset based on config.

    Args:
        config_source: The main configuration dictionary (passed as the first argument).
        stage: The training stage (1 or 2) or a dataset name string (e.g., 'vqav2_test', 'custom_eval').

    Returns:
        A list of LLaVASample objects.
    """
    config = config_source
    if not isinstance(config, dict):
         raise TypeError(f"Expected configuration dictionary as the first argument, but got type {type(config_source)}")

    data_base_path = Path(config['paths']['data_base'])
    dataset_key = None
    dataset_type = "Training Stage"

    if stage == 1:
        dataset_key = 'stage1_data'
        images_key = 'stage1_images'
    elif stage == 2:
        dataset_key = 'stage2_data'
        images_key = 'stage2_images'
    elif isinstance(stage, str):
        dataset_key = stage
        images_key = stage
        dataset_type = "Dataset"
    else:
        raise ValueError(f"Invalid stage/dataset identifier: {stage}. Must be 1, 2, or a dataset name string from config.")

    # Get paths from config
    dataset_config = config['paths'].get(dataset_key)
    if dataset_config is None:
        raise ValueError(f"{dataset_type} '{stage}' not found in config paths.")

    # Handle different config structures (simple path or dict with annotations/images)
    if isinstance(dataset_config, str): # e.g., stage1_data: path/to/file.jsonl
        json_rel_path = dataset_config
        # Try to get corresponding image key, default if needed
        images_rel_path_config = config['paths'].get(images_key.replace('_data', '_images')) 
        if not images_rel_path_config: # Handle stage 1/2 vs test set conventions
             images_rel_path_config = config['paths'].get(images_key) 
             if isinstance(images_rel_path_config, dict):
                  images_rel_path = images_rel_path_config.get('images', '.') # Default to base if only annotation provided for stage
             else:
                  images_rel_path = images_rel_path_config or '.'
        else:
            images_rel_path = images_rel_path_config
            
    elif isinstance(dataset_config, dict): # e.g., vqav2_test: {annotations: ..., images: ...}
        json_rel_path = dataset_config.get('annotations')
        images_rel_path = dataset_config.get('images')
        if json_rel_path is None or images_rel_path is None:
             raise ValueError(f"Configuration for dataset '{stage}' must contain 'annotations' and 'images' keys.")
    else:
        raise ValueError(f"Invalid configuration format for dataset '{stage}' in config paths.")

    json_path = data_base_path / json_rel_path
    base_image_folder = data_base_path / images_rel_path

    print(f"Loading {dataset_type} '{stage}' items from: {json_path}")
    print(f"Assuming image paths relative to: {base_image_folder}")

    if not json_path.exists():
        raise FileNotFoundError(f"{dataset_type} '{stage}' JSON/JSONL file not found: {json_path}")

    # Use the JSONL parsing function
    samples = parse_llava_jsonl(json_path, base_image_folder)
    print(f"Found {len(samples)} samples for {dataset_type} '{stage}'.")
    return samples

def get_image_path(sample: LLaVASample) -> Path:
    """Extracts and potentially resolves the image path from a LLaVASample."""
    # Resolve the path relative to the stored image_folder
    if sample.image_folder and not sample.image_path.is_absolute():
        resolved_path = sample.image_folder / sample.image_path
        # Add check if file exists for debugging
        if not resolved_path.exists():
            # Try resolving relative to the config's data_base path as a fallback
            try:
                _cfg = load_config(project_root / 'configs' / 'config.yaml')
                db_path = Path(_cfg['paths']['data_base'])
                alt_path = db_path / sample.image_path
                if alt_path.exists():
                    return alt_path
                else:
                     print(f"Warning: Image file not found at resolved path: {resolved_path} or alternative {alt_path}")
            except Exception:
                 print(f"Warning: Image file not found at resolved path: {resolved_path}")
            # Return the best guess path even if not found, error handled by ImageBlock later
            return resolved_path
        return resolved_path
    return sample.image_path

def get_conversations(sample: LLaVASample) -> list:
    """Extracts the conversations from a LLaVASample."""
    return sample.conversations

In [None]:
#| export
# Define the DataBlock for Stage 1 (Projector Pre-training)
LLaVADataBlockStage1 = None
if tokenizer and 'LLaVABatchTransform' in globals() and 'LLaVATextTokenizer' in globals():
    # Instantiate the batch transform for stage 1 ('plain' template)
    llava_batch_tfm_stage1 = LLaVABatchTransform(tokenizer=tokenizer, 
                                                 normalize_tfm=clip_normalize,
                                                 template='plain')

    LLaVADataBlockStage1 = DataBlock(
        blocks=(ImageBlock(cls=PILImage), TransformBlock), # Output of get_y is list (handled by TransformBlock)
        get_items=partial(get_llava_items, stage=1),
        get_x=get_image_path,
        get_y=get_conversations,
        splitter=RandomSplitter(valid_pct=0.01, seed=42), # Example split
        item_tfms=[
            *basic_image_item_tfms,
            # Apply text tokenization using the 'plain' formatter
            LLaVATextTokenizer(tokenizer, template_formatter=format_plain_template)
        ],
        batch_tfms=[
            # Apply batch transformations including masking for 'plain' template
            llava_batch_tfm_stage1
        ]
    )
    print("LLaVADataBlockStage1 defined.")
else:
    LLaVADataBlockStage1 = None
    print("Tokenizer, LLaVABatchTransform, or LLaVATextTokenizer not available, LLaVADataBlockStage1 not defined.")

V1 template assistant role tokens: [1792, 29889]
LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 0, Template: plain
LLaVADataBlockStage1 defined.


---

## Step 1.6: Create DataLoaders (Stage 1)

This section shows how to create the `DataLoaders` object from the `DataBlock`.

In [None]:
#| export
def get_stage1_dataloaders(config: dict, dblock: DataBlock = LLaVADataBlockStage1) -> DataLoaders:
    """Creates fastai DataLoaders for Stage 1 training.

    Args:
        config: The main configuration dictionary.
        dblock: The configured DataBlock for Stage 1 (defaults to LLaVADataBlockStage1).

    Returns:
        A fastai DataLoaders object.

    Raises:
        ValueError: If the DataBlock is not defined.
        FileNotFoundError: If data paths are invalid during DataBlock processing.
    """
    if dblock is None:
        raise ValueError("Stage 1 DataBlock is not defined. Ensure dependencies are available.")

    batch_size = config.get('data', {}).get('batch_size_per_device_stage1', 8)
    num_workers = config.get('data', {}).get('num_workers', 4)

    print(f"Creating Stage 1 DataLoaders with batch size: {batch_size}, num_workers: {num_workers}")

    # The DataBlock's get_items function needs the config,
    # we pass it here when calling dataloaders()
    try:
        # Pass config explicitly to dataloaders, which passes it down to get_items
        dls = dblock.dataloaders(source=config, # Pass config to be used by get_items
                                 bs=batch_size,
                                 num_workers=num_workers,
                                 pin_memory=(torch.cuda.is_available())) # Pin memory if using GPU
        print("DataLoaders created successfully.")
        return dls
    except FileNotFoundError as e:
        print(f"Error creating DataLoaders: {e}")
        print("Please ensure data paths in config.yaml are correct and data exists.")
        raise e
    except Exception as e:
        import traceback
        print(f"An unexpected error occurred during DataLoaders creation: {e}")
        traceback.print_exc() # Print full traceback for debugging
        raise e

---

## Step 4.1: Update Data Handling for Stage 2

This section contains the function to create DataLoaders specifically for Stage 2 instruction tuning, using the Vicuna v1 chat template and corresponding label masking.

In [None]:
#| export
# Define the DataBlock for Stage 2 (Instruction Tuning)
LLaVADataBlockStage2 = None
if tokenizer and 'LLaVABatchTransform' in globals() and 'LLaVATextTokenizer' in globals() and 'format_v1_template' in globals():
    # Instantiate the batch transform for stage 2 ('v1' template)
    llava_batch_tfm_stage2 = LLaVABatchTransform(tokenizer=tokenizer, 
                                                 normalize_tfm=clip_normalize,
                                                 template='v1')
    
    LLaVADataBlockStage2 = DataBlock(
        blocks=(ImageBlock(cls=PILImage), TransformBlock),
        # Use stage=2 for get_items
        get_items=partial(get_llava_items, stage=2),
        get_x=get_image_path,
        get_y=get_conversations,
        splitter=RandomSplitter(valid_pct=0.01, seed=42), # Adjust split as needed
        item_tfms=[
            *basic_image_item_tfms,
            # Apply text tokenization using the 'v1' formatter
            LLaVATextTokenizer(tokenizer, template_formatter=format_v1_template)
        ],
        batch_tfms=[
            # Apply batch transformations including masking for 'v1' template
            llava_batch_tfm_stage2
        ]
    )
    print("LLaVADataBlockStage2 defined.")
else:
    print("Dependencies missing, LLaVADataBlockStage2 not defined.")

# Function to create Stage 2 DataLoaders
def get_stage2_dataloaders(config: dict, dblock: DataBlock = LLaVADataBlockStage2) -> DataLoaders:
    """Creates fastai DataLoaders for Stage 2 training (Instruction Tuning).

    Args:
        config: The main configuration dictionary.
        dblock: The configured DataBlock for Stage 2 (defaults to LLaVADataBlockStage2).

    Returns:
        A fastai DataLoaders object.

    Raises:
        ValueError: If the DataBlock is not defined.
        FileNotFoundError: If data paths are invalid during DataBlock processing.
    """
    if dblock is None:
        raise ValueError("Stage 2 DataBlock is not defined. Ensure dependencies are available.")

    batch_size = config.get('data', {}).get('batch_size_per_device_stage2', 4) # Use stage2 batch size
    num_workers = config.get('data', {}).get('num_workers', 4)

    print(f"Creating Stage 2 DataLoaders with batch size: {batch_size}, num_workers: {num_workers}")

    try:
        # Pass config to dataloaders to be used by get_items
        dls = dblock.dataloaders(source=config, 
                                 bs=batch_size,
                                 num_workers=num_workers,
                                 pin_memory=(torch.cuda.is_available()))
        print("Stage 2 DataLoaders created successfully.")
        return dls
    except FileNotFoundError as e:
        print(f"Error creating Stage 2 DataLoaders: {e}")
        print("Please ensure Stage 2 data paths in config.yaml are correct and data exists.")
        raise e
    except Exception as e:
        import traceback
        print(f"An unexpected error occurred during Stage 2 DataLoaders creation: {e}")
        traceback.print_exc()
        raise e

V1 template assistant role tokens: [1792, 29889]
LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 0, Template: v1
LLaVADataBlockStage2 defined.


#### Example Usage & Test (Stage 2 DataLoaders)

In [None]:
#| test
try:
    # Use relative path from nbs directory for testing
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    
    # --- Test Setup: Create dummy Stage 2 data ---
    data_base = Path(config['paths']['data_base'])
    stage2_json_rel = Path(config['paths']['stage2_data'])
    # Stage 2 images might reference paths from stage 1 or other datasets
    stage1_img_rel = Path(config['paths']['stage1_images'])
    stage1_img_path = data_base / stage1_img_rel

    stage2_json_path = data_base / stage2_json_rel
    stage2_json_path.parent.mkdir(parents=True, exist_ok=True)
    stage1_img_path.mkdir(parents=True, exist_ok=True)
    
    # Assume dummy images from Stage 1 exist
    # Use string paths relative to data_base for JSON content
    img1_rel_path = stage1_img_rel.name + '/dummy_img1.jpg' 
    img2_rel_path = stage1_img_rel.name + '/dummy_img2.png'
    if not (stage1_img_path / 'dummy_img1.jpg').exists():
         PIL.Image.new('RGB', (60, 30), color = 'red').save(stage1_img_path / 'dummy_img1.jpg')
    if not (stage1_img_path / 'dummy_img2.png').exists():
         PIL.Image.new('RGB', (60, 30), color = 'green').save(stage1_img_path / 'dummy_img2.png')
         
    if not stage2_json_path.exists() or stage2_json_path.stat().st_size < 10:
        print(f"Creating dummy Stage 2 JSONL: {stage2_json_path}")
        dummy_jsonl_content = [
            {"id": "s2_001", "image": str(img1_rel_path), "conversations": [{"from": "human", "value": "<image>\nDescribe image."}, {"from": "gpt", "value": "It is a red object."}]}, 
            {"id": "s2_002", "image": str(img2_rel_path), "conversations": [{"from": "human", "value": "<image>\nIs it green?"}, {"from": "gpt", "value": "Yes, it appears green."}]},
        ]
        with open(stage2_json_path, 'w') as f:
            for item in dummy_jsonl_content:
                f.write(json.dumps(item) + '\n')
    # --- End Test Setup ---

    # Get Stage 2 DataLoaders
    dls_stage2 = get_stage2_dataloaders(config)
    
    assert isinstance(dls_stage2, DataLoaders)
    assert len(dls_stage2.train_ds) > 0
    assert len(dls_stage2.valid_ds) > 0
    
    print("\nTesting one_batch for Stage 2 DataLoaders...")
    b = dls_stage2.one_batch()
    print("one_batch() retrieved. Check shapes and content.")
    assert isinstance(b, dict) # Batch transform should output dict
    assert 'pixel_values' in b
    assert 'input_ids' in b
    assert 'attention_mask' in b
    assert 'labels' in b
    print(f"Batch Keys: {b.keys()}")
    print(f"pixel_values shape: {b['pixel_values'].shape}")
    print(f"input_ids shape: {b['input_ids'].shape}")
    print(f"attention_mask shape: {b['attention_mask'].shape}")
    print(f"labels shape: {b['labels'].shape}")

    # Decode one example for visual inspection
    print("\n--- Decoded Example from Batch --- ")
    # Re-create batch transform locally for decode if needed
    temp_batch_tfm = LLaVABatchTransform(tokenizer, clip_normalize, template='v1')
    decoded_batch = temp_batch_tfm.decode(b)
    img_decoded, text_decoded = decoded_batch[0][0], decoded_batch[1][0]
    print(f"Image Shape: {img_decoded.shape}")
    print(f"Decoded Text: {text_decoded}")
    print("--- End Decoded Example ---")
    
    print("\nStage 2 DataLoaders test passed (basic check).")

except FileNotFoundError as e:
    print(f"Skipping test: FileNotFoundError - {e}")
except Exception as e:
    import traceback
    print(f"An error occurred during Stage 2 DataLoaders test: {e}")
    traceback.print_exc()

Loaded config from ../configs/config.yaml
Creating dummy Stage 2 JSONL: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Creating Stage 2 DataLoaders with batch size: 4, num_workers: 4
Loading Stage 2 items from: /workspace/llava/data/llava_instruct_150k/llava_v1_5_mix665k.jsonl
Assuming image paths relative to: /workspace/llava/data
Found 2 samples for Stage 2.
Stage 2 DataLoaders created successfully.

Testing one_batch for Stage 2 DataLoaders...


  if self.num_workers > 0 and warn:


one_batch() retrieved. Check shapes and content.
Batch Keys: dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels'])
pixel_values shape: torch.Size([2, 3, 336, 336])
input_ids shape: torch.Size([2, 68])
attention_mask shape: torch.Size([2, 68])
labels shape: torch.Size([2, 68])

--- Decoded Example from Batch --- 
Image Shape: torch.Size([3, 336, 336])
Decoded Text: USER: <image> Describe image. ASSISTANT: It is a red object.
--- End Decoded Example ---

Stage 2 DataLoaders test passed (basic check).


---

## Test DataLoader Function

This function creates a DataLoader specifically for a test set, using Stage 2 preprocessing but without shuffling.

In [None]:
#| export
def get_test_dataloader(config: dict, dataset_name: str, dblock: DataBlock = LLaVADataBlockStage2) -> DataLoader:
    """Creates a fastai DataLoader for a specific test dataset.

    Uses the Stage 2 DataBlock definition but configures it for the specified test set.
    Disables shuffling and uses the full dataset (no splitter).

    Args:
        config: The main configuration dictionary.
        dataset_name: The key corresponding to the test set in config['paths'] 
                      (e.g., 'vqav2_test', 'textvqa_val', 'custom_eval').
        dblock: The configured DataBlock to use (defaults to LLaVADataBlockStage2).

    Returns:
        A fastai DataLoader object for the test set.

    Raises:
        ValueError: If the DataBlock is not defined or the dataset_name is not found in config.
        FileNotFoundError: If data paths are invalid during DataBlock processing.
    """
    if dblock is None:
        raise ValueError("Stage 2 DataBlock is not defined. Cannot create test dataloader.")

    if dataset_name not in config['paths']:
         raise ValueError(f"Test dataset '{dataset_name}' not found in config['paths'].")

    # Use evaluation batch size, fall back to stage 2 size if not specified
    batch_size = config.get('evaluation', {}).get('eval_batch_size_per_device', 
                    config.get('data', {}).get('batch_size_per_device_stage2', 4))
    num_workers = config.get('data', {}).get('num_workers', 4)

    print(f"Creating Test DataLoader for '{dataset_name}' with batch size: {batch_size}, num_workers: {num_workers}")

    # Create a temporary copy of the datablock to modify for the test set
    # Avoids modifying the original Stage 2 datablock
    test_dblock = copy.deepcopy(dblock) 
    
    # 1. Change get_items to point to the test dataset
    test_dblock.get_items = partial(get_llava_items, stage=dataset_name)
    # 2. Remove splitter to use the full dataset
    test_dblock.splitter = None 

    try:
        # Create datasets object first
        datasets = test_dblock.datasets(source=config)
        
        if len(datasets) == 0:
             print(f"Warning: No items found for dataset '{dataset_name}'. DataLoader will be empty.")
             
        # Create the test DataLoader (index 0 since there's no split)
        # Ensure shuffle=False for evaluation
        test_dl = DataLoader(datasets,
                             bs=batch_size,
                             num_workers=num_workers,
                             pin_memory=(torch.cuda.is_available()),
                             shuffle=False, # Important for evaluation
                             drop_last=False # Keep all samples
                            )
        
        # Store items in dl for later retrieval if needed (e.g., for sample IDs)
        # This relies on DataBlock populating `items` correctly
        test_dl.items = datasets.items
        test_dl.idxs = list(range(len(datasets))) # Store original indices
        
        print(f"Test DataLoader for '{dataset_name}' created successfully ({len(datasets)} samples).")
        return test_dl
    except FileNotFoundError as e:
        print(f"Error creating Test DataLoader for {dataset_name}: {e}")
        print("Please ensure test data paths in config.yaml are correct and data exists.")
        raise e
    except Exception as e:
        import traceback
        print(f"An unexpected error occurred during Test DataLoader creation for {dataset_name}: {e}")
        traceback.print_exc()
        raise e

#### Test Custom Eval DataLoader Creation (Step 7.4)

In [None]:
#| test
import shutil
try:
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    
    # --- Setup: Create dummy custom_eval data --- 
    data_base = Path(config['paths']['data_base'])
    custom_eval_config = config['paths'].get('custom_eval')
    if not custom_eval_config:
         raise ValueError("Missing 'custom_eval' configuration in paths.")
         
    custom_ann_rel = Path(custom_eval_config['annotations'])
    custom_img_rel = Path(custom_eval_config['images'])
    
    custom_ann_path = data_base / custom_ann_rel
    custom_img_path = data_base / custom_img_rel
    
    custom_ann_path.parent.mkdir(parents=True, exist_ok=True)
    custom_img_path.mkdir(parents=True, exist_ok=True)
    
    if not custom_ann_path.exists() or custom_ann_path.stat().st_size < 10:
        print(f"Creating dummy custom_eval JSONL: {custom_ann_path}")
        dummy_content = [
            {"id": "cust_001", "image": "custom_img_1.jpg", "conversations": [{"from": "human", "value": "<image>\nWhat is the small text?"}, {"from": "gpt", "value": "Micro text."}]},
            {"id": "cust_002", "image": "custom_img_2.jpg", "conversations": [{"from": "human", "value": "<image>\nIdentify the object."}, {"from": "gpt", "value": "A detailed widget."}]},
        ]
        with open(custom_ann_path, 'w') as f: [f.write(json.dumps(item) + '\n') for item in dummy_content]
    
    img1_path = custom_img_path / 'custom_img_1.jpg'
    img2_path = custom_img_path / 'custom_img_2.jpg'
    if not img1_path.exists():
        PIL.Image.new('RGB', (64, 64), color = 'blue').save(img1_path)
        print(f"Creating dummy custom_eval image: {img1_path}")
    if not img2_path.exists():
        PIL.Image.new('RGB', (32, 32), color = 'yellow').save(img2_path)
        print(f"Creating dummy custom_eval image: {img2_path}")
    # --- End Setup --- 
    
    # Get the custom eval dataloader
    if LLaVADataBlockStage2 is None:
         print("Warning: LLaVADataBlockStage2 not defined. Skipping custom eval dataloader test.")
    else:
         dl_custom = get_test_dataloader(config, 'custom_eval', dblock=LLaVADataBlockStage2)
         assert isinstance(dl_custom, DataLoader)
         assert len(dl_custom.dataset) == 2 # Should match dummy data
         assert dl_custom.bs > 0
         print("Custom Eval DataLoader test passed.")
    
    # Clean up dummy data (optional)
    # shutil.rmtree(data_base / 'custom_eval', ignore_errors=True)

except ValueError as e:
     print(f"Skipping test: Configuration missing or invalid - {e}")
except FileNotFoundError as e:
     print(f"Skipping test: FileNotFoundError - {e}")
except Exception as e:
     import traceback
     print(f"An error occurred during Custom Eval DataLoader test: {e}")
     traceback.print_exc()

Loaded config from ../configs/config.yaml
Creating dummy custom_eval JSONL: /workspace/llava/data/custom_eval/annotations.jsonl
Creating dummy custom_eval image: /workspace/llava/data/custom_eval/images/custom_img_1.jpg
Creating dummy custom_eval image: /workspace/llava/data/custom_eval/images/custom_img_2.jpg
Creating Test DataLoader for 'custom_eval' with batch size: 4, num_workers: 4
Loading Dataset 'custom_eval' items from: /workspace/llava/data/custom_eval/annotations.jsonl
Assuming image paths relative to: /workspace/llava/data/custom_eval/images
Found 2 samples for Dataset 'custom_eval'.
Test DataLoader for 'custom_eval' created successfully (2 samples).
Custom Eval DataLoader test passed.


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()