In [1]:
import pydicom, cv2, torch
from torchvision import transforms



In [None]:
# Suppose df is a DataFrame with columns ['img_path','race_label']
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)), # validate the size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # example normalization for single-channel
])

class CXRDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # read image (DICOM to pixel array then to uint8)
        dicom = pydicom.dcmread(row['img_path'])
        img = dicom.pixel_array.astype('float32')
        img = (img - img.min()) / (img.max() - img.min() + 1e-6) * 255.0  # normalize to [0,255]
        img = img.astype('uint8')
        if self.transform:
            img = self.transform(img)
        label = row['race_label']  # e.g. 0,1,2 for categories
        return img, label
