# Data Loading

> Functions and classes for loading and parsing datasets for LLaVA-style training.

In [None]:
#| default_exp data.loading

In [None]:
#| hide
from nbdev.showdoc import *

In [1]:
#| export
import json
from pathlib import Path
from typing import List, Dict, Any, Union, Tuple
from dataclasses import dataclass
import PIL.Image
from functools import partial

from fastai.vision.all import *
from fastai.data.block import DataBlock, TransformBlock, CategoryBlock # Added CategoryBlock just in case
from fastai.data.transforms import parent_label, GrandparentSplitter, RandomSplitter, IntToFloatTensor

from Adaptive_Patching_VIT_fastai.utils import load_config
# Import necessary items from preprocessing notebook
from Adaptive_Patching_VIT_fastai.data.preprocessing import (
    tokenizer, 
    LLaVATextTokenizer, 
    clip_normalize, 
    format_plain_template, # Needed if not using LLaVATextTokenizer directly
    LLaVABatchTransform # Import the custom batch transform
)

## Step 1.1: Data Parsing

We need functions to parse the JSONL files commonly used in LLaVA datasets. Each line typically contains:
- `id`: A unique identifier for the sample.
- `image`: The filename of the image (often relative to an `image_folder`).
- `conversations`: A list of dictionaries, where each dictionary has `from` ('human' or 'gpt') and `value` (the text).

In [2]:
#| export
@dataclass
class LLaVASample:
    """Represents a single sample from a LLaVA-style dataset.

    Attributes:
        sample_id: Unique identifier for the sample.
        image_path: Absolute path to the image file.
        conversations: List of conversation turns (dictionaries with 'from' and 'value').
        data_source: Optional field indicating the source dataset.
    """
    sample_id: str
    image_path: Path
    conversations: List[Dict[str, str]]
    data_source: str | None = None

In [3]:
show_doc(LLaVASample)

```python
#| export
@dataclass
class LLaVASample:
    """Represents a single sample from a LLaVA-style dataset.\n\n    Attributes:\n        sample_id: Unique identifier for the sample.\n        image_path: Absolute path to the image file.\n        conversations: List of conversation turns (dictionaries with 'from' and 'value').\n        data_source: Optional field indicating the source dataset.\n    """
    sample_id: str
    image_path: Path
    conversations: List[Dict[str, str]]
    data_source: str | None = None
```

In [4]:
#| export
def parse_llava_jsonl(jsonl_path: Union[str, Path], image_folder: Union[str, Path]) -> List[LLaVASample]:
    """Parses a LLaVA-style JSONL file and resolves image paths.

    Args:
        jsonl_path: Path to the JSONL file.
        image_folder: Path to the directory containing the images referenced in the JSONL file.

    Returns:
        A list of LLaVASample objects.

    Raises:
        FileNotFoundError: If the JSONL file does not exist.
        json.JSONDecodeError: If a line in the file is not valid JSON.
    """
    jsonl_path = Path(jsonl_path)
    image_folder = Path(image_folder)

    if not jsonl_path.is_file():
        raise FileNotFoundError(f"JSONL file not found: {jsonl_path}")

    samples = []
    with open(jsonl_path, 'r') as f:
        for i, line in enumerate(f):
            try:
                data = json.loads(line)
            except json.JSONDecodeError as e:
                raise json.JSONDecodeError(f"Error decoding JSON on line {i+1} in {jsonl_path}: {e.msg}", e.doc, e.pos)

            # Check for required keys
            if not all(k in data for k in ['id', 'image', 'conversations']):
                print(f"Warning: Skipping line {i+1} due to missing keys ('id', 'image', or 'conversations') in {jsonl_path}.\n",
                      f"Data: {data}")
                continue

            sample_id = data['id']
            # Construct the full image path. Assumes 'image' key holds a relative path or filename.
            # Handle potential nested structure like {'bytes': ..., 'path': ...} if found in parquet conversions
            image_ref = data['image']
            if isinstance(image_ref, dict) and 'path' in image_ref:
                # Handle cases where image info is nested (e.g., from parquet processing)
                image_filename = image_ref['path']
            elif isinstance(image_ref, str):
                image_filename = image_ref
            else:
                 print(f"Warning: Skipping line {i+1} due to unexpected image field format in {jsonl_path}.\n",
                       f"Expected string or dict with 'path', got: {type(image_ref)}")
                 continue

            # Resolve the image path: assume image_filename is relative to image_folder
            # Path.name is used to handle cases where an absolute path might be mistakenly stored
            # but we want it relative to the designated image_folder.
            image_path = image_folder / Path(image_filename).name

            conversations = data['conversations']
            data_source = data.get('data_source') # Optional field

            # Basic validation for conversations format
            if not isinstance(conversations, list) or not all(isinstance(turn, dict) and 'from' in turn and 'value' in turn for turn in conversations):
                 print(f"Warning: Skipping line {i+1} due to invalid 'conversations' format in {jsonl_path}.")
                 continue

            # Check if image file actually exists (optional, can slow down parsing)
            # if not image_path.is_file():
            #     print(f"Warning: Image file not found for sample {sample_id} at {image_path}, skipping.")
            #     continue

            samples.append(LLaVASample(
                sample_id=str(sample_id),
                image_path=image_path,
                conversations=conversations,
                data_source=data_source
            ))

    return samples

