In [None]:
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
import cv2
from PIL import Image
from torchvision.transforms import *

class VascularDataset(Dataset):
    """
    Dataset class for the vascular dataset:
    """
    def __init__(self, list_images_input, list_images_target, mean_normalization=None, std_normalization=None, size=224, crop=None,
                 transform=None, return_names=False, return_classification_labels=False):
        super(VascularDataset, self).__init__()
        self.images_input = list_images_input
        self.images_target=list_images_target
        self.mean_normalization=mean_normalization
        self.std_normalization=std_normalization
        self.totensor=ToTensor()
        self.normalize=Normalize(self.mean_normalization, self.std_normalization)
        self.crop=crop
        self.resize=Compose([Resize(size)])
        self.transform=transform
        self.return_names=return_names
        
    def __len__(self):
        return len(self.images_input)

    def __getitem__(self, index):

      img_path = str(self.images_input[index])
      img = Image.open(img_path).convert('RGB')

      seg_path = self.images_target[index]
      seg=Image.open(seg_path)
      seg = seg.convert('RGB')    # remove the transparent portion of the image
      seg = seg.convert('L')      # from RGB to black and white
      
      #Resize
      if self.crop:
        cropped = self.crop(image=np.array(img), mask=np.array(seg))          
        img = cropped['image']
        seg = cropped['mask']

      #Transformations
      if type(img)==np.ndarray:
        img=Image.fromarray(img)
        seg=Image.fromarray(seg)
        
      img=self.resize(img)
      seg=self.resize(seg)
      
      if self.transform:
          transformed = self.transform(image=np.array(img), mask=np.array(seg))          
          img = transformed['image']
          seg = transformed['mask']

      seg=(np.array(seg)>0).astype(float)
      #seg = seg.point(lambda x: 0 if x < 1 else 255.0, '1')
      #seg = np.array(seg, dtype=np.float32)  # equivalent to a cv2 image
      seg = np.expand_dims(seg, axis=0)
      img=self.totensor(img)

      if self.mean_normalization:
        img=self.normalize(img)

      if self.return_names==True:
        return img, img_path, seg, seg_path
        
      else:
        return img, seg
