https://github.com/Beckschen/TransUNet/blob/main/datasets/dataset_synapse.py

In [None]:
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset

## Auxiliary functions 

In [None]:
def random_rot_flip(image, label):
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label


def random_rotate(image, label):
    angle = np.random.randint(-20, 20)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    label = ndimage.rotate(label, angle, order=0, reshape=False)
    return image, label


class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long()}
        return sample

---
## temp

In [None]:

Nimgs = 20
channels = 3
height = 584
width = 565
save_data_dir = "./DRIVE_datasets_training_testing/"

#------------Path of the images -----------------------------
self.original_base_dir = '/home/staff/azad/deeplearning/datasets/DRIVE'
#train
self.original_imgs_train_dir = f"{self.original_base_dir}/training/images/"
self.groundTruth_imgs_train_dir = f"{self.original_base_dir}/training/1st_manual/"
self.borderMasks_imgs_train_dir = f"{self.original_base_dir}/training/mask/"
#test
self.original_imgs_test_dir = f"{self.original_base_dir}/test/images/"
self.groundTruth_imgs_test_dir = f"{self.original_base_dir}/test/1st_manual/"
self.borderMasks_imgs_test_dir = f"{self.original_base_dir}/test/mask/"
#------------------------------------------------------------


def __get_datasets(
    self, 
    imgs_dir, 
    groundTruth_dir, 
    borderMasks_dir, 
    train_test="null"
):
    Nimgs = self.Nimgs
    channels = self.channels
    height = self.height
    width = self.width

    imgs = np.empty((Nimgs,height,width,channels))
    gts = np.empty((Nimgs,height,width))
    border_masks = np.empty((Nimgs,height,width))
    
    for path, subdirs, files in os.walk(imgs_dir): #list all files, directories in the path
        for i in range(len(files)):
            #original
            print ("original image: " +files[i])
            img = Image.open(imgs_dir+files[i])
            imgs[i] = np.asarray(img)
            #corresponding ground truth
            groundTruth_name = files[i][0:2] + "_manual1.gif"
            print ("ground truth name: " + groundTruth_name)
            g_truth = Image.open(groundTruth_dir + groundTruth_name)
            gts[i] = np.asarray(g_truth)
            #corresponding border masks
            border_masks_name = ""
            if train_test=="train":
                border_masks_name = files[i][0:2] + "_training_mask.gif"
            elif train_test=="test":
                border_masks_name = files[i][0:2] + "_test_mask.gif"
            else:
                print("specify if train or test!!")
                exit()
            print("border masks name: " + border_masks_name)
            b_mask = Image.open(borderMasks_dir + border_masks_name)
            border_masks[i] = np.asarray(b_mask)

    print("imgs max: "+str(np.max(imgs)))
    print("imgs min: "+str(np.min(imgs)))
    assert(np.max(gts)==255 and np.max(border_masks)==255)
    assert(np.min(gts)==0 and np.min(border_masks)==0)
    print("ground truth and border masks are correctly withih pixel value range 0-255 (black-white)")
    
    #reshaping for my standard tensors
    imgs = np.transpose(imgs,(0,3,1,2))
    assert(imgs.shape == (Nimgs,channels,height,width))
    gts = np.reshape(gts,(Nimgs,1,height,width))
    border_masks = np.reshape(border_masks,(Nimgs,1,height,width))
    assert(gts.shape == (Nimgs,1,height,width))
    assert(border_masks.shape == (Nimgs,1,height,width))
    
    return imgs, gts, border_masks


