In [25]:
# from src.data.chexpert_datamodule_parquet import CheXpertDataset
from datasets import load_dataset
from torch.utils.data import DataLoader
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import numpy as np

In [None]:
class CheXpertDataset(Dataset):
    """
    CheXpert dataset for multi-label classification of chest X-rays.
    Incorporates preprocessing from original implementation.
    """
    def __init__(self, 
                 parquet_file,
                 base_dir,
                 transform=None,
                 use_frontal_only=True,
                 policy='ones',
                 class_index=None,
                 use_metadata=True):
        """
        Args:
            parquet_file: Path to the Parquet file with annotations.
            base_dir: Base directory containing train/valid folders.
            transform: Optional transform to be applied on a sample.
            use_frontal_only: If True, only use frontal views.
            debug_mode: If True, only use a small subset of data for debugging.
            policy: How to handle uncertain labels: 'ones', 'zeros', 'ignore'.
            class_index: Which classes to use (defaults to all 14).
            use_metadata: Whether to include metadata in the labels.
        """
        self.parquet_file = parquet_file  # Store the Parquet file path
        self.base_dir = base_dir
        self.transform = transform
        self.policy = policy
        self.use_metadata = use_metadata

        # Define target (labels) list
        self.target = [
            "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", 
            "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", 
            "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", 
            "Pleural Other", "Fracture", "Support Devices"
        ]
        
        # Define metadata fields (only if use_metadata is True)
        self.metadata = ["Sex", "Age", "Frontal/Lateral", "AP/PA"] if use_metadata else []

        # Load the entire Parquet file into memory as a Pandas DataFrame
        self.data_frame = pd.read_parquet(parquet_file)  # Load all columns
        self.num_samples = len(self.data_frame)

        # Filter for frontal views if requested
        if use_frontal_only:
            self.data_frame = self.data_frame[self.data_frame["Frontal/Lateral"] == 1]
            self.num_samples = len(self.data_frame)

        # Allow filtering for specific classes
        if class_index is None:
            self.class_index = list(range(len(self.target)))
        else:
            self.class_index = class_index
            
        self.classes = [self.target[i] for i in self.class_index]
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Access the row directly from the preloaded DataFrame
        row = self.data_frame.iloc[idx]
        
        # Extract the image data directly from the "image" column
        image_data = row["image"]  # Assuming the "image" column contains raw image data
        
        # Convert the image data to a PIL Image
        try:
            image = Image.fromarray(np.array(image_data, dtype=np.uint8)).convert('L')
        except Exception as e:
            image = Image.new('L', (224, 224), color=128)  # Fallback to a blank image
        
        # Transform the image if needed
        if self.transform:
            image = self.transform(image)
        
        # Extract labels
        labels = torch.tensor(row[self.classes].values.astype(np.float32))
        
        # Add metadata to labels only if use_metadata is True
        if self.use_metadata:
            metadata = torch.tensor(row[self.metadata].values.astype(np.float32))
            labels = torch.cat([labels, metadata], dim=0)
        
        return image, labels

