<a href="https://colab.research.google.com/github/Schiweppes/My-Deep-Learning-Notebooks/blob/main/Github_Histopathology_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install imagecodecs
%pip install patchify

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import imageio
from PIL import Image
import imagecodecs
from patchify import patchify,unpatchify
import glob

from typing import Tuple
import os


from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split as split

from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, \
concatenate, BatchNormalization, Activation, add
from keras.layers.advanced_activations import ELU, LeakyReLU

import keras.backend as K
import tensorflow as tf 

import tensorflow.keras as keras

# Functions

In [None]:
def read_data(path:str = None)->Tuple[np.ndarray,np.ndarray,int]:
    """
    Returns the first image and first label as a numpy array

            Keyword arguments:
                    path (str) : Absolute path to Dataset Folder

                 ..\Dataset Folder
                    |--\Images\
                        |--image1.png
                        ...         
                    |--\Labels\
                        |--image1_label.png
                        ...

                    
            Returns:
                    image(np.ndarray), label(np.ndarray),size(int) : image, label and size of the dataset
    """
    if not path:
        image = np.array(imageio.imread('/content/drive/MyDrive/histopatoloji\
/lab_crop/ds2_cropped_1.png'))
        label_image = np.array(imageio.imread('/content/drive/MyDrive/histopat\
oloji/lab_crop/ds2_cropped_1-labels.png'))
        return image,label_image,1
        

    else:

        image_path = path + r"/Images"
        label_path = path + r"/Labels"


        image_names = sorted(os.listdir(image_path))
        label_names = sorted(os.listdir(label_path))

        size1 = len(image_names)
        size2 = len(label_names)

        assert size1 == size2, "image and label mismatch!"

        image_sample = f"{image_path}/{image_names[0]}"

        label_sample = f"{label_path}/{label_names[0]}"


        img = np.array(imageio.imread(image_sample))
        label = np.array(imageio.imread(label_sample)) ## np array sonra

        return img,label,size1

In [None]:
def rgb2gray(rgb:imageio.core.util.Array)->np.ndarray:
    """
    Returns a gray scale image
    Keyword arguments:
        rgb_image(imageio.core.util.Array) : Image to convert to grayscale

    Returns:
        gray(np.ndarray) : Grayscale image
    """

    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return np.array(gray)

In [None]:
def patch_data(data:np.ndarray,patch_size:Tuple[int],step_size:int = 256)-> np.ndarray:
    """
    Returns patches of given data

    Keyword Arguments:
        data(np.ndarray) : data to be patched.
        patch_size((int,int,_)) : patch size
        step_size(int) : step size of the patches should equal to patch size
    Returns:
        patches(np.ndarray) : patches of data
    """
    if len(data.shape) == 3 and len(patch_size) == 2:
        raise Exception("Image should be patched by 256x256x3")

    assert step_size == patch_size[0], "Differency in step_size and patch_size\
 causes overlap!"

    data = patchify(data,patch_size, step_size)

    if data.shape[2] == 1:
        data = np.squeeze(data, axis=2)
        return data.shape,data.reshape((-1,256,256,3))
    else:
        return data.shape,data.reshape((-1,256,256))




In [None]:
def graphify(idxs,fig_size = 15):
    """
    Plots the desired amount of image and label patch

    Keyword Arguments:
    idxs(List[int]) : indexes of data to be plotted
    fig_size(int) : size of each subplot

    """
    n = len(idxs)
    plt_idx = 0
    fig,ax = plt.subplots(n, 2, figsize=(fig_size,fig_size))
    fig.tight_layout()
    ax[0][0].set_title('Original Image')
    ax[0][1].set_title('Masked Image')
    for i in idxs: 
        for j in range(2):
            image = image_patches[i,:,:,:]
            label = label_patches[i,:,:]
            if j % 2: 
                ax[plt_idx,j].imshow(label,cmap='gray')
                
            else:
                ax[plt_idx,j].imshow(image,cmap='gray')
                
        plt_idx += 1

In [None]:
def get_masks(label_patches:np.ndarray,mask_mean:float = 0.5)->list:  
    """
    Returns patch indexes of labels

    Keyword Arguments:
    label_patches(np.ndarray) : Array of label patches
    mask_mean(float) : limit value to check if mask is available in that patch

    Returns:
    mask_list(List[int]) : List of integers

    """
    mask_list = list()
    for idx,i in enumerate(label_patches):
        if i.mean() > mask_mean:
            mask_list.append(idx)
    return mask_list

