# Data Loading

> Functions and classes for loading and parsing datasets for LLaVA-style training.

In [1]:
#| default_exp data.loading

In [1]:
#| hide
from nbdev.showdoc import *

In [2]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [3]:
#| export
import json
from pathlib import Path
from typing import List, Dict, Any, Union, Tuple
from dataclasses import dataclass
import PIL.Image
from functools import partial

from fastai.vision.all import *
from fastai.data.block import DataBlock, TransformBlock, CategoryBlock # Added CategoryBlock just in case
from fastai.data.transforms import parent_label, GrandparentSplitter, RandomSplitter, IntToFloatTensor

from llava.utils import load_config
# Import necessary items from preprocessing notebook
from llava.data.preprocessing import (
    tokenizer, 
    LLaVATextTokenizer, 
    clip_normalize, 
    format_plain_template, # Needed if not using LLaVATextTokenizer directly
    LLaVABatchTransform # Import the custom batch transform
)

Project root already in sys.path: /workspace/llava


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Successfully loaded CLIP image processor for: openai/clip-vit-large-patch14-336
Successfully loaded tokenizer for: lmsys/vicuna-7b-v1.5
Adding special token <image> to tokenizer.
Added 1 token(s). New vocab size: 32001
Using token ID for <image>: 32000
Tokenizer already has pad token: <unk> (ID: 0)


## Step 1.1: Data Parsing

We need functions to parse the JSONL files commonly used in LLaVA datasets. Each line typically contains:
- `id`: A unique identifier for the sample.
- `image`: The filename of the image (often relative to an `image_folder`).
- `conversations`: A list of dictionaries, where each dictionary has `from` ('human' or 'gpt') and `value` (the text).

In [4]:
#| export
@dataclass
class LLaVASample:
    """Represents a single sample from a LLaVA-style dataset.

    Attributes:
        sample_id: Unique identifier for the sample.
        image_path: Absolute path to the image file.
        conversations: List of conversation turns (dictionaries with 'from' and 'value').
        data_source: Optional field indicating the source dataset.
    """
    sample_id: str
    image_path: Path
    conversations: List[Dict[str, str]]
    data_source: str | None = None

In [5]:
show_doc(LLaVASample)

---

### LLaVASample

>      LLaVASample (sample_id:str, image_path:pathlib.Path,
>                   conversations:List[Dict[str,str]],
>                   data_source:str|None=None)

*Represents a single sample from a LLaVA-style dataset.

Attributes:
    sample_id: Unique identifier for the sample.
    image_path: Absolute path to the image file.
    conversations: List of conversation turns (dictionaries with 'from' and 'value').
    data_source: Optional field indicating the source dataset.*

