In [18]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
import pydicom
from pathlib import Path
from typing import Tuple, Dict, Optional, List
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
def create_pairs_csv(labels_csv: str, output_csv: str, dataset_root: str):
    """Create pairs CSV file from the original labels CSV.
    
    Args:
        labels_csv (str): Path to original labels CSV with columns [caseId, instanceUId, phaseLabel]
        output_csv (str): Path to save the pairs CSV
        dataset_root (str): Root directory containing the batch folders
    """
    # Read labels CSV
    df = pd.read_csv(labels_csv)
    # StudyInstanceUID,SeriesInstanceUID,Label
    # Group by caseId to find pairs
    case_groups = df.groupby('StudyInstanceUID')
    
    pairs = []
    for case_id, group in case_groups:
        # Get phases for this case
        non_contrast = group[group['Label'].str.lower() == 'non-contrast']
        arterial = group[group['Label'].str.lower() == 'aterial']
        portal_venous = group[group['Label'].str.lower() == 'venous']
        
        # Find batch folder containing this case
        case_path = None
        for batch_dir in Path(dataset_root).glob('batch*'):
            if (batch_dir / str(case_id)).exists():
                case_path = batch_dir / str(case_id)
                break
                
        if case_path is None:
            print(f"Warning: Could not find case {case_id} in any batch folder")
            continue
            
        # Create pairs: non-contrast -> arterial and non-contrast -> portal-venous
        if not non_contrast.empty:
            nc_instance = non_contrast.iloc[0]
            nc_path = case_path / str(nc_instance['SeriesInstanceUID'])
            
            # Non-contrast -> Arterial pairs
            if not arterial.empty:
                art_instance = arterial.iloc[0]
                art_path = case_path / str(art_instance['SeriesInstanceUID'])
                
                pairs.append({
                    'caseId': case_id,
                    'InputPath': str(nc_path),
                    'InputLabel': 'non-contrast',
                    'TargetPath': str(art_path),
                    'TargetLabel': 'arterial'
                })
                
            # Non-contrast -> Portal venous pairs
            if not portal_venous.empty:
                pv_instance = portal_venous.iloc[0]
                pv_path = case_path / str(pv_instance['SeriesInstanceUID'])
                
                pairs.append({
                    'caseId': case_id,
                    'InputPath': str(nc_path),
                    'InputLabel': 'non-contrast',
                    'TargetPath': str(pv_path),
                    'TargetLabel': 'venous'
                })
    
    # Create and save pairs DataFrame
    pairs_df = pd.DataFrame(pairs)
    pairs_df.to_csv(output_csv, index=False)
    print(f"✅ Saved {len(pairs_df)} valid pairs to {output_csv}")

if __name__ == "__main__":
    # Example usage
    dataset_root = "/media/disk1/saeedeh_danaei/ncct_cect/vindr_ds"
    labels_csv = os.path.join(dataset_root, "labels.csv")
    pairs_csv = os.path.join(dataset_root, "series_pairs.csv")
    
    create_pairs_csv(labels_csv, pairs_csv, dataset_root)



✅ Saved 175 valid pairs to /media/disk1/saeedeh_danaei/ncct_cect/vindr_ds/series_pairs.csv


In [19]:


class VindrCTDataset(Dataset):
    """Dataset class for loading paired non-contrast and contrast-enhanced CT scans from Vindr dataset."""
    
    def __init__(
        self,
        pairs_csv: str,
        transform: Optional[A.Compose] = None,
        phase: str = "train",
        slice_selection: str = "middle",  # Options: "middle", "all", or int
        max_slices: Optional[int] = None
    ):
        """
        Args:
            pairs_csv (str): Path to CSV file containing paired scan information
            transform (Optional[A.Compose]): Albumentations transformations to apply
            phase (str): Dataset phase ('train', 'val', or 'test')
            slice_selection (str): How to select slices from the 3D volume
                - "middle": Select middle slice
                - "all": Use all slices (returns 3D volume)
                - int: Select specific number of evenly spaced slices
            max_slices (Optional[int]): Maximum number of slices to load (for memory efficiency)
        """
        self.pairs_df = pd.read_csv(pairs_csv)
        self.transform = transform if transform is not None else self._get_default_transform(phase)
        self.phase = phase
        self.slice_selection = slice_selection
        self.max_slices = max_slices
        
    def __len__(self) -> int:
        return len(self.pairs_df)
    
    def _load_dicom_series(self, series_path: str) -> np.ndarray:
        """Load a DICOM series and return as a normalized numpy array."""
        series_path = Path(series_path)
        dicom_files = sorted(series_path.glob('*.dcm'), 
                           key=lambda x: float(pydicom.dcmread(str(x)).ImagePositionPatient[2]))
        
        if self.max_slices and len(dicom_files) > self.max_slices:
            # Sample evenly spaced slices
            indices = np.linspace(0, len(dicom_files)-1, self.max_slices, dtype=int)
            dicom_files = [dicom_files[i] for i in indices]
        
        # Load first slice to get series metadata
        first_slice = pydicom.dcmread(str(dicom_files[0]))
        series_shape = (len(dicom_files), first_slice.Rows, first_slice.Columns)
        series_array = np.zeros(series_shape, dtype=np.float32)
        
        # Load all slices
        for idx, dcm_path in enumerate(dicom_files):
            dcm = pydicom.dcmread(str(dcm_path))
            pixel_array = dcm.pixel_array.astype(np.float32)
            
            # Apply rescaling if available
            if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'):
                pixel_array = pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept
            
            series_array[idx] = pixel_array
        
        # Select slices based on slice_selection parameter
        if isinstance(self.slice_selection, int):
            indices = np.linspace(0, len(series_array)-1, self.slice_selection, dtype=int)
            series_array = series_array[indices]
        elif self.slice_selection == "middle":
            mid_idx = len(series_array) // 2
            series_array = series_array[mid_idx:mid_idx+1]
        # For "all", keep the full array
        
        # Normalize to [0, 1] using percentile-based normalization
        p1, p99 = np.percentile(series_array, (1, 99))
        series_array = np.clip(series_array, p1, p99)
        series_array = (series_array - p1) / (p99 - p1)
        
        return series_array
    
    def _get_default_transform(self, phase: str) -> A.Compose:
        """Get default transformation pipeline based on dataset phase."""
        if phase == "train":
            return A.Compose([
                A.RandomRotate90(p=0.5),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.OneOf([
                    A.GaussNoise(var_limit=(10.0, 50.0)),
                    A.GaussianBlur(),
                    A.RandomBrightnessContrast(),
                ], p=0.3),
                A.Normalize(mean=[0.5], std=[0.5]),
                ToTensorV2(),
            ])
        else:
            return A.Compose([
                A.Normalize(mean=[0.5], std=[0.5]),
                ToTensorV2(),
            ])
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a paired sample of non-contrast and contrast-enhanced CT scans."""
        row = self.pairs_df.iloc[idx]
        
        # Load input (non-contrast) and target (contrast) series
        input_series = self._load_dicom_series(row['InputPath'])
        target_series = self._load_dicom_series(row['TargetPath'])
        
        # Add channel dimension if needed
        if input_series.ndim == 2:
            input_series = input_series[None, ...]  # Add channel dim
            target_series = target_series[None, ...]
        elif input_series.ndim == 3:
            input_series = input_series[:, None, ...]  # Add channel dim after slice dim
            target_series = target_series[:, None, ...]
        
        # Apply transformations
        if self.transform:
            if input_series.ndim == 3:  # Multiple slices
                transformed_slices_input = []
                transformed_slices_target = []
                for i in range(len(input_series)):
                    transformed = self.transform(
                        image=input_series[i, 0],  # Remove channel dim
                        mask=target_series[i, 0]
                    )
                    transformed_slices_input.append(transformed['image'])
                    transformed_slices_target.append(transformed['mask'])
                input_series = torch.stack(transformed_slices_input)
                target_series = torch.stack(transformed_slices_target)
            else:  # Single slice
                transformed = self.transform(
                    image=input_series[0],  # Remove channel dim
                    mask=target_series[0]
                )
                input_series = transformed['image']
                target_series = transformed['mask']
        
        return {
            'input': input_series,
            'target': target_series,
            'input_label': row['InputLabel'],
            'target_label': row['TargetLabel'],
            'case_id': row['caseId']
        }

def get_vindr_dataloader(
    pairs_csv: str,
    batch_size: int = 4,
    num_workers: int = 4,
    phase: str = "train",
    transform: Optional[A.Compose] = None,
    slice_selection: str = "middle",
    max_slices: Optional[int] = None
) -> torch.utils.data.DataLoader:
    """Create a dataloader for the Vindr CT dataset.
    
    Args:
        pairs_csv (str): Path to CSV containing paired scan information
        batch_size (int): Batch size for the dataloader
        num_workers (int): Number of worker processes for data loading
        phase (str): Dataset phase ('train', 'val', or 'test')
        transform (Optional[A.Compose]): Custom transformations to apply
        slice_selection (str): How to select slices ("middle", "all", or int)
        max_slices (Optional[int]): Maximum number of slices to load
        
    Returns:
        torch.utils.data.DataLoader: DataLoader for the dataset
    """
    dataset = VindrCTDataset(
        pairs_csv=pairs_csv,
        transform=transform,
        phase=phase,
        slice_selection=slice_selection,
        max_slices=max_slices
    )
    
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(phase == "train"),
        num_workers=num_workers,
        pin_memory=True
    )

In [None]:
train_loader = get_vindr_dataloader(
    pairs_csv="path/to/series_pairs.csv",
    batch_size=4,
    phase="train",
    slice_selection="middle"
)

In [None]:
train_loader = get_vindr_dataloader(
    pairs_csv="path/to/series_pairs.csv",
    batch_size=2,
    phase="train",
    slice_selection=16  # Load 16 evenly spaced slices
)

In [21]:
train_loader = get_vindr_dataloader(
    pairs_csv="/media/disk1/saeedeh_danaei/ncct_cect/vindr_ds/series_pairs.csv",
    batch_size=1,
    phase="train",
    slice_selection="all",
    max_slices=32  # Optional: limit slices for memory
)

  A.GaussNoise(var_limit=(10.0, 50.0)),