In [None]:
def yield_data(path:str = None)->Tuple[np.ndarray,np.ndarray,int]:
    """
    Returns the first image and first label as a numpy array

            Keyword arguments:
                    path (str) : Absolute path to Dataset Folder

                 ..\Dataset Folder
                    |--\Images\
                        |--image1.png
                        ...         
                    |--\Labels\
                        |--image1_label.tif
                        ...

                    
            Returns:
                    image(np.ndarray), label(np.ndarray),size(int) : image, label and size of the dataset
    """
    if not path:
        image = np.array(imageio.imread('/content/drive/MyDrive/histopatoloji\
/lab_crop/ds2_cropped_1.png'))
        label_image = np.array(imageio.imread('/content/drive/MyDrive/histopat\
oloji/lab_crop/ds2_cropped_1-labels.png'))
        return image,label_image,1
        

    else:

        image_path = path + r"/Images"
        label_path = path + r"/Labels"


        image_names = os.listdir(image_path)
        label_names = os.listdir(label_path)

        for i,j in zip(sorted(image_names),sorted(label_names)):
            print("*************\n",i,j)
            image_sample = f"{image_path}/{i}"

            label_sample = f"{label_path}/{j}"
            img = np.array(imageio.imread(image_sample))
            label = np.array(imageio.imread(label_sample))
            yield (img,label)

In [None]:
def preprocess_augment(x_patch,y_patch,augment:bool = False):
    x_train,x_test,y_train,y_test = split(x_patch,
                                      y_patch,
                                      test_size = test_split_size)
    datagen = None

    if augment:
        datagen = ImageDataGenerator(rescale=1./255,
                                rotation_range = rotation_range,
                                horizontal_flip =horizontal_flip,
                                vertical_flip = vertical_flip,
                                width_shift_range=width_shift_range,
                                height_shift_range=height_shift_range)
    else:
        datagen = ImageDataGenerator(rescale=1./255)
        
    train_dataset_patch = datagen.flow(x = x_train,
                                   y = y_train,
                                   batch_size = BATCH_SIZE,
                                   seed = SEED)

    valid_dataset_patch = datagen.flow(x = x_test,
                                   y = y_test,
                                   batch_size = BATCH_SIZE,
                                   seed = SEED)
    train_step,val_step = len(x_train) // BATCH_SIZE,len(x_test)//BATCH_SIZE
    return train_dataset_patch,valid_dataset_patch,train_step,val_step  

In [None]:
def main(dataset_folder:str):
    """
    One function to run them all.

    Keyword argumnts:
        path : Main path to Dataset Folder

    Returns:
        None

    """
    for idx,(img,label) in enumerate(yield_data(path=dataset_folder)):
        label = rgb2gray(label)

        shape,image_patches = patch_data(img,(256,256,3),256)
        shape,label_patches = patch_data(label,(256,256),256)
        
        label_patches = np.where(label_patches<50,label_patches,255.)
        del img,label
        print(f"patch size : {image_patches.shape[0]}")

        assert label_patches.shape[:] == image_patches.shape[:-1]

        train_dataset,valid_dataset,train_step,val_step = preprocess_augment(
                                                         image_patches,
                                                         label_patches,
                                                         augment = False)
        del image_patches,label_patches
        ## for loading our model
        custom_objects = {
                  'FocalTverskyLoss':FocalTverskyLoss,
                  'log_cosh_dice_loss':log_cosh_dice_loss,
                  'sparse_categorical_crossentropy': keras.metrics.SparseCategoricalCrossentropy(from_logits = True),
                  'dice_loss':dice_loss,
                  'jaccard':jaccard,
                  'dice_coef':dice_coef,
                  'sensitivity':sensitivity,
                  'specificity':specificity}

        print(f"Training {idx+1}th image:")
        if len(os.listdir(checkpoint_path)) != 0:
            print("Loading trained model ...")

            loaded_model = keras.models.load_model(checkpoint_path,
                                                   custom_objects=custom_objects)

            loaded_model.compile(optimizer = keras.optimizers.Adam(
                   learning_rate = learning_rate),
                   loss = custom_objects[loss_function],
                   metrics = [jaccard,dice_coef,],
                   )
            print("Model is loaded and compiled!")
            
            loaded_model.fit(train_dataset,
                         batch_size = BATCH_SIZE,
                         epochs= 2,
                         callbacks = callbacks,
                         workers = -1,
                         validation_data = valid_dataset,
                         validation_steps = (val_step),
                         steps_per_epoch=(train_step))
            tf.keras.backend.clear_session()
        else:
            print("Creating new model...")
            
            model = MultiResUnet(256,256,3)
            model.compile(optimizer = keras.optimizers.Adam(
                   learning_rate = learning_rate),
                   loss = custom_objects[loss_function],
                   metrics = [jaccard,dice_coef,sensitivity,specificity],
                   )
            print("Model is created and compiled!")

            model.fit(train_dataset,
                         batch_size = BATCH_SIZE,
                         epochs= 2,
                         callbacks = callbacks,
                         workers = -1,
                         validation_data = valid_dataset,
                         validation_steps = (val_step),
                         steps_per_epoch=(train_step))
            tf.keras.backend.clear_session()

