In [None]:
# ---------- Example DataLoading Datasets MIMIC/CHEXPERT ----------
import torch
from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler
from utils.data.chexpert_dataset import CHEXPERTDataset
from utils.data.mimic_dataset import MIMICDataset
from utils.processing import image_transform, loader

# MODELS
MODELS_DIR = "models/"
SEGMENTER_MODEL_PATH = f"{MODELS_DIR}dino_unet_decoder_finetuned.pth"
save_path = f"{MODELS_DIR}complete_model.pth"
checkpoint_path = f"{MODELS_DIR}model_checkpoint.pth"

# CheXpert
CHEXPERT_DIR = "Datasets/CheXpertPlus"
chexpert_paths = {
    "chexpert_data_path": f"{CHEXPERT_DIR}/PNG",  # base PNG folder
    "chexpert_data_csv": f"{CHEXPERT_DIR}/df_chexpert_plus_240401.csv",
}

# MIMIC
MIMIC_DIR = "Datasets/MIMIC"
mimic_paths = {
    "mimic_data_path": MIMIC_DIR,
    "mimic_splits_csv": f"{MIMIC_DIR}/mimic-cxr-2.0.0-split.csv.gz",
    "mimic_metadata_csv": f"{MIMIC_DIR}/mimic-cxr-2.0.0-metadata.csv",
    "mimic_reports_path": f"{MIMIC_DIR}/cxr-record-list.csv.gz",  # must contain 'path'
}

# Build dataframes
MIMIC_df, CHEXPERT_df = loader(chexpert_paths, mimic_paths, split="train")

# Example dataset usage
mimic_images_dir = "Datasets/MIMIC/matched_images_and_masks_mimic_224/images"
mimic_reports_dir = "Datasets/MIMIC"  # base; dataset derives "<rel_dir>.txt"


transform = image_transform(img_size=512)
mimic_ds = MIMICDataset(MIMIC_df, mimic_images_dir, mimic_reports_dir, transform=transform)

chexpert_ds = CHEXPERTDataset(
    CHEXPERT_df,
    chexpert_paths["chexpert_data_path"],
    split="train",
    transform=transform
)

mixed = ConcatDataset([mimic_ds, chexpert_ds])
n1, n2 = len(mimic_ds), len(chexpert_ds)
p1, p2 = 0.7, 0.3  # desired sampling ratio

# per-sample weights: higher weight â†’ sampled more often
w1 = torch.full((n1,), fill_value=p1 / max(n1, 1), dtype=torch.float)
w2 = torch.full((n2,), fill_value=p2 / max(n2, 1), dtype=torch.float)
weights = torch.cat([w1, w2])

sampler = WeightedRandomSampler(weights, num_samples=n1 + n2, replacement=True)

# Dataloader tuning for cloud I/O
loader = DataLoader(
    mixed,
    batch_size=32,
    sampler=sampler,
    # num_workers=os.cpu_count() // 2 if os.cpu_count() else 4,  # adjust on your VM
    # persistent_workers=True,           # reuses workers between iterations
    # prefetch_factor=4,                 # each worker prefetches batches
    # pin_memory=True,                   # if using CUDA
    # drop_last=False
)
print("DataLoader created with dataset length:", len(mixed))
print("MIMIC dataset size: ", len(mimic_ds))
print("CheXpert dataset size: ", len(chexpert_ds))
# Example read
images, findings, image_paths, _ = next(iter(loader))
print("Batch image tensor shape:", getattr(images, "shape", "N/A"))
print("Batch findings shape:", getattr(findings, "shape", len(findings)))
print("Batch image paths shape:", getattr(image_paths, "shape", len(image_paths)))

Local images_dir detected; filtering rows with missing PNGs...
[INFO] Kept 82440/188960 rows with existing PNGs
DataLoader created with dataset length: 320412
MIMIC dataset size:  237972
CheXpert dataset size:  82440
Batch image tensor shape: torch.Size([32, 3, 512, 512])
Batch findings shape: 32
Batch image paths shape: 32