In [5]:
show_doc(parse_llava_jsonl)

```python
#| export
def parse_llava_jsonl(jsonl_path: str | Path, image_folder: str | Path) -> List[LLaVASample]:
    """Parses a LLaVA-style JSONL file and resolves image paths.\n\n    Args:\n        jsonl_path: Path to the JSONL file.\n        image_folder: Path to the directory containing the images referenced in the JSONL file.\n\n    Returns:\n        A list of LLaVASample objects.\n\n    Raises:\n        FileNotFoundError: If the JSONL file does not exist.\n        json.JSONDecodeError: If a line in the file is not valid JSON.\n    """
    jsonl_path = Path(jsonl_path)
    image_folder = Path(image_folder)

    if not jsonl_path.is_file():
        raise FileNotFoundError(f"JSONL file not found: {jsonl_path}")

    samples = []
    with open(jsonl_path, 'r') as f:
        for i, line in enumerate(f):
            try:
                data = json.loads(line)
            except json.JSONDecodeError as e:
                raise json.JSONDecodeError(f"Error decoding JSON on line {i+1} in {jsonl_path}: {e.msg}", e.doc, e.pos)

            # Check for required keys
            if not all(k in data for k in ['id', 'image', 'conversations']):
                print(f"Warning: Skipping line {i+1} due to missing keys ('id', 'image', or 'conversations') in {jsonl_path}.\n",
                      f"Data: {data}")
                continue

            sample_id = data['id']
            # Construct the full image path. Assumes 'image' key holds a relative path or filename.
            # Handle potential nested structure like {'bytes': ..., 'path': ...} if found in parquet conversions
            image_ref = data['image']
            if isinstance(image_ref, dict) and 'path' in image_ref:
                # Handle cases where image info is nested (e.g., from parquet processing)
                image_filename = image_ref['path']
            elif isinstance(image_ref, str):
                image_filename = image_ref
            else:
                 print(f"Warning: Skipping line {i+1} due to unexpected image field format in {jsonl_path}.\n",
                       f"Expected string or dict with 'path', got: {type(image_ref)}")
                 continue
            
            # Resolve the image path: assume image_filename is relative to image_folder
            # Path.name is used to handle cases where an absolute path might be mistakenly stored
            # but we want it relative to the designated image_folder.
            image_path = image_folder / Path(image_filename).name 

            conversations = data['conversations']
            data_source = data.get('data_source') # Optional field

            # Basic validation for conversations format
            if not isinstance(conversations, list) or not all(isinstance(turn, dict) and 'from' in turn and 'value' in turn for turn in conversations):
                 print(f"Warning: Skipping line {i+1} due to invalid 'conversations' format in {jsonl_path}.")
                 continue
            
            # Check if image file actually exists (optional, can slow down parsing)
            # if not image_path.is_file():
            #     print(f"Warning: Image file not found for sample {sample_id} at {image_path}, skipping.")
            #     continue

            samples.append(LLaVASample(
                sample_id=str(sample_id),
                image_path=image_path,
                conversations=conversations,
                data_source=data_source
            ))

    return samples
```

#### Example Usage & Test

