In [None]:
"""
TextVQA Dataset Loader and Preprocessing Module
"""

import os
from typing import Dict, List, Optional, Tuple, Callable
from PIL import Image
import io
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch

In [None]:
class TextVQADataset(Dataset):
    """TextVQA Dataset for Visual Question Answering"""

    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        processor: Optional[Callable] = None,
        max_samples: Optional[int] = None,
    ):
        """
        Args:
            data_dir: Path to the data directory containing parquet files
            split: One of 'train', 'validation', or 'test'
            processor: Model processor for preprocessing images and text
            max_samples: Maximum number of samples to load (for debugging)
        """
        self.data_dir = data_dir
        self.split = split
        self.processor = processor

        # Load data from parquet files
        self.data = self._load_parquet_files()

        if max_samples is not None:
            self.data = self.data[:max_samples]

        print(f"Loaded {len(self.data)} samples for {split} split")

    def _load_parquet_files(self) -> pd.DataFrame:
        """Load all parquet files for the specified split"""
        split_patterns = {
            "train": ("train-", 20),
            "validation": ("validation-", 3),
            "test": ("test-", 4),
        }

        prefix, num_files = split_patterns[self.split]
        files = [
            os.path.join(self.data_dir, f"{prefix}{str(i).zfill(5)}-of-{str(num_files).zfill(5)}.parquet")
            for i in range(num_files)
        ]

        dfs = []
        for f in files:
            if os.path.exists(f):
                dfs.append(pd.read_parquet(f))

        return pd.concat(dfs, ignore_index=True)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict:
        """Get a single sample"""
        row = self.data.iloc[idx]

        # Load image from bytes
        image_bytes = row["image"]["bytes"]
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        # Get question and answers
        question = row["question"]
        answers = row["answers"] if "answers" in row and row["answers"] is not None else []
        ocr_tokens = row["ocr_tokens"] if "ocr_tokens" in row else []

        sample = {
            "image_id": row["image_id"],
            "question_id": row["question_id"],
            "image": image,
            "question": question,
            "answers": answers,
            "ocr_tokens": ocr_tokens,
        }

        return sample

    def get_most_common_answer(self, answers: List[str]) -> str:
        """Get the most common answer from the answer list"""
        if not answers:
            return ""
        from collections import Counter
        counter = Counter(answers)
        return counter.most_common(1)[0][0]


def collate_fn_qwen(batch: List[Dict], processor) -> Dict:
    """Collate function for Qwen2.5-VL model"""
    images = [sample["image"] for sample in batch]
    questions = [sample["question"] for sample in batch]

    # Create conversation format for Qwen2.5-VL
    conversations = []
    for q in questions:
        conv = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": q}
                ]
            }
        ]
        conversations.append(conv)

    # Process with Qwen processor
    texts = [
        processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
        for conv in conversations
    ]

    inputs = processor(
        text=texts,
        images=images,
        padding=True,
        return_tensors="pt"
    )

    # Add metadata
    inputs["image_ids"] = [sample["image_id"] for sample in batch]
    inputs["question_ids"] = [sample["question_id"] for sample in batch]
    inputs["questions"] = questions
    inputs["answers"] = [sample["answers"] for sample in batch]

    return inputs


def collate_fn_train_qwen(batch: List[Dict], processor, tokenizer) -> Dict:
    """Collate function for Qwen2.5-VL model training with labels"""
    images = [sample["image"] for sample in batch]
    questions = [sample["question"] for sample in batch]

    # Get most common answer for each sample
    dataset = TextVQADataset.__new__(TextVQADataset)
    answers = [dataset.get_most_common_answer(sample["answers"]) for sample in batch]

    # Create conversation format with answer for training
    conversations = []
    for q, a in zip(questions, answers):
        conv = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": q}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": a}
                ]
            }
        ]
        conversations.append(conv)

    # Process with Qwen processor
    texts = [
        processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
        for conv in conversations
    ]

    inputs = processor(
        text=texts,
        images=images,
        padding=True,
        return_tensors="pt"
    )

    # Create labels (mask the input part, only train on output)
    labels = inputs["input_ids"].clone()

    # We need to mask everything except the answer tokens
    # For simplicity, we'll use the full sequence as labels
    # The model will learn to predict the assistant response
    inputs["labels"] = labels

    return inputs


def get_dataloader(
    data_dir: str,
    split: str,
    processor,
    batch_size: int = 4,
    shuffle: bool = True,
    num_workers: int = 4,
    max_samples: Optional[int] = None,
    for_training: bool = False,
    tokenizer = None,
) -> DataLoader:
    """Create a DataLoader for the TextVQA dataset"""

    dataset = TextVQADataset(
        data_dir=data_dir,
        split=split,
        processor=processor,
        max_samples=max_samples,
    )

    if for_training:
        collate = lambda batch: collate_fn_train_qwen(batch, processor, tokenizer)
    else:
        collate = lambda batch: collate_fn_qwen(batch, processor)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate,
        pin_memory=True,
    )

    return dataloader


if __name__ == "__main__":
    # Test the data loader
    data_dir = "textvqa_data/data"

    dataset = TextVQADataset(data_dir, split="validation", max_samples=5)

    for i in range(min(3, len(dataset))):
        sample = dataset[i]
        print(f"\nSample {i}:")
        print(f"  Image ID: {sample['image_id']}")
        print(f"  Question: {sample['question']}")
        print(f"  Answers: {sample['answers'][:3]}...")
        print(f"  OCR Tokens: {sample['ocr_tokens']}")
        print(f"  Image size: {sample['image'].size}")