In [None]:
def reconstruct_patch(patch:np.ndarray,patch_shape):
    """ 
    Reconstruct the whole image.

    Keyword Arguments:
    patch(np.ndarray) : patch array to reconstruct
    patch_shape(tuple) : necessary shape information to reconstruct

    Returns:
    Reconstructed numpy array.



    """
    patch_size = 256
    shape = patch_shape
    if shape[-1] == 3:
        image_patches = patch.reshape((shape[0],shape[1],256,256,3))
        image_patches = np.expand_dims(image_patches,2)
        return unpatchify(image_patches,(shape[0]*256,shape[1]*256,3))
    else:
        patches = patch.reshape((shape[0],shape[1],256,256))
        return unpatchify(patches,(shape[0]*256,shape[1]*256))

In [None]:
def graphify_prediction(img_patch:np.ndarray, pred_patch:np.ndarray,idx_list:list,fig_size:int = 15):
    """
    Plots images and corresponding predictions

    Keyword Arguments:
    img_patch(np.ndarray) : image patch
    pred_patch(np.ndarray) : prediction patch
    idx_list (list) : list of masks
    fig_size (int) : size of figure

    

    """

    if pred_patch.shape[-1] == 1:
        pred_patch = np.squeeze(pred_patch,axis = -1)

    n = len(idx_list)
    plt_idx = 0
    fig,ax = plt.subplots(n, 2, figsize=(fig_size,fig_size))
    fig.tight_layout()
    ax[0][0].set_title('Image')
    ax[0][1].set_title('Prediction')
    for i in idx_list: 
        for j in range(2):
            image = img_patch[i,:,:,:]
            pred = pred_patch[i,:,:]
            if j % 2: 
                ax[plt_idx,j].imshow(pred,cmap='gray')
                
            else:
                ax[plt_idx,j].imshow(image,cmap='gray')
                
        plt_idx += 1

# Model

In [None]:
def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='relu'):
    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)
    x = Activation(activation)(x)

    return x

In [None]:
def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2)):
    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)
    return x

In [None]:
def MultiResBlock(U, inp, alpha = 1.67):
    W = alpha * U

    shortcut = inp

    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
                         int(W*0.5), 1, 1, activation='relu', padding='same')

    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,
                        activation='relu', padding='same')

    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
    out = BatchNormalization(axis=3)(out)

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out

In [None]:
def ResPath(filters, length, inp):
    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation='relu', padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation='relu', padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out

In [None]:
def MultiResUnet(height, width, n_channels):
    inputs = Input((height, width, n_channels))

    mresblock1 = MultiResBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
    mresblock1 = ResPath(32, 4, mresblock1)

    mresblock2 = MultiResBlock(32*2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
    mresblock2 = ResPath(32*2, 3, mresblock2)

    mresblock3 = MultiResBlock(32*4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
    mresblock3 = ResPath(32*4, 2, mresblock3)

    mresblock4 = MultiResBlock(32*8, pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(mresblock4)
    mresblock4 = ResPath(32*8, 1, mresblock4)

    mresblock5 = MultiResBlock(32*16, pool4)

    up6 = concatenate([Conv2DTranspose(
        32*8, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock4], axis=3)
    mresblock6 = MultiResBlock(32*8, up6)

    up7 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock3], axis=3)
    mresblock7 = MultiResBlock(32*4, up7)

    up8 = concatenate([Conv2DTranspose(
        32*2, (2, 2), strides=(2, 2), padding='same')(mresblock7), mresblock2], axis=3)
    mresblock8 = MultiResBlock(32*2, up8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(
        2, 2), padding='same')(mresblock8), mresblock1], axis=3)
    mresblock9 = MultiResBlock(32, up9)

    conv10 = conv2d_bn(mresblock9, 1, 1, 1, activation='sigmoid') ## sigmoid
    
    model = keras.models.Model(inputs=[inputs], outputs=[conv10])

    return model