In [5]:
#| export
def parse_llava_json(json_path: Union[str, Path], image_folder: Union[str, Path]) -> list:
    """Parses a LLaVA-style JSON file (containing a list of objects) and resolves image paths.
       Note: This function assumes the JSON file contains a single JSON array, not JSON Lines.

    Args:
        json_path: Path to the JSON file (expected to contain a list).
        image_folder: Path to the directory containing the images referenced in the JSON file.

    Returns:
        A list of LLaVASample objects.

    Raises:
        FileNotFoundError: If the JSON file does not exist.
        json.JSONDecodeError: If the file content is not a valid JSON list.
        TypeError: If the parsed JSON is not a list.
    """
    json_path = Path(json_path)
    image_folder = Path(image_folder)

    if not json_path.is_file():
        raise FileNotFoundError(f"JSON file not found: {json_path}")

    samples = []
    try:
        with open(json_path, 'r') as f:
            content = f.read()
            full_data_list = json.loads(content)
    except json.JSONDecodeError as e:
        # Add context to the original error
        raise json.JSONDecodeError(f"Error decoding JSON file {json_path}: {e.msg}", e.doc, e.pos) from e
    except Exception as e:
        raise RuntimeError(f"Error reading file {json_path}: {e}") from e


    if not isinstance(full_data_list, list):
        raise TypeError(f"Expected JSON file to contain a list of objects, but got type {type(full_data_list)} in {json_path}")

    for i, data in enumerate(full_data_list):
        try:
            # Check for required keys
            if not isinstance(data, dict) or not all(k in data for k in ['id', 'image', 'conversations']):
                print(f"Warning: Skipping item index {i} due to missing keys ('id', 'image', or 'conversations') or incorrect format in {json_path}. Data: {data}")
                continue

            sample_id = data['id']
            # Construct the full image path
            image_ref = data['image']
            if isinstance(image_ref, dict) and 'path' in image_ref:
                image_filename = image_ref['path']
            elif isinstance(image_ref, str):
                image_filename = image_ref
            else:
                 print(f"Warning: Skipping item index {i} due to unexpected image field format in {json_path}. Expected string or dict with 'path', got: {type(image_ref)}")
                 continue

            # Resolve the image path relative to image_folder
            image_path = image_folder / Path(image_filename).name

            conversations = data['conversations']
            data_source = data.get('data_source')

            # Basic validation for conversations format
            if not isinstance(conversations, list) or not all(isinstance(turn, dict) and 'from' in turn and 'value' in turn for turn in conversations):
                 print(f"Warning: Skipping item index {i} due to invalid 'conversations' format in {json_path}.")
                 continue

            # Optional: Check if image exists (can be slow)
            # if not image_path.is_file():
            #     print(f"Warning: Image file not found for sample {sample_id} at {image_path}, skipping item index {i}.")
            #     continue

            samples.append(LLaVASample(
                sample_id=str(sample_id),
                image_path=image_path,
                conversations=conversations,
                data_source=data_source
            ))
        except Exception as e:
            print(f"Error processing item index {i} in {json_path}: {e}. Data: {data}")
            # Optionally re-raise or just continue
            continue


    return samples

In [7]:
show_doc(parse_llava_json)

---

### parse_llava_json

>      parse_llava_json (json_path:Union[str,pathlib.Path],
>                        image_folder:Union[str,pathlib.Path])

*Parses a LLaVA-style JSON file (containing a list of objects) and resolves image paths.
   Note: This function assumes the JSON file contains a single JSON array, not JSON Lines.

Args:
    json_path: Path to the JSON file (expected to contain a list).
    image_folder: Path to the directory containing the images referenced in the JSON file.

Returns:
    A list of LLaVASample objects.

Raises:
    FileNotFoundError: If the JSON file does not exist.
    json.JSONDecodeError: If the file content is not a valid JSON list.
    TypeError: If the parsed JSON is not a list.*

#### Example Usage & Test

In [8]:
#| eval: false
# Create dummy data for testing
dummy_data_dir = Path('./dummy_data')
dummy_img_dir = dummy_data_dir / 'images'
dummy_jsonl_path = dummy_data_dir / 'dummy_llava_data.jsonl'

dummy_img_dir.mkdir(parents=True, exist_ok=True)
print(f"Created dummy directory: {dummy_img_dir}")

# Create dummy image files (attempt, may fail if PIL issues persist)
try:
    PIL.Image.new('RGB', (60, 30), color = 'red').save(dummy_img_dir / 'img1.jpg')
    PIL.Image.new('RGB', (60, 30), color = 'green').save(dummy_img_dir / 'img2.png')
except Exception as e:
    print(f"Note: Could not create dummy image files (PIL might not be fully installed or usable): {e}")

# Create dummy jsonl content
dummy_jsonl_content = [
    {"id": "sample1", "image": "img1.jpg", "conversations": [{"from": "human", "value": "<image>\nDescribe this."}, {"from": "gpt", "value": "It is red."}]}, 
    {"id": "sample2", "image": "img2.png", "conversations": [{"from": "human", "value": "<image>\nWhat color?"}, {"from": "gpt", "value": "Green."}]},
    {"id": "sample3_missing_keys", "conversations": []}, # Missing image/id
    {"id": "sample4_bad_conv", "image": "img1.jpg", "conversations": "not a list"}, # Bad conversation format
    {"id": "sample5_missing_img_file", "image": "nonexistent.jpg", "conversations": [{"from": "human", "value": "<image>"}, {"from": "gpt", "value": "..."}]},
]