def prepare_dataset(self):
    if not os.path.exists(save_data_dir):
        os.makedirs(save_data_dir)

    #getting the training datasets
    imgs_train, groundTruth_train, border_masks_train = self.__get_datasets(
        self.original_imgs_train_dir,
        self.groundTruth_imgs_train_dir,
        self.borderMasks_imgs_train_dir,
        "train"
    )
    print("saving train datasets")
    write_hdf5(imgs_train, save_data_dir + "DRIVE_dataset_imgs_train.hdf5")
    write_hdf5(groundTruth_train, save_data_dir + "DRIVE_dataset_groundTruth_train.hdf5")
    write_hdf5(border_masks_train,save_data_dir + "DRIVE_dataset_borderMasks_train.hdf5")

    #getting the testing datasets
    imgs_test, groundTruth_test, border_masks_test = self.__get_datasets(
        self.original_imgs_test_dir,
        self.groundTruth_imgs_test_dir,
        self.borderMasks_imgs_test_dir,
        "test"
    )
    print("saving test datasets")
    write_hdf5(imgs_test,save_data_dir + "DRIVE_dataset_imgs_test.hdf5")
    write_hdf5(groundTruth_test, save_data_dir + "DRIVE_dataset_groundTruth_test.hdf5")
    write_hdf5(border_masks_test,save_data_dir + "DRIVE_dataset_borderMasks_test.hdf5")




In [None]:
# save patches

from help_functions import *
from extract_patches import *

#function to obtain data for training/testing (validation)
from extract_patches import get_data_training

#========= Load settings from Config file
#patch to the datasets
path_data = './DRIVE_datasets_training_testing/'

print('extracting patches')
patches_imgs_train, patches_masks_train = get_data_training(
    DRIVE_train_imgs_original = path_data + 'DRIVE_dataset_imgs_train.hdf5',
    DRIVE_train_groudTruth    = path_data + 'DRIVE_dataset_groundTruth_train.hdf5',  #masks
    patch_height = 64,
    patch_width  = 64,
    N_subimgs    = 200000,
    inside_FOV = 'True' #select the patches only inside the FOV  (default == True)
)

np.save('patches_imgs_train',patches_imgs_train)
np.save('patches_masks_train',patches_masks_train)

In [None]:
# preprocessing functions

from __future__ import division
###################################################
#
#   Script to pre-process the original imgs
#
##################################################


import numpy as np
from PIL import Image
import cv2

from help_functions import *


#My pre processing (use for both training and testing!)
def my_PreProc(data):
    assert(len(data.shape)==4)
    assert (data.shape[1]==3)  #Use the original images
    #black-white conversion
    train_imgs = rgb2gray(data)
    #my preprocessing:
    train_imgs = dataset_normalized(train_imgs)
    train_imgs = clahe_equalized(train_imgs)
    train_imgs = adjust_gamma(train_imgs, 1.2)
    train_imgs = train_imgs/255.  #reduce to 0-1 range
    return train_imgs


#============================================================
#========= PRE PROCESSING FUNCTIONS ========================#
#============================================================

#==== histogram equalization
def histo_equalized(imgs):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==1)  #check the channel is 1
    imgs_equalized = np.empty(imgs.shape)
    for i in range(imgs.shape[0]):
        imgs_equalized[i,0] = cv2.equalizeHist(np.array(imgs[i,0], dtype = np.uint8))
    return imgs_equalized


# CLAHE (Contrast Limited Adaptive Histogram Equalization)
#adaptive histogram equalization is used. In this, image is divided into small blocks called "tiles" (tileSize is 8x8 by default in OpenCV). Then each of these blocks are histogram equalized as usual. So in a small area, histogram would confine to a small region (unless there is noise). If noise is there, it will be amplified. To avoid this, contrast limiting is applied. If any histogram bin is above the specified contrast limit (by default 40 in OpenCV), those pixels are clipped and distributed uniformly to other bins before applying histogram equalization. After equalization, to remove artifacts in tile borders, bilinear interpolation is applied
def clahe_equalized(imgs):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==1)  #check the channel is 1
    #create a CLAHE object (Arguments are optional).
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    imgs_equalized = np.empty(imgs.shape)
    for i in range(imgs.shape[0]):
        imgs_equalized[i,0] = clahe.apply(np.array(imgs[i,0], dtype = np.uint8))
    return imgs_equalized


# ===== normalize over the dataset
def dataset_normalized(imgs):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==1)  #check the channel is 1
    imgs_normalized = np.empty(imgs.shape)
    imgs_std = np.std(imgs)
    imgs_mean = np.mean(imgs)
    imgs_normalized = (imgs-imgs_mean)/imgs_std
    for i in range(imgs.shape[0]):
        imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255
    return imgs_normalized


