# Data Loading

> Functions and classes for loading and parsing datasets for LLaVA-style training.

In [None]:
#| default_exp data.loading

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import json
from pathlib import Path
from typing import List, Dict, Any, Union
from dataclasses import dataclass
import PIL.Image

from fastai.vision.all import *
from fastai.data.block import DataBlock, TransformBlock
from fastai.data.transforms import parent_label, GrandparentSplitter # Example imports, adjust as needed

from Adaptive_Patching_VIT_fastai.utils import load_config

## Step 1.1: Data Parsing

We need functions to parse the JSONL files commonly used in LLaVA datasets. Each line typically contains:
- `id`: A unique identifier for the sample.
- `image`: The filename of the image (often relative to an `image_folder`).
- `conversations`: A list of dictionaries, where each dictionary has `from` ('human' or 'gpt') and `value` (the text).

In [None]:
#| export
@dataclass
class LLaVASample:
    """Represents a single sample from a LLaVA-style dataset.

    Attributes:
        sample_id: Unique identifier for the sample.
        image_path: Absolute path to the image file.
        conversations: List of conversation turns (dictionaries with 'from' and 'value').
        data_source: Optional field indicating the source dataset.
    """
    sample_id: str
    image_path: Path
    conversations: List[Dict[str, str]]
    data_source: str | None = None

In [None]:
show_doc(LLaVASample)

```python
#| export
@dataclass
class LLaVASample:
    """Represents a single sample from a LLaVA-style dataset.

    Attributes:
        sample_id: Unique identifier for the sample.
        image_path: Absolute path to the image file.
        conversations: List of conversation turns (dictionaries with 'from' and 'value').
        data_source: Optional field indicating the source dataset.
    """
    sample_id: str
    image_path: Path
    conversations: List[Dict[str, str]]
    data_source: str | None = None
```

In [None]:
#| export
def parse_llava_jsonl(jsonl_path: Union[str, Path], image_folder: Union[str, Path]) -> List[LLaVASample]:
    """Parses a LLaVA-style JSONL file and resolves image paths.

    Args:
        jsonl_path: Path to the JSONL file.
        image_folder: Path to the directory containing the images referenced in the JSONL file.

    Returns:
        A list of LLaVASample objects.

    Raises:
        FileNotFoundError: If the JSONL file does not exist.
        json.JSONDecodeError: If a line in the file is not valid JSON.
    """
    jsonl_path = Path(jsonl_path)
    image_folder = Path(image_folder)

    if not jsonl_path.is_file():
        raise FileNotFoundError(f"JSONL file not found: {jsonl_path}")

    samples = []
    with open(jsonl_path, 'r') as f:
        for i, line in enumerate(f):
            try:
                data = json.loads(line)
            except json.JSONDecodeError as e:
                raise json.JSONDecodeError(f"Error decoding JSON on line {i+1} in {jsonl_path}: {e.msg}", e.doc, e.pos)

            # Check for required keys
            if not all(k in data for k in ['id', 'image', 'conversations']):
                print(f"Warning: Skipping line {i+1} due to missing keys ('id', 'image', or 'conversations') in {jsonl_path}.\n",
                      f"Data: {data}")
                continue

            sample_id = data['id']
            # Construct the full image path. Assumes 'image' key holds a relative path or filename.
            # Handle potential nested structure like {'bytes': ..., 'path': ...} if found in parquet conversions
            image_ref = data['image']
            if isinstance(image_ref, dict) and 'path' in image_ref:
                # Handle cases where image info is nested (e.g., from parquet processing)
                image_filename = image_ref['path']
            elif isinstance(image_ref, str):
                image_filename = image_ref
            else:
                 print(f"Warning: Skipping line {i+1} due to unexpected image field format in {jsonl_path}.\n",
                       f"Expected string or dict with 'path', got: {type(image_ref)}")
                 continue
            
            # Ensure image_filename is treated as relative to image_folder
            # Path(image_filename).name gets the final component if it was an absolute path by mistake
            image_path = image_folder / Path(image_filename).name 

            conversations = data['conversations']
            data_source = data.get('data_source') # Optional field

            # Basic validation for conversations format
            if not isinstance(conversations, list) or not all(isinstance(turn, dict) and 'from' in turn and 'value' in turn for turn in conversations):
                 print(f"Warning: Skipping line {i+1} due to invalid 'conversations' format in {jsonl_path}.")
                 continue
            
            # Check if image file actually exists (optional, can slow down parsing)
            # if not image_path.is_file():
            #     print(f"Warning: Image file not found for sample {sample_id} at {image_path}, skipping.")
            #     continue

            samples.append(LLaVASample(
                sample_id=str(sample_id),
                image_path=image_path,
                conversations=conversations,
                data_source=data_source
            ))

    return samples

