In [5]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm  # Progress bar
import pandas as pd
import pydicom
from PIL import Image
from torch.utils.data import Dataset, DataLoader


csv_path = r"G:\Meine Ablage\Universität\Master Thesis\Pneumonia\training\grouped_data.csv"
image_folder = r"C:\Users\Admin\Documents\rsna-pneumonia-detection-challenge\stage_2_train_images"


data = pd.read_csv(csv_path)

# Dataset class for Pneumonia
class PneumoniaDataset(Dataset):
    def __init__(self, dataframe, image_folder, transform=None):
        self.data = dataframe
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = os.path.join(self.image_folder, f"{row['patientId']}.dcm")
        label = row['Target']

        # Load DICOM file and process it into RGB format
        dicom = pydicom.dcmread(image_path)
        image = dicom.pixel_array
        image = Image.fromarray(image).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)


# Define dataset without normalization
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor (scales to [0,1])
])

# Load dataset
dataset = PneumoniaDataset(dataframe=data, image_folder=image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# Compute mean and std
mean = torch.zeros(3)
std = torch.zeros(3)
num_pixels = 0

for images, _ in tqdm(dataloader):
    batch_pixels = images.size(0) * images.size(2) * images.size(3)  # B x H x W
    mean += images.sum(dim=[0, 2, 3])
    std += (images ** 2).sum(dim=[0, 2, 3])
    num_pixels += batch_pixels

mean /= num_pixels
std = torch.sqrt(std / num_pixels - mean ** 2)  # Standard deviation formula

print(f"Mean: {mean.tolist()}")
print(f"Std: {std.tolist()}")


100%|██████████| 834/834 [10:12<00:00,  1.36it/s]

Mean: [0.4901120066642761, 0.4901120066642761, 0.4901120066642761]
Std: [0.24817368388175964, 0.24817368388175964, 0.24817368388175964]