with open(dummy_jsonl_path, 'w') as f:
    for item in dummy_jsonl_content:
        f.write(json.dumps(item) + '\n')
print(f"Created dummy jsonl file: {dummy_jsonl_path}")

# Test parsing
try:
    parsed_samples = parse_llava_json(dummy_jsonl_path, dummy_img_dir)
    print(f"Successfully parsed {len(parsed_samples)} samples:")
    for sample in parsed_samples:
        print(sample)

    # Basic checks (adjust expected length based on warnings/skips)
    # Check based on the print output, 3 samples are expected now (1, 2, 5)
    assert len(parsed_samples) == 3 
    assert parsed_samples[0].sample_id == 'sample1'
    # Resolve paths for comparison
    assert parsed_samples[0].image_path.resolve() == (dummy_img_dir / 'img1.jpg').resolve()
    assert parsed_samples[1].sample_id == 'sample2'
    assert parsed_samples[1].image_path.resolve() == (dummy_img_dir / 'img2.png').resolve()
    assert parsed_samples[0].conversations[0]['from'] == 'human'
    assert parsed_samples[2].sample_id == 'sample5_missing_img_file'

except FileNotFoundError as e:
    print(f"Error: {e}")
except json.JSONDecodeError as e:
    print(f"JSON Parsing Error: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

# Clean up dummy data (optional)
# import shutil
# if dummy_data_dir.exists():
#    shutil.rmtree(dummy_data_dir)
#    print(f"Cleaned up dummy data directory: {dummy_data_dir}")

Created dummy directory: dummy_data/images
Created dummy jsonl file: dummy_data/dummy_llava_data.jsonl
JSON Parsing Error: Error decoding JSON file dummy_data/dummy_llava_data.jsonl: Extra data: line 2 column 1 (char 153)


---

## Step 1.2: Image Loading and Basic Preprocessing

This involves defining the `ImageBlock` and basic `item_tfms` for loading and resizing images. Normalization stats were defined in `11_data_preprocessing.ipynb`.

In [6]:
#| export
# Define the ImageBlock using PILImage for loading images
image_block = ImageBlock(cls=PILImage)

# Define basic item transforms for image processing
# 1. Resize images to 336x336, padding if necessary
#    method='pad': Pads the image to the target size.
#    pad_mode='const': Uses a constant value for padding.
#    pad_value=0: Uses black for padding (common for vision models).
# 2. Convert the image to a PyTorch tensor.
# Note: Normalization will be applied in batch_tfms using clip_normalize.
basic_image_item_tfms = [
    Resize(336),
    ToTensor(),
]

---

## Step 1.4: Define Custom Dataset/DataBlock (Stage 1)

This section defines the fastai `DataBlock` for Stage 1 projector pre-training.

In [7]:
#| export
def get_llava_items(config_source: dict, stage: int = 1) -> list:
    """Loads LLaVA samples for a specific stage based on config passed as source.

    Args:
        config_source: The main configuration dictionary (passed as the first argument).
        stage: The training stage (1 or 2).

    Returns:
        A list of LLaVASample objects.
    # ... (rest of docstring/implementation) ...
    """
    config = config_source
    if not isinstance(config, dict):
         raise TypeError(f"Expected configuration dictionary as the first argument, but got type {type(config_source)}")

    data_base_path = Path(config['paths']['data_base'])
    if stage == 1:
        json_rel_path = config['paths']['stage1_data'] # Renamed variable for clarity
        images_rel_path = config['paths']['stage1_images']
    elif stage == 2:
        json_rel_path = config['paths']['stage2_data']
        images_rel_path = config['paths'].get('stage2_images', '.')
    else:
        raise ValueError(f"Invalid stage specified: {stage}. Must be 1 or 2.")

    json_path = data_base_path / json_rel_path # Path to the JSON file
    image_folder = data_base_path / images_rel_path

    print(f"Loading Stage {stage} items from: {json_path}")
    print(f"Assuming images relative to: {image_folder}")

    if not json_path.exists():
        raise FileNotFoundError(f"Stage {stage} JSON file not found: {json_path}") # Updated error message
    if not image_folder.is_dir() and stage == 1:
        print(f"Warning: Stage {stage} image folder not found or not a directory: {image_folder}")

    # Use the updated parsing function
    samples = parse_llava_json(json_path, image_folder) # <--- Use parse_llava_json
    print(f"Found {len(samples)} samples for Stage {stage}.")
    return samples

# Keep get_image_path and get_conversations as they are
def get_image_path(sample: LLaVASample) -> Path:
    """Extracts the image path from a LLaVASample."""
    return sample.image_path

def get_conversations(sample: LLaVASample) -> list:
    """Extracts the conversations from a LLaVASample."""
    return sample.conversations

In [8]:
#| export
# Define the DataBlock for Stage 1 (Projector Pre-training)
if tokenizer:
    # Instantiate the custom batch transform (needs tokenizer)
    llava_batch_tfm = LLaVABatchTransform(tokenizer=tokenizer, normalize_tfm=clip_normalize)

    LLaVADataBlockStage1 = DataBlock(
        blocks=(ImageBlock(cls=PILImage), TransformBlock),
        # get_items expects config as first arg (source), use partial only for stage
        get_items=partial(get_llava_items, stage=1),
        get_x=get_image_path,
        get_y=get_conversations,
        splitter=RandomSplitter(valid_pct=0.01, seed=42),
        item_tfms=[
            *basic_image_item_tfms,
            LLaVATextTokenizer(tokenizer, template_formatter=format_plain_template)
        ],
        batch_tfms=[
            llava_batch_tfm
        ]
    )
    print("LLaVADataBlockStage1 defined.")
else:
    LLaVADataBlockStage1 = None
    print("Tokenizer not available, LLaVADataBlockStage1 not defined.")

LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 0
LLaVADataBlockStage1 defined.


#### Example Usage & Test (DataBlock Definition)

In [12]:
#| eval: false
if LLaVADataBlockStage1:
    try:
        config = load_config('../configs/config.yaml')
        print("Attempting DataBlock summary...")
        # Pass the config dictionary AS the source argument
        LLaVADataBlockStage1.summary(source=config, bs=4) # Pass config to source
    except FileNotFoundError as e:
        print(f"\nSkipping summary: FileNotFoundError: {e}")
        print("Please ensure 'paths.data_base', 'paths.stage1_data', and 'paths.stage1_images' are correctly set in configs/config.yaml and point to existing data.")
    except Exception as e:
        import traceback
        print(f"\nSkipping summary: Exception occurred: {e}")
        traceback.print_exc() # Print stack trace for debugging
        print("Check paths in config.yaml and ensure data files are accessible and correctly formatted.")
else:
    print("LLaVADataBlockStage1 not defined, cannot show summary.")

Attempting DataBlock summary...
Setting-up type transforms pipelines
Collecting items from {'project_name': 'llava', 'version': 1.0, 'paths': {'data_base': '/workspace/llava/data/', 'output_dir': '/workspace/llava/output/', 'stage1_data': 'llava_pretrain/llava_pretrain.jsonl', 'stage1_images': 'llava_pretrain/images', 'stage2_data': 'llava_instruct_150k/llava_v1_5_mix665k.jsonl', 'stage2_images': '/', 'vqav2_test': 'vqav2/test2015', 'vqav2_test_annotations': 'vqav2/v2_mscoco_test2015_annotations.json', 'textvqa_val': 'textvqa/TextVQA_0.5.1_val.json', 'textvqa_images': 'textvqa/train_images', 'stage1_projector_weights': 'stage1_projector.pth', 'stage2_model_weights': 'stage2_full_model'}, 'model': {'llm_name_or_path': 'lmsys/vicuna-7b-v1.5', 'vision_encoder_name_or_path': 'openai/clip-vit-large-patch14-336', 'vision_feature_layer': -2, 'image_token': '<image>', 'image_token_index_marker': -200, 'projector': {'type': 'mlp_2x', 'input_dim': 1024, 'output_dim': 4096}, 'peft': {'use_lora': 

---

## Step 1.6: Create DataLoaders (Stage 1)

This section shows how to create the `DataLoaders` object from the `DataBlock`.

In [9]:
#| export
def get_stage1_dataloaders(config: dict, dblock: DataBlock = LLaVADataBlockStage1) -> DataLoaders:
    """Creates fastai DataLoaders for Stage 1 training.

    Args:
        config: The main configuration dictionary.
        dblock: The configured DataBlock for Stage 1 (defaults to LLaVADataBlockStage1).

    Returns:
        A fastai DataLoaders object.

    Raises:
        ValueError: If the DataBlock is not defined.
        FileNotFoundError: If data paths are invalid during DataBlock processing.
    """
    if dblock is None:
        raise ValueError("Stage 1 DataBlock is not defined. Ensure tokenizer loaded correctly.")

    batch_size = config.get('data', {}).get('batch_size_per_device_stage1', 8)
    num_workers = config.get('data', {}).get('num_workers', 4)

    print(f"Creating Stage 1 DataLoaders with batch size: {batch_size}, num_workers: {num_workers}")

    # The DataBlock's get_items function needs the config,
    # we pass it here when calling dataloaders()
    try:
        # Pass config explicitly to dataloaders, which passes it down to get_items
        dls = dblock.dataloaders(source=config, # Pass config to be used by get_items
                                 bs=batch_size,
                                 num_workers=num_workers)
        print("DataLoaders created successfully.")
        return dls
    except FileNotFoundError as e:
        print(f"Error creating DataLoaders: {e}")
        print("Please ensure data paths in config.yaml are correct and data exists.")
        raise e
    except Exception as e:
        import traceback
        print(f"An unexpected error occurred during DataLoaders creation: {e}")
        traceback.print_exc() # Print full traceback for debugging
        raise e

In [10]:
#| eval: false
# Example usage: Create DataLoaders (requires config.yaml and actual data)
try:
    config_path = '../configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    dls = get_stage1_dataloaders(config)

    # print(f"DataLoaders created. Testing show_batch...")
    # Now that batch transforms are complete, show_batch should work if data exists
    # Ensure you have data available at the paths specified in config.yaml
    # dls.show_batch(max_n=4, figsize=(12, 8))
    print("\nTesting one_batch...")
    b = dls.one_batch()
    print("one_batch() retrieved. Check keys and shapes.")
    # Example: print shapes of tensors in the batch
    print(f'len(b): {len(b)}')
    
    print(f'b[0]: {b[0]}')
    print(f'b[1]: {b[1]}')
    
    print(f'len(b[0]): {len(b[0])}')
    print(f'len(b[1]): {len(b[1])}')
    
except FileNotFoundError as e:
    print(f"\nError creating DataLoaders: FileNotFoundError - {e}")
    print("Please ensure 'paths.data_base', 'paths.stage1_data', and 'paths.stage1_images' are correctly set in configs/config.yaml and point to existing data.")
except Exception as e:
    import traceback
    print(f"\nAn unexpected error occurred during DataLoaders creation/testing: {e}")
    traceback.print_exc()


Loaded config from ../configs/config.yaml
Creating Stage 1 DataLoaders with batch size: 8, num_workers: 4
Loading Stage 1 items from: /workspace/llava/data/llava_pretrain/llava_pretrain.jsonl
Assuming images relative to: /workspace/llava/data/llava_pretrain/images
Found 595375 samples for Stage 1.
DataLoaders created successfully.

Testing one_batch...
one_batch() retrieved. Check keys and shapes.
len(b): 2
b[0]: TensorImage([[[[0.0353, 0.0549, 0.0745,  ..., 0.3529, 0.3569, 0.3098],
               [0.0784, 0.0588, 0.0471,  ..., 0.4078, 0.3922, 0.3294],
               [0.1529, 0.0863, 0.0275,  ..., 0.4471, 0.4078, 0.3255],
               ...,
               [0.0510, 0.1686, 0.3137,  ..., 0.1059, 0.1059, 0.1137],
               [0.0706, 0.1137, 0.1765,  ..., 0.0745, 0.0784, 0.0824],
               [0.0902, 0.0824, 0.0706,  ..., 0.0471, 0.0510, 0.0549]],

              [[0.0863, 0.1059, 0.1255,  ..., 0.3608, 0.3647, 0.3176],
               [0.1294, 0.1098, 0.0980,  ..., 0.4157, 0.4000, 0.

---

## Step 4.1: Update Data Handling for Stage 2 (Placeholder)

This section will contain functions to create DataLoaders specifically for Stage 2 instruction tuning, using the appropriate chat template and label masking.

In [26]:
#| export
# Placeholder for get_stage2_dataloaders function
def get_stage2_dataloaders(config: dict):
    """Creates fastai DataLoaders for Stage 2 training (Placeholder)."""
    # Similar to stage 1 but uses different template/masking
    # Define llava_datablock_stage2 or custom dataset logic
    print("get_stage2_dataloaders - Placeholder: Not implemented yet.")
    # Example structure:
    # from .preprocessing import format_v1_template # Need v1 formatter
    # llava_tokenizer_tfm_stage2 = LLaVATextTokenizer(tokenizer, template_formatter=format_v1_template)
    # llava_batch_tfm_stage2 = LLaVABatchTransform(tokenizer=tokenizer, template='v1') # Need to adapt batch tfm for v1 masking
    # LLaVADataBlockStage2 = DataBlock(
    #     blocks=(ImageBlock(cls=PILImage), TransformBlock),
    #     get_items=partial(get_llava_items, stage=2),
    #     get_x=get_image_path,
    #     get_y=get_conversations,
    #     splitter=RandomSplitter(valid_pct=0.01, seed=42),
    #     item_tfms=[
    #         *basic_image_item_tfms,
    #         llava_tokenizer_tfm_stage2
    #     ],
    #     batch_tfms=[
    #         IntToFloatTensor(div_mask=torch.BoolTensor([True,False])),
    #         clip_normalize,
    #         llava_batch_tfm_stage2
    #     ]
    # )
    # batch_size = config.get('data', {}).get('batch_size_per_device_stage2', 4)
    # num_workers = config.get('data', {}).get('num_workers', 4)
    # dls = LLaVADataBlockStage2.dataloaders(config=config, bs=batch_size, num_workers=num_workers)
    # return dls
    raise NotImplementedError("Stage 2 DataLoaders are not yet implemented.")
    return None

---

## Step 7.4: Implement Custom Evaluation Set Handling (Placeholder)

In [27]:
#| export
# Placeholder for loading custom eval set
def get_custom_eval_dataloaders(config: dict):
    """Creates fastai DataLoaders for the custom evaluation set (Placeholder)."""
    print("get_custom_eval_dataloaders - Placeholder: Not implemented yet.")
    # Similar logic to get_stage1/2_dataloaders but points to custom eval data paths
    # Might use stage 2 templates/transforms or custom ones.
    raise NotImplementedError("Custom Eval DataLoaders are not yet implemented.")
    return None

In [16]:
#| hide
import nbdev; nbdev.nbdev_export()