In [4]:
#!/usr/bin/env python
# coding: utf-8

# In[2]:


import numpy as np
import pandas as pd
import os
%run transforms.ipynb

import  PIL
from PIL import Image
#import gdcm
import tensorflow as tf
#import pydicom as dicom
from pathlib import Path

#from keras.preprocessing.image import ImageDataGenerator

from skimage.transform import resize
import tensorflow.keras
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from sklearn.utils import class_weight
import scipy
import skimage
import time
from time import perf_counter
import datetime
import logging
import random
from tensorflow.keras.callbacks import ModelCheckpoint
import sys
import datetime
import imageio

# In[3]:


### data generator


class DataGenerator(tensorflow.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs_inputs, list_IDs_masks, epochs, batch_size, dim, n_channels,shuffle, num_output_nodes,augmentation,crop_pad):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        #self.list_IDs = list_IDs
        self.list_IDs_inputs = list_IDs_inputs
        self.list_IDs_masks =  list_IDs_masks
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.num_output_nodes = num_output_nodes
        self.augmentation = augmentation
        self.crop_pad = crop_pad
        self.on_epoch_end()
        #to use next
        self.n = 0
        self.max = self.__len__()
        
    #to use next
    def __next__(self):
        if self.n >= self.max:
           self.n = 0
        result = self.__getitem__(self.n)
        self.n += 1
        return result

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs_inputs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_inputs = [self.list_IDs_inputs[k] for k in indexes]
        list_IDs_masks = [self.list_IDs_masks[k] for k in indexes]

        # Generate data
        X, Y = self.__data_generation(list_IDs_inputs,list_IDs_masks)

        return X, Y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs_inputs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            #print(self.indexes)
            
    def standardize(self, pixel_array):
        #Normilsation
        #pixel_array = (pixel_array-np.amin(pixel_array))/(np.amax(pixel_array)-np.amin(pixel_array))
        #Standardisation
        pixel_array = (pixel_array - np.mean(pixel_array)) / np.std(pixel_array)
        return pixel_array
    
    def normalize(self,pixel_array):
        pixel_array = (pixel_array - pixel_array.min())/(pixel_array.max()-pixel_array.min())
        return pixel_array
            
    ###extract all slices from a series based on one slice path
    def extract_and_augment(self,image_path,random_seed_aug,case):
         
               
        
        #Load data   
        pixel_array = imageio.imread(image_path)
      
    
        #augmentation        
        rand = random.randrange(0, 100, 1)
        if rand < self.augmentation * 100 :
            pixel_array = randaugm(pixel_array,random_seed_aug,case)
            
        
        if case == "mask":
            pixel_array = np.expand_dims(pixel_array,axis = 2)

        return pixel_array 
    
    def image_processing(self,ID, standardize, normalize,random_seed_aug, case):
        
        array = self.extract_and_augment(ID,random_seed_aug,case)
        
        #standardization
        if standardize == True:
            array = self.standardize(array)
        elif standardize == False:
            array = array
        
        #normalization
        if normalize == True:
            array = self.normalize(array)
        elif normalize == False:
            array = array
            
        
        #print ("Min/Max pixel: " ,np.amin(array),"/" ,np.amax(array))
        return array
        
    
    def __data_generation(self, list_IDs_inputs, list_IDs_masks):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)


        X = np.empty((self.batch_size, *self.dim))
        Y = np.empty((self.batch_size,*self.dim[:2],1))
        
        #for dataugmentation
        
        list_IDs_all = [(list_IDs_inputs[i], list_IDs_masks[i]) for i in range(0, len(list_IDs_inputs))]
        
        # Generate data
        for i, ID_tuple in enumerate(list_IDs_all): 
             # Store sample
             #X[i,] = np.load('data/' + ID + '.npy')
            
            if self.augmentation > 0:
                random_seed_aug = random.randint(1,1000000000)
            elif self.augmentation == 0:
                random_seed_aug = None
                
            
            X[i] = self.image_processing(ID_tuple[0], False, True,random_seed_aug,"image")
            Y[i] = self.image_processing(ID_tuple[1], False, False,random_seed_aug,"mask")
            
            Y = np.where(Y > 0.5, 1, 0)
        return X, Y