https://github.com/Beckschen/TransUNet/blob/main/datasets/dataset_synapse.py

In [1]:
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset

## Auxiliary functions 

In [2]:
def random_rot_flip(image, label):
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label


def random_rotate(image, label):
    angle = np.random.randint(-20, 20)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    label = ndimage.rotate(label, angle, order=0, reshape=False)
    return image, label


class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long()}
        return sample

## Dataset

In [3]:
class SynapseDataset(Dataset):
    def __init__(self, base_dir, list_dir, split, transform=None):
        self.transform = transform  # using transform in torch!
        self.split = split
        self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
        self.data_dir = base_dir

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = os.path.join(self.data_dir, slice_name+'.npz')
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            vol_name = self.sample_list[idx].strip('\n')
            filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
            data = h5py.File(filepath)
            image, label = data['image'][:], data['label'][:]

        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample

## Test

In [None]:
import sys
sys.path.append('..')
from utils import show_sbs
from torch.utils.data import DataLoader, Subset
from torchvision import transforms



# # ------------------- params --------------------
INPUT_SIZE = 256
BASE_DIR = 
LIST_DIR = 

# TR_BATCH_SIZE = 8
# TR_DL_SHUFFLE = True
# TR_DL_WORKER = 1

# VL_BATCH_SIZE = 12
# VL_DL_SHUFFLE = False
# VL_DL_WORKER = 1

# TE_BATCH_SIZE = 12
# TE_DL_SHUFFLE = False
# TE_DL_WORKER = 1
# # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# # ----------------- transform ------------------
transform = transforms.Compose([
    RandomGenerator(output_size=[INPUT_SIZE, INPUT_SIZE])
])
# # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# # ----------------- dataset --------------------
# preparing training dataset
tr_dataset = SynapseDataset(
    base_dir = , 
    list_dir = , 
    split="train",
    transform=transform
)
print("The length of train set is: {}".format(len(tr_dataset)))

    
    
# # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
# # !cat ~/deeplearning/skin/Prepare_ISIC2018.py

# indices = list(range(len(train_dataset)))

# # split indices to: -> train, validation, and test
# tr_indices = indices[0:1815]
# vl_indices = indices[1815:1815+259]
# te_indices = indices[1815+259:2594]

# # create new datasets from train dataset as training, validation, and test
# tr_dataset = Subset(train_dataset, tr_indices)
# vl_dataset = Subset(train_dataset, vl_indices)
# te_dataset = Subset(train_dataset, te_indices)

import random
def worker_init_fn(worker_id):
    random.seed(args.seed + worker_id)


# # prepare train dataloader
trainloader = DataLoader(
    tr_dataset, 
    batch_size=TR_BATCH_SIZE, 
    shuffle=TR_DL_SHUFFLE, 
    num_workers=TR_DL_WORKER,
    pin_memory=True,
    worker_init_fn=worker_init_fn
)

# # prepare validation dataloader
# vl_loader = DataLoader(
#     vl_dataset, 
#     batch_size=VL_BATCH_SIZE, 
#     shuffle=VL_DL_SHUFFLE, 
#     num_workers=VL_DL_WORKER,
#     pin_memory=True
# )

# # prepare test dataloader
# te_loader = DataLoader(
#     te_dataset, 
#     batch_size=TE_BATCH_SIZE, 
#     shuffle=TE_DL_SHUFFLE, 
#     num_workers=TE_DL_WORKER,
#     pin_memory=True
# )

# -------------- test -----------------
# test and visualize the input data
for batch in tr_loader:
    print("Training")
    img = batch['image']
    msk = batch['label']
    show_sbs(img[0], msk[0])
    break
    
# for img, msk in vl_loader:
#     print("Validation")
#     show_sbs(img[0], msk[0])
#     break