In [None]:
show_doc(parse_llava_jsonl)

```python
#| export
def parse_llava_jsonl(jsonl_path: str | Path, image_folder: str | Path) -> List[LLaVASample]:
    """Parses a LLaVA-style JSONL file and resolves image paths.

    Args:
        jsonl_path: Path to the JSONL file.
        image_folder: Path to the directory containing the images referenced in the JSONL file.

    Returns:
        A list of LLaVASample objects.

    Raises:
        FileNotFoundError: If the JSONL file does not exist.
        json.JSONDecodeError: If a line in the file is not valid JSON.
    """
    jsonl_path = Path(jsonl_path)
    image_folder = Path(image_folder)

    if not jsonl_path.is_file():
        raise FileNotFoundError(f"JSONL file not found: {jsonl_path}")

    samples = []
    with open(jsonl_path, 'r') as f:
        for i, line in enumerate(f):
            try:
                data = json.loads(line)
            except json.JSONDecodeError as e:
                raise json.JSONDecodeError(f"Error decoding JSON on line {i+1} in {jsonl_path}: {e.msg}", e.doc, e.pos)

            # Check for required keys
            if not all(k in data for k in ['id', 'image', 'conversations']):
                print(f"Warning: Skipping line {i+1} due to missing keys ('id', 'image', or 'conversations') in {jsonl_path}.\n",
                      f"Data: {data}")
                continue

            sample_id = data['id']
            # Construct the full image path. Assumes 'image' key holds a relative path or filename.
            # Handle potential nested structure like {'bytes': ..., 'path': ...} if found in parquet conversions
            image_ref = data['image']
            if isinstance(image_ref, dict) and 'path' in image_ref:
                # Handle cases where image info is nested (e.g., from parquet processing)
                image_filename = image_ref['path']
            elif isinstance(image_ref, str):
                image_filename = image_ref
            else:
                 print(f"Warning: Skipping line {i+1} due to unexpected image field format in {jsonl_path}.\n",
                       f"Expected string or dict with 'path', got: {type(image_ref)}")
                 continue
            
            # Ensure image_filename is treated as relative to image_folder
            # Path(image_filename).name gets the final component if it was an absolute path by mistake
            image_path = image_folder / Path(image_filename).name 

            conversations = data['conversations']
            data_source = data.get('data_source') # Optional field

            # Basic validation for conversations format
            if not isinstance(conversations, list) or not all(isinstance(turn, dict) and 'from' in turn and 'value' in turn for turn in conversations):
                 print(f"Warning: Skipping line {i+1} due to invalid 'conversations' format in {jsonl_path}.")
                 continue
            
            # Check if image file actually exists (optional, can slow down parsing)
            # if not image_path.is_file():
            #     print(f"Warning: Image file not found for sample {sample_id} at {image_path}, skipping.")
            #     continue

            samples.append(LLaVASample(
                sample_id=str(sample_id),
                image_path=image_path,
                conversations=conversations,
                data_source=data_source
            ))

    return samples
```

#### Example Usage & Test

In [None]:
#| eval: false
# Create dummy data for testing
dummy_data_dir = Path('./dummy_data')
dummy_img_dir = dummy_data_dir / 'images'
dummy_jsonl_path = dummy_data_dir / 'dummy_llava_data.jsonl'

dummy_img_dir.mkdir(parents=True, exist_ok=True)

# Create dummy image files
try:
    PIL.Image.new('RGB', (60, 30), color = 'red').save(dummy_img_dir / 'img1.jpg')
    PIL.Image.new('RGB', (60, 30), color = 'green').save(dummy_img_dir / 'img2.png')
except Exception as e:
    print(f"Note: Could not create dummy image files (PIL might not be fully installed or usable): {e}")

# Create dummy jsonl content
dummy_jsonl_content = [
    {"id": "sample1", "image": "img1.jpg", "conversations": [{"from": "human", "value": "<image>\nDescribe this."}, {"from": "gpt", "value": "It is red."}]}, 
    {"id": "sample2", "image": "img2.png", "conversations": [{"from": "human", "value": "<image>\nWhat color?"}, {"from": "gpt", "value": "Green."}]},
    {"id": "sample3_missing_keys", "conversations": []}, # Missing image/id
    {"id": "sample4_bad_conv", "image": "img1.jpg", "conversations": "not a list"}, # Bad conversation format
    {"id": "sample5_missing_img_file", "image": "nonexistent.jpg", "conversations": [{"from": "human", "value": "<image>"}, {"from": "gpt", "value": "..."}]},
]

