In [None]:
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
import cv2
from PIL import Image
from torchvision.transforms import *

ccRCC=0
pRCC=1

class VascularClassificationDataset(Dataset):
    """
    Dataset class for the vascular dataset:
    """
    def __init__(self, list_images_input, mean_normalization=None, std_normalization=None, size=224, crop=None,
                 transform=None, return_names=False):
        super(VascularClassificationDataset, self).__init__()
        self.images_input = list_images_input
        self.mean_normalization=mean_normalization
        self.std_normalization=std_normalization
        self.totensor=ToTensor()
        self.normalize=Normalize(self.mean_normalization, self.std_normalization)
        self.crop=crop
        self.resize=Compose([Resize(size)])
        self.transform=transform
        self.return_names=return_names
        
    def __len__(self):
        return len(self.images_input)

    def __getitem__(self, index):

      img_path = str(self.images_input[index])
      img = Image.open(img_path).convert('L')

      if 'ccRCC' in img_path:
        label=ccRCC
      elif 'pRCC' in img_path:
        label=pRCC

      #Resize
      if self.crop:
        cropped = self.crop(image=np.array(img))          
        img = cropped['image']
        
      img=self.resize(img)
      
      if self.transform:
          transformed = self.transform(image=np.array(img))          
          img = transformed['image']

      img=self.totensor(img)

      if self.mean_normalization:
        img=self.normalize(img)

      if self.return_names==True:
        return img, img_path, label
        
      else:
        return img, label