# Losses and Metrics
Loss Functions and Metrics

In [None]:
epsilon = 1e-5
smooth = 1
ALPHA = 0.5
BETA = 0.5
GAMMA = 1

def dsc(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dsc(y_true, y_pred)
    return loss

def log_cosh_dice_loss(y_true, y_pred):
        x = dice_loss(y_true, y_pred)
        return tf.math.log((tf.exp(x) + tf.exp(-x)) / 2.0)




def FocalTverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, gamma=GAMMA, smooth=1e-6):
    
        #flatten label and prediction tensors
        inputs = K.flatten(inputs)
        targets = K.flatten(targets)
        
        #True Positives, False Positives & False Negatives
        TP = K.sum((inputs * targets))
        FP = K.sum(((1-targets) * inputs))
        FN = K.sum((targets * (1-inputs)))
               
        Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
        FocalTversky = K.pow((1 - Tversky), gamma)
        
        return FocalTversky

**Metric Functions**

In [None]:
def dice_coef(y_true, y_pred):
    smooth = 0.0
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def jaccard(y_true, y_pred):

    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum ( y_true_f * y_pred_f)
    union = K.sum ( y_true_f + y_pred_f - y_true_f * y_pred_f)

    return intersection/union

def sensitivity(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        return true_positives / (possible_positives + K.epsilon())

def specificity( y_true, y_pred):
        true_negatives = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
        possible_negatives = K.sum(K.round(K.clip(1 - y_true, 0, 1)))
        return true_negatives / (possible_negatives + K.epsilon())

# Parameters


In [None]:
#@title Preprocess Parameters
BATCH_SIZE = 4        #@param {type:"integer"}
IMG_SIZE = 256        #@param {type:"integer"}
SEED = 727            #@param {type:"integer"}
test_split_size = 0.2 #@param {type:"number"}

In [None]:
#@title Augmentation Parameters
rotation_range = 135#@param {type:"slider",min:0,max:360,step:45}
horizontal_flip = True#@param{type:"boolean"}
vertical_flip = True#@param{type:"boolean"}
width_shift_range=0.2 #@param {type:"number"}
height_shift_range=0.2 #@param {type:"number"}

In [None]:
#@title Compile Parameters
learning_rate = 1e-2#@param {type:"number"}
loss_function = "dice_loss" #@param ["sparse_categorical_crossentropy","FocalTverskyLoss","dice_loss","log_cosh_dice_loss"]





In [None]:
#@title Callback Parameters
reduce_lr_factor = 0.5 #@param {type:"number"}
reduce_lr_monitor = 'val_loss'#@param ["val_loss","val_jaccard","val_dice_coef","val_sensitivity","val_specificity"]
reduce_patience = 2 #@param {type:"integer"}
reduce_min_lr = 1e-5 #@param {type:"number"}


checkpoint_path = "/content/drive/MyDrive/histopatoloji/auto_model_dummy" #@param ["/content/drive/MyDrive/histopatoloji/Model","/content/drive/MyDrive/histopatoloji/auto_model_dummy"]
checkpoint_monitor = 'val_loss'#@param ["val_loss","val_jaccard","val_dice_coef","val_sensitivity","val_specificity"]
checkpoint_verbose = 1#@param [0,1,2]
checkpoint_save_best_only = True #@param {type:"boolean"}
checkpoint_save_weights_only=False #@param {type:"boolean"}
checkpoint_mode='auto' #@param ["auto","min","max"]

early_stop_monitor = 'val_loss'#@param ["val_loss","val_jaccard","val_dice_coef","val_sensitivity","val_specificity"]
early_stop_patience=3 #@param type:"integer"
early_stop_verbose = 0 #@param [0,1,2]
eary_stop_mode = 'min' #@param ["auto","min","max"]


In [None]:
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor=reduce_lr_monitor,
                                              factor=reduce_lr_factor ,
                                              patience=reduce_patience,
                                              min_lr=reduce_min_lr)

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath =checkpoint_path,
    monitor=checkpoint_monitor,
    verbose=checkpoint_verbose,
    save_best_only=checkpoint_save_best_only,
    save_weights_only=checkpoint_save_weights_only,
    mode=checkpoint_mode,
    save_freq='epoch',
    options=None,
    initial_value_threshold=None,
)