In [6]:
#| eval: false
# Create dummy data for testing
dummy_data_dir = Path('./dummy_data')
dummy_img_dir = dummy_data_dir / 'images'
dummy_jsonl_path = dummy_data_dir / 'dummy_llava_data.jsonl'

dummy_img_dir.mkdir(parents=True, exist_ok=True)
print(f"Created dummy directory: {dummy_img_dir}")

# Create dummy image files (attempt, may fail if PIL issues persist)
try:
    PIL.Image.new('RGB', (60, 30), color = 'red').save(dummy_img_dir / 'img1.jpg')
    PIL.Image.new('RGB', (60, 30), color = 'green').save(dummy_img_dir / 'img2.png')
except Exception as e:
    print(f"Note: Could not create dummy image files (PIL might not be fully installed or usable): {e}")

# Create dummy jsonl content
dummy_jsonl_content = [
    {"id": "sample1", "image": "img1.jpg", "conversations": [{"from": "human", "value": "<image>\nDescribe this."}, {"from": "gpt", "value": "It is red."}]}, 
    {"id": "sample2", "image": "img2.png", "conversations": [{"from": "human", "value": "<image>\nWhat color?"}, {"from": "gpt", "value": "Green."}]},
    {"id": "sample3_missing_keys", "conversations": []}, # Missing image/id
    {"id": "sample4_bad_conv", "image": "img1.jpg", "conversations": "not a list"}, # Bad conversation format
    {"id": "sample5_missing_img_file", "image": "nonexistent.jpg", "conversations": [{"from": "human", "value": "<image>"}, {"from": "gpt", "value": "..."}]},
]

with open(dummy_jsonl_path, 'w') as f:
    for item in dummy_jsonl_content:
        f.write(json.dumps(item) + '\n')
print(f"Created dummy jsonl file: {dummy_jsonl_path}")

# Test parsing
try:
    parsed_samples = parse_llava_jsonl(dummy_jsonl_path, dummy_img_dir)
    print(f"Successfully parsed {len(parsed_samples)} samples:")
    for sample in parsed_samples:
        print(sample)

    # Basic checks (adjust expected length based on warnings/skips)
    # Check based on the print output, 3 samples are expected now (1, 2, 5)
    assert len(parsed_samples) == 3 
    assert parsed_samples[0].sample_id == 'sample1'
    # Resolve paths for comparison
    assert parsed_samples[0].image_path.resolve() == (dummy_img_dir / 'img1.jpg').resolve()
    assert parsed_samples[1].sample_id == 'sample2'
    assert parsed_samples[1].image_path.resolve() == (dummy_img_dir / 'img2.png').resolve()
    assert parsed_samples[0].conversations[0]['from'] == 'human'
    assert parsed_samples[2].sample_id == 'sample5_missing_img_file'

except FileNotFoundError as e:
    print(f"Error: {e}")
