In [1]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os
import nibabel as nib
from PIL import Image
import torchvision.transforms as transforms

In [12]:
class RIGADataset(Dataset):

    def __init__(self, dataset_location='/vol/biodata/data/RIGA/MESSIDOR/', transform=None):
        self.transform = transform
        self.dataset_location = dataset_location
        self.data = self.load_data()

    def load_data(self):

        data = {}

        for file in os.listdir(self.dataset_location)[:500]:
            filename = os.fsdecode(file)

            # isolate image number for dictionary storing
            if '-' in filename:
                image_name = filename.split('-')[0]
            elif 'prime' in filename:
                image_name = filename.split('prime')[0]
            else:
                pass

            # load image and convert to tensor
            img = Image.open(self.dataset_location + filename)
            img_tensor = transforms.ToTensor()(img)

            if self.transform:
                img_tensor = self.transform(img_tensor)

            if image_name in data.keys():
                if 'prime' not in filename:
                    data[image_name]['masks'].append(img_tensor)
                else:
                    data[image_name]['image'] = img_tensor
            else:
                data.update({image_name: {'image': [], 'masks': []}})


        # ensure that each image has 6 masks
        for key in data.keys():
            assert len(data[key]['masks']) == 6

        del img
        del img_tensor

        return data

    def __len__(self):
        return len(self.data.keys())
    
    def __getitem__(self, idx):
        key = list(self.data.keys())[idx]
        image = self.data[key]['image']
        masks = self.data[key]['masks']
        if self.transform:
            image = self.transform(image)
        # if self.target_transform:
        #     masks = self.target_transform(masks)
        return image, masks

In [13]:
data = RIGADataset()

KeyboardInterrupt: 

In [None]:
dataloader = DataLoader(data, batch_size=16, shuffle=True)

In [None]:
train_features, train_labels = next(iter(dataloader))

{'image': [], 'masks': [tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0078, 0.0078, 0.0078],
         [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
         [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
         ...,
         [0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0078, 0.0078],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0039],
         [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
         ...,
         [0.0039, 0.0078, 0.0118,  ..., 0.0078, 0.0078, 0.0039],
         [0.0078, 0.0118, 0.0118,  ..., 0.0078, 0.0039, 0.0039],
         [0.0039, 0.0078, 0.0078,  ..., 0.0118, 0.0078, 0.0039]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0078, 0.0078],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0039],
         [0.0000, 0.0000, 0.0000, 

RuntimeError: each element in list of batch should be of equal size

In [None]:
train_labels

NameError: name 'train_labels' is not defined