early_stop = keras.callbacks.EarlyStopping(
    monitor=early_stop_monitor,
    min_delta=0,
    patience=early_stop_patience,
    verbose=early_stop_verbose,
    mode=eary_stop_mode,
    baseline=None,
    restore_best_weights=False
)

callbacks = [reduce_lr,checkpoint,early_stop] # earlystop

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only use the first GPU
  try:
    tf.config.set_visible_devices(gpus[0], 'GPU')
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
  except RuntimeError as e:
    # Visible devices must be set before GPUs have been initialized
    print(e)

# ABSOLUTE
Main function to execute all at once.

In [None]:
main(dataset_folder = '/content/drive/MyDrive/histopatoloji/ev_crop_3') ## my_data

**Prediction**

In [None]:
custom_objects = {'FocalTverskyLoss':FocalTverskyLoss,
                  'log_cosh_dice_loss':log_cosh_dice_loss,
                  'sparse_categorical_crossentropy': keras.metrics.SparseCategoricalCrossentropy(from_logits = True),
                  'dice_loss':dice_loss,
                  'jaccard':jaccard,
                  'dice_coef':dice_coef,
                  'sensitivity':sensitivity,
                  'specificity':specificity}
loaded_model = keras.models.load_model('/content/drive/MyDrive/histopatoloji/auto_model_dummy',custom_objects = custom_objects)

In [None]:
## read data for prediction
img,label,dataset_size = read_data('/content/drive/MyDrive/histopatoloji/buse')
label = rgb2gray(label)
img_shape = img.shape
label_shape = label.shape

img_patch_shape, image_patches = patch_data(img,(256,256,3),256) # use for reconstruction
label_patch_shape, label_patches = patch_data(label,(256,256),256)
label_patches = np.where(label_patches<50,label_patches,255.)

del img,label
assert label_patches.shape[:] == image_patches.shape[:-1]

In [None]:
## evaluate your data
evaluation = loaded_model.evaluate(image_patches,label_patches,batch_size = 4)

In [None]:
## predict your data
predictions = loaded_model.predict(image_patches,batch_size = 4, verbose = 1)
predictions = np.squeeze(predictions,axis = -1)

In [None]:
## list of patches that applied mask
masks = get_masks(label_patches = label_patches,mask_mean=25)

Plot predictions

In [None]:
reconstructed_image = reconstruct_patch(image_patches,img_patch_shape)

In [None]:
plt.imshow(reconstructed_image)

In [None]:
reconstructed_pred = reconstruct_patch(predictions,label_patch_shape)

In [None]:
plt.imshow(reconstructed_pred,cmap='gray')

In [None]:
del image_patches,label_patches,reconstructed_pred,predictions

# Driver Functions
Test section to see functions work. (Do not run)

In [None]:
img,label,dataset_size = read_data() ## Sample image and label

In [None]:
label = rgb2gray(label)

In [None]:
image_patches = patch_data(img,(256,256,3),256)

In [None]:
label_patches = patch_data(label,(256,256),256)
label_patches = np.where(label_patches<50,label_patches,255.)

In [None]:
del img,label

In [None]:
assert label_patches.shape[:] == image_patches.shape[:-1]

**Preprocess**

In [None]:
x_train,x_test,y_train,y_test = split(image_patches,
                                      label_patches,
                                      test_size = test_split_size)

In [None]:
datagen = ImageDataGenerator(rescale=1./255,
                             rotation_range = rotation_range,
                             horizontal_flip =horizontal_flip,
                             vertical_flip = vertical_flip,
                             width_shift_range=width_shift_range,
                             height_shift_range=height_shift_range)

In [None]:
train_dataset_patch = datagen.flow(x = x_train,
                                   y = y_train,
                                   batch_size = BATCH_SIZE,
                                   seed = SEED)

valid_dataset_patch = datagen.flow(x = x_test,
                                   y = y_test,
                                   batch_size = BATCH_SIZE,
                                   seed = SEED)

**Graph of some masks**

In [None]:
masks = get_masks()

In [None]:
graphify(masks[10:15],fig_size = 14)

In [None]:
history = model.fit(train_dataset_patch,
                         batch_size = BATCH_SIZE,
                         epochs= 11,
                         callbacks = callbacks,
                         workers = -1,
                         validation_data = valid_dataset_patch,
                         validation_steps = (len(x_test)//BATCH_SIZE),
                         steps_per_epoch=(len(x_train)//BATCH_SIZE))