In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

In [2]:
train_data_path = 'datasets/B-Disease_Grading/1-Original_Images/a-Train_Set'
test_data_path = 'datasets/B-Disease_Grading/1-Original_Images/b-Test_Set'
train_labels = pd.read_csv('datasets/B-Disease_Grading/2-Groundtruths/a-Train_Labels.csv')
test_labels = pd.read_csv('datasets/B-Disease_Grading/2-Groundtruths/b-Test_Labels.csv')

In [3]:
class IDRiDDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image name'] + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)#.convert('RGB')

        retinopathy = int(self.data.iloc[idx]['Retinopathy grade'])
        edema = int(self.data.iloc[idx]['Risk of macular edema '])

        if self.transform:
            image = self.transform(image)

        return image, {'retinopathy': retinopathy, 'edema': edema}

In [4]:
def compute_mean_std(dataset):
    loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)

    mean = 0.0
    std = 0.0
    nb_samples = 0

    for images, _ in tqdm(loader):
        batch_samples = images.size(0)  # batch size (32)
        images = images.view(batch_samples, images.size(1), -1)  # [B, C, H*W]
        
        mean += images.mean(2).sum(0)  # sum of channel means
        std += images.std(2).sum(0)    # sum of channel stds
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples

    return mean.numpy(), std.numpy()

transform_mean_std = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(448),
    transforms.ToTensor()
])

dataset_mean_std = IDRiDDataset(
    csv_file='datasets/B-Disease_Grading/2-Groundtruths/a-Train_Labels.csv',
    img_dir='datasets/B-Disease_Grading/1-Original_Images/a-Train_Set',
    transform=transform_mean_std
)

mean, std = compute_mean_std(dataset_mean_std)
print(f"Mean: {mean}\nStd: {std}")

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:12<00:00,  1.03it/s]

Mean: [0.64398366 0.31709325 0.10248035]
Std: [0.11275874 0.08133845 0.03794106]





In [5]:
transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(448),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=mean,
        std=std
    )
])

train_dataset = IDRiDDataset(
    img_dir='datasets/B-Disease_Grading/1-Original_Images/a-Train_Set',
    csv_file='datasets/B-Disease_Grading/2-Groundtruths/a-Train_Labels.csv',
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)