def adjust_gamma(imgs, gamma=1.0):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==1)  #check the channel is 1
    # build a lookup table mapping the pixel values [0, 255] to
    # their adjusted gamma values
    invGamma = 1.0 / gamma
    table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
    # apply gamma correction using the lookup table
    new_imgs = np.empty(imgs.shape)
    for i in range(imgs.shape[0]):
        new_imgs[i,0] = cv2.LUT(np.array(imgs[i,0], dtype = np.uint8), table)
    return new_imgs

In [None]:
# def preprocess_dataset():


## Dataset

In [None]:
class SynapseDataset(Dataset):
    def __init__(self, base_dir, list_dir, split, transform=None):
        self.transform = transform  # using transform in torch!
        self.split = split
        self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
        self.data_dir = base_dir

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = os.path.join(self.data_dir, slice_name+'.npz')
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            vol_name = self.sample_list[idx].strip('\n')
            filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
            data = h5py.File(filepath)
            image, label = data['image'][:], data['label'][:]

        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample

## Test

In [None]:
import sys
sys.path.append('..')
from utils import show_sbs
from torch.utils.data import DataLoader, Subset
from torchvision import transforms



# # ------------------- params --------------------
INPUT_SIZE = 256
BASE_DIR = 
LIST_DIR = 

# TR_BATCH_SIZE = 8
# TR_DL_SHUFFLE = True
# TR_DL_WORKER = 1

# VL_BATCH_SIZE = 12
# VL_DL_SHUFFLE = False
# VL_DL_WORKER = 1

# TE_BATCH_SIZE = 12
# TE_DL_SHUFFLE = False
# TE_DL_WORKER = 1
# # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# # ----------------- transform ------------------
transform = transforms.Compose([
    RandomGenerator(output_size=[INPUT_SIZE, INPUT_SIZE])
])
# # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# # ----------------- dataset --------------------
# preparing training dataset
tr_dataset = SynapseDataset(
    base_dir = , 
    list_dir = , 
    split="train",
    transform=transform
)
print("The length of train set is: {}".format(len(tr_dataset)))

    
    
# # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
# # !cat ~/deeplearning/skin/Prepare_ISIC2018.py

# indices = list(range(len(train_dataset)))

# # split indices to: -> train, validation, and test
# tr_indices = indices[0:1815]
# vl_indices = indices[1815:1815+259]
# te_indices = indices[1815+259:2594]

# # create new datasets from train dataset as training, validation, and test
# tr_dataset = Subset(train_dataset, tr_indices)
# vl_dataset = Subset(train_dataset, vl_indices)
# te_dataset = Subset(train_dataset, te_indices)

import random
def worker_init_fn(worker_id):
    random.seed(args.seed + worker_id)


# # prepare train dataloader
trainloader = DataLoader(
    tr_dataset, 
    batch_size=TR_BATCH_SIZE, 
    shuffle=TR_DL_SHUFFLE, 
    num_workers=TR_DL_WORKER,
    pin_memory=True,
    worker_init_fn=worker_init_fn
)

# # prepare validation dataloader
# vl_loader = DataLoader(
#     vl_dataset, 
#     batch_size=VL_BATCH_SIZE, 
#     shuffle=VL_DL_SHUFFLE, 
#     num_workers=VL_DL_WORKER,
#     pin_memory=True
# )

# # prepare test dataloader
# te_loader = DataLoader(
#     te_dataset, 
#     batch_size=TE_BATCH_SIZE, 
#     shuffle=TE_DL_SHUFFLE, 
#     num_workers=TE_DL_WORKER,
#     pin_memory=True
# )

# -------------- test -----------------
# test and visualize the input data
for batch in tr_loader:
    print("Training")
    img = batch['image']
    msk = batch['label']
    show_sbs(img[0], msk[0])
    break
    
# for img, msk in vl_loader:
#     print("Validation")
#     show_sbs(img[0], msk[0])
#     break