except json.JSONDecodeError as e:
    print(f"JSON Parsing Error: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

# Clean up dummy data (optional)
# import shutil
# if dummy_data_dir.exists():
#    shutil.rmtree(dummy_data_dir)
#    print(f"Cleaned up dummy data directory: {dummy_data_dir}")

Created dummy directory: dummy_data/images
Note: Could not create dummy image files (PIL might not be fully installed or usable): name 'Image' is not defined
Created dummy jsonl file: dummy_data/dummy_llava_data.jsonl
Data: {'id': 'sample3_missing_keys', 'conversations': []}
Successfully parsed 3 samples:
LLaVASample(sample_id='sample1', image_path=PosixPath('dummy_data/images/img1.jpg'), conversations=[{'from': 'human', 'value': '<image>\nDescribe this.'}, {'from': 'gpt', 'value': 'It is red.'}], data_source=None)
LLaVASample(sample_id='sample2', image_path=PosixPath('dummy_data/images/img2.png'), conversations=[{'from': 'human', 'value': '<image>\nWhat color?'}, {'from': 'gpt', 'value': 'Green.'}], data_source=None)
LLaVASample(sample_id='sample5_missing_img_file', image_path=PosixPath('dummy_data/images/nonexistent.jpg'), conversations=[{'from': 'human', 'value': '<image>'}, {'from': 'gpt', 'value': '...'}], data_source=None)


---

## Step 1.2: Image Loading and Basic Preprocessing

This involves defining the `ImageBlock` and basic `item_tfms` for loading and resizing images. Normalization stats were defined in `11_data_preprocessing.ipynb`.

In [7]:
#| export
# Define the ImageBlock using PILImage for loading images
image_block = ImageBlock(cls=PILImage)

# Define basic item transforms for image processing
# 1. Resize images to 336x336, padding if necessary
#    method='pad': Pads the image to the target size.
#    pad_mode='const': Uses a constant value for padding.
#    pad_value=0: Uses black for padding (common for vision models).
# 2. Convert the image to a PyTorch tensor.
# Note: Normalization will be applied in batch_tfms using clip_normalize.
basic_image_item_tfms = [
    Resize(336, method='pad', pad_mode=PadMode.Constant, pad_value=0),
    ToTensor(),
]

---

## Step 1.4: Define Custom Dataset/DataBlock (Stage 1)

This section defines the fastai `DataBlock` for Stage 1 projector pre-training.

In [8]:
#| export
def get_llava_items(config: dict, stage: int = 1) -> List[LLaVASample]:
    """Loads LLaVA samples for a specific stage based on config.

    Args:
        config: The main configuration dictionary.
        stage: The training stage (1 or 2).

    Returns:
        A list of LLaVASample objects.

    Raises:
        ValueError: If stage is not 1 or 2.
        KeyError: If required path configurations are missing.
    """
    data_base_path = Path(config['paths']['data_base'])
    if stage == 1:
        jsonl_rel_path = config['paths']['stage1_data']
        images_rel_path = config['paths']['stage1_images']
    elif stage == 2:
        jsonl_rel_path = config['paths']['stage2_data']
        # Stage 2 images might be relative to base path or absolute paths might be in JSONL
        images_rel_path = config['paths'].get('stage2_images', '.') # Assume relative to base if specified
    else:
        raise ValueError(f"Invalid stage specified: {stage}. Must be 1 or 2.")

    jsonl_path = data_base_path / jsonl_rel_path
    image_folder = data_base_path / images_rel_path

    print(f"Loading Stage {stage} items from: {jsonl_path}")
    print(f"Assuming images relative to: {image_folder}")

    # Check if paths exist
    if not jsonl_path.exists():
        raise FileNotFoundError(f"Stage {stage} JSONL file not found: {jsonl_path}")
    if not image_folder.is_dir() and stage == 1: # Only strictly check image folder for stage 1
        print(f"Warning: Stage {stage} image folder not found or not a directory: {image_folder}")
        # Proceed cautiously, parsing might still work if JSONL contains absolute paths
        # or if image existence check is disabled in parse_llava_jsonl

    samples = parse_llava_jsonl(jsonl_path, image_folder)
    print(f"Found {len(samples)} samples for Stage {stage}.")
    return samples

def get_image_path(sample: LLaVASample) -> Path:
    """Extracts the image path from a LLaVASample."""
    return sample.image_path

def get_conversations(sample: LLaVASample) -> List[Dict[str, str]]:
    """Extracts the conversations from a LLaVASample."""
    return sample.conversations

In [9]:
#| export
# Define the DataBlock for Stage 1 (Projector Pre-training)
if tokenizer:
    # Instantiate the custom batch transform (needs tokenizer)
    llava_batch_tfm = LLaVABatchTransform(tokenizer=tokenizer)
    
    LLaVADataBlockStage1 = DataBlock(
        blocks=(ImageBlock(cls=PILImage), TransformBlock), # Input: Image, Target: Processed Conversations (token IDs)
        get_items=partial(get_llava_items, stage=1), # Use partial to pass stage=1 to the item getter
        get_x=get_image_path, # Function to get image path from sample
        get_y=get_conversations, # Function to get conversations for text processing
        splitter=RandomSplitter(valid_pct=0.01, seed=42), # Small validation set for monitoring
        item_tfms=[
            *basic_image_item_tfms, # Apply resize/ToTensor to images (x)
            LLaVATextTokenizer(tokenizer, template_formatter=format_plain_template) # Apply template+tokenization to conversations (y)
        ],
        # batch_tfms are applied after items are collated into a batch
        batch_tfms=[ 
            IntToFloatTensor(div_mask=torch.BoolTensor([True,False])), # Convert image tensor to float (but not text IDs)
            clip_normalize,   # Apply normalization (to image tensor)
            llava_batch_tfm   # Apply custom batch padding, masking, etc.
        ]
    )
    print("LLaVADataBlockStage1 defined.")
else:
    LLaVADataBlockStage1 = None
    print("Tokenizer not available, LLaVADataBlockStage1 not defined.")

LLaVABatchTransform initialized. Image Token ID: 32000, Pad Token ID: 2
LLaVADataBlockStage1 defined.


#### Example Usage & Test (DataBlock Definition)

In [10]:
#| eval: false
if LLaVADataBlockStage1:
    try:
        # Load config to pass to get_items via the DataBlock
        config = load_config('configs/config.yaml') 
        print("Attempting DataBlock summary...")
        # This will call get_items(config) internally
        # It requires valid paths in config.yaml and actual data to succeed fully.
        LLaVADataBlockStage1.summary(config=config, bs=4)
    except FileNotFoundError as e:
        print(f"\nSkipping summary: FileNotFoundError: {e}")
        print("Please ensure 'paths.data_base', 'paths.stage1_data', and 'paths.stage1_images' are correctly set in configs/config.yaml and point to existing data.")
    except Exception as e:
        print(f"\nSkipping summary: Exception occurred during gathering samples: {e}")
        print("Check paths in config.yaml and ensure data files are accessible and correctly formatted.")
else:
    print("LLaVADataBlockStage1 not defined, cannot show summary.")

Attempting DataBlock summary...
DataBlock Summary:
  Setting up Pipeline: partial -> get_image_path
  Setting up Pipeline: partial -> get_conversations -> LLaVATextTokenizer
  
  Building validation fold (1%)
    Setting up Pipeline: partial -> get_image_path
    Setting up Pipeline: partial -> get_conversations -> LLaVATextTokenizer
    

Skipping summary: FileNotFoundError: [Errno 2] No such file or directory: '/path/to/your/datasets/llava_pretrain/llava_pretrain.jsonl'
Please ensure 'paths.data_base', 'paths.stage1_data', and 'paths.stage1_images' are correctly set in configs/config.yaml and point to existing data.


---

## Step 1.6: Create DataLoaders (Stage 1)

This section shows how to create the `DataLoaders` object from the `DataBlock`.

In [11]:
#| export
def get_stage1_dataloaders(config: dict, dblock: DataBlock = LLaVADataBlockStage1) -> DataLoaders:
    """Creates fastai DataLoaders for Stage 1 training.

    Args:
        config: The main configuration dictionary.
        dblock: The configured DataBlock for Stage 1 (defaults to LLaVADataBlockStage1).

    Returns:
        A fastai DataLoaders object.

    Raises:
        ValueError: If the DataBlock is not defined.
        FileNotFoundError: If data paths are invalid during DataBlock processing.
    """
    if dblock is None:
        raise ValueError("Stage 1 DataBlock is not defined. Ensure tokenizer loaded correctly.")

    batch_size = config.get('data', {}).get('batch_size_per_device_stage1', 8)
    num_workers = config.get('data', {}).get('num_workers', 4)

    print(f"Creating Stage 1 DataLoaders with batch size: {batch_size}, num_workers: {num_workers}")

    # The DataBlock's get_items function needs the config,
    # we pass it here when calling dataloaders()
    try:
        # Pass config explicitly to dataloaders, which passes it down to get_items
        dls = dblock.dataloaders(source=None, # Source is determined by get_items
                                 config=config, # Pass config to be used by get_items
                                 bs=batch_size,
                                 num_workers=num_workers)
        print("DataLoaders created successfully.")
        return dls
    except FileNotFoundError as e:
        print(f"Error creating DataLoaders: {e}")
        print("Please ensure data paths in config.yaml are correct and data exists.")
        raise e
    except Exception as e:
        import traceback
        print(f"An unexpected error occurred during DataLoaders creation: {e}")
        traceback.print_exc() # Print full traceback for debugging
        raise e

In [12]:
#| eval: false
# Example usage: Create DataLoaders (requires config.yaml and actual data)
try:
    config_path = 'configs/config.yaml'
    config = load_config(config_path)
    print(f"Loaded config from {config_path}")
    dls = get_stage1_dataloaders(config)

    print(f"DataLoaders created. Testing show_batch...")
    # Now that batch transforms are complete, show_batch should work if data exists
    # Ensure you have data available at the paths specified in config.yaml
    # dls.show_batch(max_n=4, figsize=(12, 8))
    # print("\nTesting one_batch...")
    # b = dls.one_batch()
    # print("one_batch() retrieved. Check keys and shapes.")
    # # Example: print shapes of tensors in the batch
    # for k, v in b.items():
    #     if isinstance(v, torch.Tensor):
    #         print(f"  {k}: {v.shape}")
    #     else:
    #         print(f"  {k}: {type(v)}")

    print("\nNote: Full testing requires valid data paths and files.")

except FileNotFoundError as e:
    print(f"\nError creating DataLoaders: FileNotFoundError - {e}")
    print("Please ensure 'paths.data_base', 'paths.stage1_data', and 'paths.stage1_images' are correctly set in configs/config.yaml and point to existing data.")
except Exception as e:
    import traceback
    print(f"\nAn unexpected error occurred during DataLoaders creation/testing: {e}")
    traceback.print_exc()


Loaded config from configs/config.yaml
Creating Stage 1 DataLoaders with batch size: 8, num_workers: 4
Loading Stage 1 items from: /path/to/your/datasets/llava_pretrain/llava_pretrain.jsonl
Assuming images relative to: /path/to/your/datasets/llava_pretrain/images


FileNotFoundError: [Errno 2] No such file or directory: '/path/to/your/datasets/llava_pretrain/llava_pretrain.jsonl'

---

## Step 4.1: Update Data Handling for Stage 2 (Placeholder)

This section will contain functions to create DataLoaders specifically for Stage 2 instruction tuning, using the appropriate chat template and label masking.

In [13]:
#| export
# Placeholder for get_stage2_dataloaders function
def get_stage2_dataloaders(config: dict):
    """Creates fastai DataLoaders for Stage 2 training (Placeholder)."""
    # Similar to stage 1 but uses different template/masking
    # Define llava_datablock_stage2 or custom dataset logic
    print("get_stage2_dataloaders - Placeholder: Not implemented yet.")
    # Example structure:
    # from .preprocessing import format_v1_template # Need v1 formatter
    # llava_tokenizer_tfm_stage2 = LLaVATextTokenizer(tokenizer, template_formatter=format_v1_template)
    # llava_batch_tfm_stage2 = LLaVABatchTransform(tokenizer=tokenizer, template='v1') # Need to adapt batch tfm for v1 masking
    # LLaVADataBlockStage2 = DataBlock(
    #     blocks=(ImageBlock(cls=PILImage), TransformBlock),
    #     get_items=partial(get_llava_items, stage=2),
    #     get_x=get_image_path,
    #     get_y=get_conversations,
    #     splitter=RandomSplitter(valid_pct=0.01, seed=42),
    #     item_tfms=[
    #         *basic_image_item_tfms,
    #         llava_tokenizer_tfm_stage2
    #     ],
    #     batch_tfms=[
    #         IntToFloatTensor(div_mask=torch.BoolTensor([True,False])),
    #         clip_normalize,
    #         llava_batch_tfm_stage2
    #     ]
    # )
    # batch_size = config.get('data', {}).get('batch_size_per_device_stage2', 4)
    # num_workers = config.get('data', {}).get('num_workers', 4)
    # dls = LLaVADataBlockStage2.dataloaders(config=config, bs=batch_size, num_workers=num_workers)
    # return dls
    raise NotImplementedError("Stage 2 DataLoaders are not yet implemented.")
    return None

---

## Step 7.4: Implement Custom Evaluation Set Handling (Placeholder)

In [14]:
#| export
# Placeholder for loading custom eval set
def get_custom_eval_dataloaders(config: dict):
    """Creates fastai DataLoaders for the custom evaluation set (Placeholder)."""
    print("get_custom_eval_dataloaders - Placeholder: Not implemented yet.")
    # Similar logic to get_stage1/2_dataloaders but points to custom eval data paths
    # Might use stage 2 templates/transforms or custom ones.
    raise NotImplementedError("Custom Eval DataLoaders are not yet implemented.")
    return None

In [15]:
#| hide
import nbdev; nbdev.nbdev_export()