with open(dummy_jsonl_path, 'w') as f:
    for item in dummy_jsonl_content:
        f.write(json.dumps(item) + '\n')

# Test parsing
try:
    parsed_samples = parse_llava_jsonl(dummy_jsonl_path, dummy_img_dir)
    print(f"Successfully parsed {len(parsed_samples)} samples:")
    for sample in parsed_samples:
        print(sample)

    # Basic checks
    assert len(parsed_samples) == 2 # Only valid samples should be parsed
    assert parsed_samples[0].sample_id == 'sample1'
    # Resolve paths for comparison
    assert parsed_samples[0].image_path.resolve() == (dummy_img_dir / 'img1.jpg').resolve()
    assert parsed_samples[1].sample_id == 'sample2'
    assert parsed_samples[1].image_path.resolve() == (dummy_img_dir / 'img2.png').resolve()
    assert parsed_samples[0].conversations[0]['from'] == 'human'

except FileNotFoundError as e:
    print(f"Error: {e}")
except json.JSONDecodeError as e:
    print(f"JSON Parsing Error: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

# Clean up dummy data (optional)
# import shutil
# shutil.rmtree(dummy_data_dir)

---

## Step 1.2: Image Loading and Basic Preprocessing

This involves defining the `ImageBlock` and basic `item_tfms` for loading and resizing images. Normalization stats will be defined in `11_data_preprocessing.ipynb`.

In [None]:
#| export
# Define the ImageBlock using PILImage for loading images
image_block = ImageBlock(cls=PILImage)

# Define basic item transforms for image processing
# 1. Resize images to 336x336, padding if necessary
#    method='pad': Pads the image to the target size.
#    pad_mode='const': Uses a constant value for padding.
#    pad_value=0: Uses black for padding (common for vision models).
# 2. Convert the image to a PyTorch tensor.
# Note: Normalization will be defined separately and likely applied in batch_tfms.
basic_image_item_tfms = [
    Resize(336, method='pad', pad_mode=PadMode.Constant, pad_value=0),
    ToTensor(),
]

---

## Step 1.4: Define Custom Dataset/DataBlock (Stage 1 - Placeholder)

This section will define the fastai `DataBlock` or a custom `Datasets` class to integrate the parsing, image loading, and text processing steps.

In [None]:
# Placeholder for get_items function for DataBlock
# def get_llava_items(config):
#     # Load config, parse jsonl using parse_llava_jsonl
#     # Return list of LLaVASample objects or similar structure
#     pass

# Placeholder for DataBlock definition
# llava_datablock_stage1 = DataBlock(
#     blocks=(image_block, TransformBlock), # TextBlock or custom block needed for text
#     get_items=get_llava_items,
#     get_x=lambda sample: sample.image_path, # Function to get image path from sample
#     get_y=lambda sample: sample.conversations, # Function to get conversations from sample
#     splitter=RandomSplitter(),
#     item_tfms=basic_image_item_tfms + [YourTextTokenizerTransform], # Add text transforms later
#     batch_tfms=[YourCustomBatchTransform] # Custom batch transform for padding, masking, normalization
# )

---

## Step 1.6: Create DataLoaders (Stage 1 - Placeholder)

This section will show how to create the `DataLoaders` object from the `DataBlock` or custom `Datasets`.

In [None]:
# Placeholder for creating DataLoaders
# def get_stage1_dataloaders(config):
#     # Load config, define datablock/dataset
#     # items = get_llava_items(config)
#     # dls = llava_datablock_stage1.dataloaders(items, bs=config['data']['batch_size_per_device_stage1'])
#     # return dls
#     pass

# Example: Test show_batch
# config = load_config('configs/config.yaml') # Load config first
# dls = get_stage1_dataloaders(config)
# dls.show_batch()

---

## Step 4.1: Update Data Handling for Stage 2 (Placeholder)

This section will contain functions to create DataLoaders specifically for Stage 2 instruction tuning, using the appropriate chat template and label masking.

In [None]:
# Placeholder for get_stage2_dataloaders function
# def get_stage2_dataloaders(config):
#     # Similar to stage 1 but uses different template/masking
#     # Define llava_datablock_stage2 or custom dataset logic
#     pass

---

## Step 7.4: Implement Custom Evaluation Set Handling (Placeholder)

In [None]:
# Placeholder for loading custom eval set
# def get_custom_eval_dataloaders(config):
#    pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()