<a href="https://colab.research.google.com/github/ivalencius/classification-4u/blob/main/3D_semantic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Edited By Ilan Valencius (valencig@bc.edu)**
---
<font size = 4> **Some Notes**
* This particular notebook enables image segmentation of 3 channel RGB dataset. Based upon [ZeroCostDL4Mic UNet (2D) Multilabel Notebok](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)
* This notebook has been adapted for **local use** with the primary aim of training large datasets on local high-performance GPUs. Creating predictions and performing quality control are feasible on the Google Cloud notebook as they will not enouncter runtime timeouts imposed by Google Colab.
* Quick command for running locally after installing proper packages:
*`jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0`
---

In [None]:
#@title Run if using Google Drive
!pip install fpdf
!pip install segmentation-models
!pip install imagecodecs
import imagecodecs
from google.colab import drive
drive.mount('/content/drive')

# **1.3. Load key dependencies**
---
<font size = 4> 

<font size = 4>**`tmp_folder`:** This folder replaces the `/content` root folder found in Google colab. Temporary files will be saved here including images added to pdf training reports as well as created patches. **Files in this folder are not permanent and will be removed after running the script.**

In [None]:
#@markdown ### Path to temporary folder
import os
# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'
tmp_folder = '/content/temp'#@param {type: "string"}

if os.path.isdir(tmp_folder):
  print('Folder Exists')
else:
  print(bcolors.WARNING+'!!! Folder does not exist !!!')
  print('Creating Folder')
  os.mkdir(tmp_folder)

In [None]:
Notebook_version = '1.13'
Network = 'valencig_3d_ML'

!python --version

import sys
before = [str(m) for m in sys.modules]

#@markdown ##Load key U-Net dependencies

#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)
#only the data library needs to be additionally installed.
#%tensorflow_version 1.x
import tensorflow as tf
#import tensorflow_addons as tfa
# print(tensorflow.__version__)
# print("Tensorflow enabled.")


# Keras imports
from keras import models
from keras.models import Model, load_model
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
from keras.utils.vis_utils import plot_model
from tensorflow.keras.optimizers import Adam, SGD
#from keras.optimizers import Adam
# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints
from keras.callbacks import ModelCheckpoint
from keras.callbacks import ReduceLROnPlateau
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from keras import backend as keras
from keras.callbacks import Callback

# General import
from __future__ import print_function
import numpy as np
import pandas as pd
import os
import glob
from skimage import img_as_ubyte, io, transform
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.pyplot import imread
from pathlib import Path
import shutil
import random
import time
import csv
import sys
from math import ceil
#from fpdf import FPDF, HTMLMixin
from pip._internal.operations.freeze import freeze
import subprocess
# Imports for QC
from PIL import Image
from scipy import signal
from scipy import ndimage
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.preprocessing import normalize
from sklearn.utils import class_weight

# For sliders and dropdown menu and progress bar
from ipywidgets import interact
import ipywidgets as widgets
# from tqdm import tqdm
from tqdm.notebook import tqdm

from sklearn.feature_extraction import image
from skimage import img_as_ubyte, io, transform
from skimage.util.shape import view_as_windows

from datetime import datetime
import math
from fpdf import FPDF,HTMLMixin

import segmentation_models as sm
sm.set_framework('tf.keras')
from segmentation_models import Unet
from segmentation_models.losses import JaccardLoss, DiceLoss

# Suppressing some warnings
import warnings
warnings.filterwarnings('ignore')

def net3d(model_architecture, backbone, pretrained_weights = None, input_size = (256,256,1), learning_rate = 1e-4, verbose=True, labels=2, optimizer='Adam', imagenet_weights = True):
    # Get Model architecture
    if model_architecture == 'Unet':
      net = sm.Unet
    elif model_architecture == 'Linknet':
      net = sm.Linknet
    elif model_architecture == 'FPN':
      net = sm.FPN
    elif model_architecture == 'PSPNet':
      net = sm.PSPNet

    # Determine if backbone is pretrained
    if imagenet_weights:
      encoder_weights = True
    else:
      encoder_weights = None
    # define model
    model = net(backbone_name=backbone, 
                 encoder_weights='imagenet',
                input_shape=input_size,
                 classes = labels,
                 activation='softmax',
                 weights = pretrained_weights)

    # Change loss type
    loss = sm.losses.categorical_focal_jaccard_loss

    if optimizer == "Adam":
      opt = Adam(lr = learning_rate, decay = weight_decay)
      model.compile(opt, loss=sm.losses.categorical_focal_jaccard_loss, metrics=[sm.metrics.iou_score])
    elif optimizer == "SGD":
      # Uncomment to set up polynomial decay schedule
      '''schedule = tf.keras.optimizers.schedules.PolynomialDecay(
                    initial_learning_rate=0.01,
                    decay_steps = number_of_epochs,
                    power=0.9
                  )'''
      model.compile(optimizer = SGD(learning_rate=learning_rate, momentum = momentum, decay = weight_decay),loss = loss, metrics=[sm.metrics.iou_score])
      #model.compile(optimizer = SGD(learning_rate=learning_schedule, momentum = momentum, nesterov=False, decay = weight_decay, clipnorm=35), loss = weighted_binary_crossentropy(training_target))
    if(pretrained_weights):
    	model.load_weights(pretrained_weights);
    return model
# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. 
def buildDoubleGenerator3d(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):
  '''
  Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
  
  datagen: ImageDataGenerator 
  subset: can take either 'training' or 'validation'
  '''
  
  # Build the dict for the ImageDataGenerator
  # non_aug_args = dict(width_shift_range = 0,
  #                     height_shift_range = 0,
  #                     rotation_range = 0, #90
  #                     zoom_range = 0,
  #                     shear_range = 0,
  #                     horizontal_flip = False,
  #                     vertical_flip = False,
  #                     fill_mode = 'reflect')
  # default params of data generator is without augmentation
  '''mask_load_gen = ImageDataGenerator(dtype='uint8',
                                     preprocessing_function=preprocess_norm)
  image_load_gen = ImageDataGenerator(dtype='float64',
                                      preprocessing_function=preprocess_norm)'''
                                      
  image_generator = image_datagen.flow_from_directory(
        os.path.dirname(image_folder_path),
        classes = [os.path.basename(image_folder_path)],
        class_mode = None,
        color_mode = "rgb",
        #target_size = target_size,
        batch_size = batch_size,
        subset = subset,
        seed = 1)
  mask_generator = mask_datagen.flow_from_directory(
        os.path.dirname(mask_folder_path),
        classes = [os.path.basename(mask_folder_path)],
        class_mode = None,
        color_mode = "grayscale",
        #target_size = target_size,
        batch_size = batch_size,
        subset = subset,
        seed = 1)

  train_generator = zip(image_generator, mask_generator)
  for (img, mask) in train_generator:
      '''if (np.all(np.isnan(img))):
        print('Nan')
        continue'''
      img, mask = preprocess_data(img, mask, labels)
      yield (img, mask)

def weighted_binary_crossentropy(class_weights):

    def _weighted_binary_crossentropy(y_true, y_pred):
        binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)
        weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]
        weighted_binary_crossentropy = weight_vector * binary_crossentropy

        return keras.mean(weighted_binary_crossentropy)

    return _weighted_binary_crossentropy

def predict_as_tiles(Image_path, model, smoothing=True):

  # Read the data in and preprocess
  Image_raw = io.imread(Image_path)
  Image = preprocess_data(Image_raw)

  # Get the patch size from the input layer of the model
  #patch_size = model.layers[0].output_shape[0][1:3] # ADDED [0]
  patch_size = Input_size ## HARDCODED

  '''# Pad the image with zeros if any of its dimensions is smaller than the patch size
  if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:
    Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))
    Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw
  else:
    Image = Image_raw'''

  if smoothing:
        prediction_smooth = predict_img_with_smooth_windowing(
          Image,
          window_size=patch_size[0],
          subdivisions=2,  # Minimal amount of overlap for windowing. Must be an even number.
          nb_classes=labels,
          pred_func=(lambda img_batch_subdiv: model.predict((img_batch_subdiv)))
          )
        prediction = np.argmax(prediction_smooth, axis=2).astype('uint8')
  else: 
    # Calculate the number of patches in each dimension
    n_patch_in_width = ceil(Image.shape[0]/patch_size[0])
    n_patch_in_height = ceil(Image.shape[1]/patch_size[1])
    prediction = np.zeros((Image.shape[0], Image.shape[1]), dtype = 'uint8')
    for x in range(n_patch_in_width):
      for y in range(n_patch_in_height):
        xi = patch_size[0]*x
        yi = patch_size[1]*y

        # If the patch exceeds the edge of the image shift it back 
        if xi+patch_size[0] >= Image.shape[0]:
          xi = Image.shape[0]-patch_size[0]

        if yi+patch_size[1] >= Image.shape[1]:
          yi = Image.shape[1]-patch_size[1]
        
        # Extract and reshape the patch
        patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]
        patch = np.reshape(patch,(1,)+patch.shape)

        # Get the prediction from the patch and paste it in the prediction in the right place
        predicted_patch = model.predict(patch, batch_size = 1)
        pred_patch = np.argmax(np.squeeze(predicted_patch), axis=2)
        prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = pred_patch.astype(np.uint8)
      prediction = prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]
  return prediction

def save_augment(datagen,orig_img,dir_augmented_data=tmp_folder+"/augment"):
  """
  Saves a subset of the augmented data for visualisation, by default in tmp_folder.

  This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html
  
  """
  try:
    os.mkdir(dir_augmented_data)
  except:
        ## if the preview folder exists, then remove
        ## the contents (pictures) in the folder
    for item in os.listdir(dir_augmented_data):
      os.remove(dir_augmented_data + "/" + item)

    ## convert the original image to array
  x = img_to_array(orig_img)
    ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B
    #print(x.shape)
  x = x.reshape((1,) + x.shape)
    #print(x.shape)
    ## -------------------------- ##
    ## randomly generate pictures
    ## -------------------------- ##
  i = 0
    #We will just save 5 images,
    #but this can be changed, but note the visualisation in 3. currently uses 5.
  Nplot = 5
  for batch in datagen.flow(x,batch_size=1,
                            save_to_dir=dir_augmented_data,
                            save_format='tif',
                            seed=42):
    i += 1
    if i > Nplot - 1:
      break      

# Normalization functions from Martin Weigert
def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =                   (x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x



# Simple normalization to min/max for the Mask
def normalizeMinMax(x, dtype=np.float32):
  x = x.astype(dtype,copy=False)
  x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))
  return x


def create_patches3d(source, target, patch_width, patch_height, min_fraction, training=True):
  """
  Function creates patches from the Training_source and Training_target images. 
  The steps parameter indicates the offset between patches and, if integer, is the same in x and y.
  Saves all created patches in two new directories in the tmp_folder folder.

  Returns: - Two paths to where the patches are now saved
  """
  if training:
    print(os.path.join(source, tmp_folder,'training_img_patches'))
    Patch_source = os.path.join(source, tmp_folder,'training_img_patches')
    Patch_target = os.path.join(target, tmp_folder,'training_mask_patches')
    Patch_rejected = os.path.join(source, tmp_folder,'training_rejected')
  else:
    print(os.path.join(source, tmp_folder,'val_img_patches'))
    Patch_source = os.path.join(source, tmp_folder,'val_img_patches')
    Patch_target = os.path.join(target, tmp_folder,'val_mask_patches')
    Patch_rejected = os.path.join(source, tmp_folder,'val_rejected')

  #Here we save the patches, in the tmp_folder directory as they will not usually be needed after training
  if os.path.exists(Patch_source):
    shutil.rmtree(Patch_source)
  if os.path.exists(Patch_target):
    shutil.rmtree(Patch_target)
  if os.path.exists(Patch_rejected):
    shutil.rmtree(Patch_rejected)

  os.mkdir(Patch_source)
  os.mkdir(Patch_target)
  os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.

  patch_num = 0

  for file in tqdm(os.listdir(source)):
    img = io.imread(os.path.join(source, file))
    mask = io.imread(os.path.join(target, file), as_gray=True)
    x, y, z = img.shape
    assert((x,y)== mask.shape)
    x_max = math.floor(x/patch_width)*patch_width
    y_max = math.floor(y/patch_height)*patch_height
    img = img[0:x_max, 0:y_max]
    mask = mask[0:x_max, 0:y_max]

    # Using view_as_windows with step size equal to the patch size to ensure there is no overlap
    patches_img = view_as_windows(img, (patch_width, patch_height,3),(patch_width, patch_height,3))
    patches_mask = view_as_windows(mask, (patch_width, patch_height),(patch_width, patch_height))
    
    patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height,3)
    patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)

    for i in range(patches_img.shape[0]):
      img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')
      mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')
      patch_num += 1

    
      #io.imsave(img_save_path, img_as_ubyte(patches_img[i]), photometric='rgb)
      io.imsave(img_save_path, img_as_ubyte(patches_img[i]), photometric='rgb')
      io.imsave(mask_save_path, patches_mask[i])


  return Patch_source, Patch_target

# Custom callback showing sample prediction
class SampleImageCallback(Callback):

    def __init__(self, model, sample_data, model_path, save=False):
        self.model = model
        self.sample_data = sample_data
        self.model_path = model_path
        self.save = save

    def on_epoch_end(self, epoch, logs={}):
      if np.mod(epoch,1) == 0:
            sample_predict = self.model.predict_on_batch(self.sample_data)

            f=plt.figure(figsize=(16,10))
            plt.subplot(1,2,1)
            plt.imshow(self.sample_data[0,:,:,:])
            plt.title('Sample source (not normalized)')
            plt.axis('off');
            '''for i in range(1, labels):
              plt.subplot(1,labels+1,i+1)
              plt.imshow(sample_predict[0,:,:,i], interpolation='nearest', cmap='magma')
              plt.title('Predicted label {}'.format(i))
              plt.axis('off');'''

            plt.subplot(1,2,2)
            plt.imshow(np.squeeze(np.argmax(sample_predict[0], axis=-1)), interpolation='nearest')
            plt.title('Semantic segmentation')
            plt.axis('off');

            plt.show()

            if self.save:
                plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')
                random_choice = random.choice(os.listdir(Patch_source))


def saveResult(save_path, nparray, source_dir_list, prefix='',file_ext='.tif'):
  for (filename, image) in zip(source_dir_list, nparray):
      io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+file_ext), image) # saving as unsigned 8-bit image


def convert2Mask(image, threshold):
  mask = img_as_ubyte(image, force_copy=True)
  mask[mask > threshold] = 255
  mask[mask <= threshold] = 0
  return mask

# -------------- Other definitions -----------
W  = '\033[0m'  # white (normal)
R  = '\033[31m' # red
prediction_prefix = 'Predicted_'


print('-------------------')
print('U-Net and dependencies installed.')

def pdf_export(trained = False, augmentation = False, pretrained_model = False):
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B') 

  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Training report for '+Network+' model ('+model_name+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L') 
    
  # add another cell 
  if trained:
    training_time = "Training time: "+str(hour)+ "hour(s) "+str(mins)+"min(s) "+str(round(sec))+"sec(s)"
    pdf.cell(190, 5, txt = training_time, ln = 1, align='L')
  pdf.ln(1)

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '
  #print(all_packages)

  #Main Packages
  main_packages = ''
  version_numbers = []
  for name in ['tensorflow','numpy','Keras']:
    find_name=all_packages.find(name)
    main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '
    #Version numbers only here:
    version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])

  cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)
  cuda_version = cuda_version.stdout.decode('utf-8')
  cuda_version = cuda_version[cuda_version.find(', V')+3:-1]
  gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)
  gpu_name = gpu_name.stdout.decode('utf-8')
  gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]
  #print(cuda_version[cuda_version.find(', V')+3:-1])
  #print(gpu_name)
  loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]
  shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape
  dataset_size = len(os.listdir(Training_source))

  text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+'. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

  if pretrained_model:
    text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+'. The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.multi_cell(180, 5, txt = text, align='L')
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(1)
  pdf.cell(28, 5, txt='Augmentation: ', ln=1)
  pdf.set_font('')
  if Use_Data_augmentation:
    aug_text = 'The dataset was augmented by'
    if rotation_range != 0:
      aug_text = aug_text+'\n- rotation'
    if horizontal_flip == True or vertical_flip == True:
      aug_text = aug_text+'\n- flipping'
    if zoom_range != 0:
      aug_text = aug_text+'\n- random zoom magnification'
    if horizontal_shift != 0 or vertical_shift != 0:
      aug_text = aug_text+'\n- shifting'
    if shear_range != 0:
      aug_text = aug_text+'\n- image shearing'
  else:
    aug_text = 'No augmentation was used for training.'
  pdf.multi_cell(190, 5, txt=aug_text, align='L')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.cell(200, 5, txt='The following parameters were used for training:')
  pdf.ln(1)
  html = """ 
  <table width=40% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>number_of_epochs</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>patch_size</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>batch_size</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>number_of_steps</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>percentage_validation</td>
      <td width = 50%>{4}</td>
    </tr>
    <tr>
      <td width = 50%>initial_learning_rate</td>
      <td width = 50%>{5}</td>
    </tr>
    <tr>
      <td width = 50%>min_fraction</td>
      <td width = 50%>{6}</td>
    </tr>
    <tr>
      <td width = 50%>model_architecture</td>
      <td width = 50%>{7}</td>
    </tr>
    <tr>
      <td width = 50%>backbone</td>
      <td width = 50%>{8}</td>
  </table>
  """.format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, min_fraction, model_architecture, backbone)
  pdf.write_html(html)

  #pdf.multi_cell(190, 5, txt = text_2, align='L')
  pdf.set_font("Arial", size = 11, style='B')
  pdf.ln(1)
  pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Training_source, align = 'L')
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Training_target, align = 'L')
  #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')
  pdf.ln(1)
  pdf.cell(60, 5, txt = 'Example Training pair', ln=1)
  pdf.ln(1)
  exp_size = io.imread(tmp_folder+"/TrainingDataExample_Unet2D.png").shape
  pdf.image(tmp_folder+'/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(3)
  pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')

  print('------------------------------')
  print('PDF report exported in '+model_path+'/'+model_name+'/')

def qc_pdf_export():
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B') 

  Network = 'Unet 2D'

  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L') 

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '

  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(2)
  pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')
  pdf.ln(1)
  exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape
  if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):
    pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))
  else:
    pdf.set_font('')
    pdf.set_font('Arial', size=10)
    pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')
  pdf.ln(2)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(3)
  pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)
  pdf.ln(1)
  exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape
  pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)

  pdf.ln(1)
  html = """
  <body>
  <font size="10" face="Courier New" >
  <table width=60% style="margin-left:0px;">"""
  with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:
    metrics = csv.reader(csvfile)
    header = next(metrics)
    image = header[0]
    IoU = header[-1]
    header = """
    <tr>
    <th width = 33% align="center">{0}</th>
    <th width = 33% align="center">{1}</th>
    </tr>""".format(image,IoU)
    html = html+header
    i=0
    for row in metrics:
      i+=1
      image = row[0]
      IoU = row[-1]
      cells = """
        <tr>
          <td width = 33% align="center">{0}</td>
          <td width = 33% align="center">{1}</td>
        </tr>""".format(image,str(round(float(IoU),3)))
      html = html+cells
    html = html+"""</body></table>"""
    
  pdf.write_html(html)

  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')

  pdf.ln(3)
  reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'

  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')

  pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')

  print('------------------------------')
  print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')

# **2. Complete the Colab session**




---







## **2.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelerator: GPU** *(Graphics processing unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access

if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

# print the tensorflow version
print('Tensorflow version is ' + str(tf.__version__))


# **3. Select your parameters and paths**
---

## **3.1. Setting main training parameters**
---
<font size = 4> 

<font size = 5> **Paths for training data and models**

<font size = 4>**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. RGBimages) and target files (semantic segmentation masks).  The mask should be a unique 2D image with values 0, 1, 2, ... each of them corresponding to a semantic definition of the content in the image. The values should be ordered from the lowest to the highest and without missing any value in between (unless it is missing in the image). Enter the path to the source and target images for training. **These should be located in the same parent folder.**

<font size = 4>**`Validation_source`, `Validation_target`:** Similar to `training_source` and `training_target` except data in these folders will be used as a validation dataset during training. **These should be located in the same parent folder.**

<font size = 4>**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).

<font size = 4>**`model_path`**: Enter the path of the folder where you want to save your model.

<font size = 4>**`labels`**: The number of different labels that the network needs to learn, which also includes the background. For example: to segment two different kind of objects in an image (cats and dogs), labels = 3 (2 labels for the two kinds and one more label for the background=0).


<font size = 5> **Select training parameters**

<font size = 4>**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 20**

<font size = 5>**Advanced parameters - experienced users only**

<font size = 4>**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 16**

<font size = 4>**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.

<font size = 4>**`patch_width`** and **`patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. Larger patches than 512x512 should **NOT** be selected for network stability.

<font size = 4>**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)


In [None]:
# ------------- Initial user input ------------

#@markdown ###Path to training images:

Training_source = 'C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Data\\train_source' #@param {type:"string"}
Training_target = 'C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Data\\train_target' #@param {type:"string"}

Validation_source = 'C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Data\\validation_source' #@param {type:"string"}
Validation_target = 'C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Data\\validation_target' #@param {type:"string"}


model_name = 'landcoverai-1' #@param {type:"string"}
model_path = 'C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Models' #@param {type:"string"}


#@markdown ###Classes
labels =  5#@param {type:"number"}

#@markdown ###Training parameters:
number_of_epochs =  15#@param {type:"number"}

batch_size =  16#@param {type:"integer"}
number_of_steps =  0#@param {type:"number"}
 

#@markdown ###Image Properties
patch_width =  256#@param{type:"number"}
patch_height =  256#@param{type:"number"}
min_fraction = 0.02#@param{type:"number"}

# ------------- Initialising folder, variables and failsafes ------------
#  Create the folders where to save the model and the QC
full_model_path = os.path.join(model_path, model_name)
if os.path.exists(full_model_path):
  print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)


#The create_patches function will create the two folders below
# Patch_source = 'tmp_folder/img_patches'
# Patch_target = 'tmp_folder/mask_patches'
print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')

#Create patches

print('Creating patches...')
Train_source, Train_target = create_patches3d(Training_source, Training_target, patch_width, patch_height, min_fraction, training=True)
Val_source, Val_target = create_patches3d(Validation_source, Validation_target, patch_width, patch_height, min_fraction, training=False)

number_of_training_dataset = len(os.listdir(Train_source))
number_of_validation_dataset = len(os.listdir(Val_source))
print('Total number of valid training patches: '+str(number_of_training_dataset))
print('Total number of valid validation patches: '+str(number_of_validation_dataset))
percentage_validation = number_of_validation_dataset/(number_of_training_dataset+number_of_validation_dataset)
print('Validation percentage: '+str(percentage_validation)))

if number_of_steps == 0:
  number_of_steps = ceil(number_of_training_dataset/batch_size)
print('Number of training steps: '+str(number_of_steps))

# Calculate the number of steps to use for validation
validation_steps = int(number_of_validation_dataset/batch_size)
print('Number of validation steps: %d'%validation_steps)
# Here we disable pre-trained model by default (in case the next cell is not ran)
Use_pretrained_model = False
# Here we disable data augmentation by default (in case the cell is not ran)
Use_Data_augmentation = False
# Build the default dict for the ImageDataGenerator
'''data_gen_args = dict(width_shift_range = 0.,
                     height_shift_range = 0.,
                     rotation_range = 0., #90
                     zoom_range = 0.,
                     shear_range = 0.,
                     horizontal_flip = False,
                     vertical_flip = False,
                     validation_split = percentage_validation/100,
                     fill_mode = 'reflect')'''

# ------------- Display ------------

#if not os.path.exists('tmp_folder/img_patches/'):
random_choice = random.choice(os.listdir(Train_source))
x = io.imread(os.path.join(Train_source, random_choice))
print(x.dtype)

#os.chdir(Training_target)
y = io.imread(os.path.join(Train_target, random_choice), as_gray=True)
f=plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x)
plt.title('Training image patch')
plt.axis('off');

plt.subplot(1,2,2)
plt.imshow(y)
plt.title('Training mask patch')
plt.axis('off');

plt.savefig(tmp_folder+'\\TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)



In [None]:
#@markdown ###Run to visualize images in the training dataset
#Patch_source, Patch_target = create_patches3d(Training_source, Training_target, patch_width, patch_height, min_fraction)
random_choice = random.choice(os.listdir(Train_source))
x = io.imread(os.path.join(Train_source, random_choice))

#os.chdir(Training_target)
y = io.imread(os.path.join(Train_target, random_choice))

f=plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x, interpolation='nearest')
plt.title('Training image patch')
plt.axis('off');

plt.subplot(1,2,2)
plt.imshow(y, interpolation='nearest')

plt.title('Training mask patch')
plt.axis('off');

plt.savefig(tmp_folder+'\\TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)


##**3.2. Data augmentation**

---

<font size = 4> Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.

<font size = 4> The augmentation options below are to be used as follows:

* <font size = 4> **`shift`**: a translation of the image by a fraction of the image size (width or height), **default: 10%**
* **`zoom_range`**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**
* **`shear_range`**: Shear angle in counter-clockwise direction, **default: 10%**
* **`flip`**: creating a mirror image along specified axis (horizontal or vertical), **default: True**
* **`rotation_range`**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**

**NOTE** To preserve the spatial properties of satellite imagery, `shear range` and `zoom range` should not be used (they are more practical in microscopy where cells can be stretched and sheared). 

In [None]:
#@markdown ##**Augmentation options**
Use_Data_augmentation = False #@param {type:"boolean"}
Use_Default_Augmentation_Parameters = True #@param {type:"boolean"}

if Use_Data_augmentation:
  if Use_Default_Augmentation_Parameters:
    horizontal_shift =  10 
    vertical_shift =  20 
    zoom_range =  10
    shear_range =  10
    horizontal_flip = True
    vertical_flip = True
    rotation_range =  180
#@markdown ###If you are not using the default settings, please provide the values below:

#@markdown ###**Image shift, zoom, shear and flip (%)**
  else:
    horizontal_shift =  7 #@param {type:"slider", min:0, max:100, step:1}
    vertical_shift =  5 #@param {type:"slider", min:0, max:100, step:1}
    zoom_range =  6 #@param {type:"slider", min:0, max:100, step:1}
    shear_range =  3 #@param {type:"slider", min:0, max:100, step:1}
    horizontal_flip = True #@param {type:"boolean"}
    vertical_flip = True #@param {type:"boolean"}

#@markdown ###**Rotate image within angle range (degrees):**
    rotation_range =  90 #@param {type:"slider", min:0, max:180, step:1}

#given behind the # are the default values for each parameter.

else:
  horizontal_shift =  0 
  vertical_shift =  0 
  zoom_range =  0
  shear_range =  0
  horizontal_flip = False
  vertical_flip = False
  rotation_range =  0


# Build the dict for the ImageDataGenerator
data_gen_args = dict(width_shift_range = horizontal_shift/100.,
                     height_shift_range = vertical_shift/100.,
                     rotation_range = rotation_range, #90
                     zoom_range = zoom_range/100.,
                     shear_range = shear_range/100.,
                     horizontal_flip = horizontal_flip,
                     vertical_flip = vertical_flip,
                     fill_mode = 'reflect')



# ------------- Display ------------
dir_augmented_data_imgs=tmp_folder+"/augment_img"
dir_augmented_data_masks=tmp_folder+"/augment_mask"
random_choice = random.choice(os.listdir(Train_source))
orig_img = io.imread(os.path.join(Train_source,random_choice))
orig_mask = io.imread(os.path.join(Train_target,random_choice))

augment_view = ImageDataGenerator(**data_gen_args)

if Use_Data_augmentation:
  print("Parameters enabled")
  print("Here is what a subset of your augmentations looks like:")
  save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)
  save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)

  fig = plt.figure(figsize=(15, 7))
  fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)

 
  ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[])        
  #new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))
  ax.imshow(orig_img)
  ax.set_title('Original Image')
  i = 2
  for imgnm in os.listdir(dir_augmented_data_imgs):
    ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) 
    img = io.imread(dir_augmented_data_imgs + "/" + imgnm)
    ax.imshow(img)
    i += 1

  ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[])        
  #new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))
  ax.imshow(orig_mask)
  ax.set_title('Original Mask')
  j=2
  for imgnm in os.listdir(dir_augmented_data_masks):
    ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) 
    mask = io.imread(dir_augmented_data_masks + "/" + imgnm)
    ax.imshow(mask)
    j += 1
  plt.show()

else:
  print(bcolors.WARNING+"No augmentation will be used")

  


## **3.3. Using weights from a pre-trained model as initial weights**
---
<font size = 4>  Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. 

<font size = 4> This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.

<font size = 4> In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. 

In [None]:
# @markdown ##Loading weights from a pre-trained network

Use_pretrained_model = False #@param {type:"boolean"}
pretrained_model_choice = "Model_from_file" #@param ["Model_from_file"]
Weights_choice = "best" #@param ["last", "best"]


#@markdown ###If you chose "Model_from_file", please provide the path to the model folder:
pretrained_model_path = "F:\\Snyder_UNet_spring_2022\\UNet_data\\LoveDa\\models\\LoveDA_2_21_22" #@param {type:"string"}

# --------------------- Check if we load a previously trained model ------------------------
if Use_pretrained_model:

# --------------------- Load the model from the choosen path ------------------------
  if pretrained_model_choice == "Model_from_file":
    h5_file_path = os.path.join(pretrained_model_path, "weights_"+Weights_choice+".hdf5")


# --------------------- Download the a model provided in the XXX ------------------------

  if pretrained_model_choice == "Model_name":
    pretrained_model_name = "Model_name"
    pretrained_model_path = "tmp_folder/"+pretrained_model_name
    print("Downloading the UNET_Model_from_")
    if os.path.exists(pretrained_model_path):
      shutil.rmtree(pretrained_model_path)
    os.makedirs(pretrained_model_path)
    wget.download("", pretrained_model_path)
    wget.download("", pretrained_model_path)
    wget.download("", pretrained_model_path)    
    wget.download("", pretrained_model_path)
    h5_file_path = os.path.join(pretrained_model_path, "weights_"+Weights_choice+".hdf5")

# --------------------- Add additional pre-trained models here ------------------------



# --------------------- Check the model exist ------------------------
# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, 
  if not os.path.exists(h5_file_path):
    print(R+'WARNING: pretrained model does not exist')
    Use_pretrained_model = False
    

# If the model path contains a pretrain model, we load the training rate, 
  if os.path.exists(h5_file_path):
#Here we check if the learning rate can be loaded from the quality control folder
    if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):

      with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:
        csvRead = pd.read_csv(csvfile, sep=',')
        #print(csvRead)
    
        if "learning rate" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)
          print("pretrained network learning rate found")
          #find the last learning rate
          lastLearningRate = csvRead["learning rate"].iloc[-1]
          #Find the learning rate corresponding to the lowest validation loss
          min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]
          #print(min_val_loss)
          bestLearningRate = min_val_loss['learning rate'].iloc[-1]

          if Weights_choice == "last":
            print('Last learning rate: '+str(lastLearningRate))

          if Weights_choice == "best":
            print('Learning rate of best validation loss: '+str(bestLearningRate))

        if not "learning rate" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead
          bestLearningRate = initial_learning_rate
          lastLearningRate = initial_learning_rate
          print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)

#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used
    if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):
      print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)
      bestLearningRate = initial_learning_rate
      lastLearningRate = initial_learning_rate


# Display info about the pretrained model to be loaded (or not)
if Use_pretrained_model:
  print('Weights found in:')
  print(h5_file_path)
  print('will be loaded prior to training.')

else:
  print(R+'No pretrained network will be used.')




# **4. Train the network**
---
####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. 

## **4.1. Prepare the training data and model for training**
---
<font size = 4>Here, we create the model architecture used for training. All model architectures are from the [segmentation_models Github](https://github.com/qubvel/segmentation_models). 

**Note:** Patch size of Linknet and FPN must be divisible by 32, PSPNet must be divisible by 6* downsample factor. See: [[here]](https://segmentation-models.readthedocs.io/en/latest/api.html).

**Note:** For tasks other than semantic segmentation, please change loss type from categorical focal jaccard loss.

---


<font size = 4> **`model_architecture`**: [documentation](https://segmentation-models.readthedocs.io/en/latest/api.html#) **default: Unet**

<font size = 4> **`backbone`**: name of classification model used as feature extractor to build segmentation model **default: resnet__**

<font size = 4>**`initial_learning_rate`:**  the initial value to be used as learning rate. Use a smaller learning rate for larger datasets to prevent overfitting . **Default value: 0.01**

<font size = 4> **`use_imagenet_weights`**: whether the backbone should be initialized with weights pretrained on the [imagenet dataset](https://www.image-net.org/)

<font size = 4> **`verbose`**: display model outline (text) **default:False**

<font size = 4> **`verbose_graph`**: display model outline (graph) **default:True**

<font size = 4> **`optimizer`**: method guding the training of the model. More info can be found [[here]](https://keras.io/api/optimizers/) **default:Adam**

<font size = 4> **`use_decay`**: minimized gradient weights to prevent overfitting. [[info]](https://towardsdatascience.com/this-thing-called-weight-decay-a7cd4bcfccab) **default: False**

<font size = 4> **`momentum`**: for stochastic gradient descent (SGD) optimizer; float hyperparameter >= 0 that accelerates gradient descent in the relevant direction and dampens oscillations **default:0**


In [None]:
#@markdown # 3D Model Configuration

model_architecture = 'Unet'#@param ["Unet","Linknet","PSPNet","FPN"]{type: "string"}
backbone = 'resnet50' #@param ['vgg16', 'vgg19', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152','seresnet18', 'seresnet34' ,'seresnet50' ,'seresnet101', 'seresnet152','resnext50', 'resnext101','seresnext50', 'seresnext101','senet154','densenet121','densenet169', 'densenet201','inceptionv3', 'inceptionresnetv2','mobilenet', 'mobilenetv2','efficientnetb0', 'efficientnetb1', 'efficientnetb2', 'efficientnetb3', 'efficientnetb4' ,'efficientnetb5' ,'efficientnetb6' ,'efficientnetb7']{type:'string'}
initial_learning_rate = 0.01 #@param {type:"number"}
use_imagnet_weights = True #@param [True, False]{type: "boolean"}
verbose = False #@param {type:"boolean"}
verbose_graph = False #@param {type:"boolean"}
#@markdown ### For other optimizers please alter the model code
optimizer = "Adam" #@param ["Adam", "SGD"]{type: "string"}
#@markdown ###If using weight decay
use_decay = False #@param [True, False]{type: "boolean"}
if use_decay:
  weight_decay = 0.01 #@param {type:"number"}
else:
  weight_decay = 0

#@markdown ###If using SGD
momentum= 0.9 #@param {type:"number"}

# ------------------ Set the generators, model and logger ------------------
# This will take the image size and set that as a patch size (arguable...)
# Read image size (without actuall reading the data)

#Define a function to perform additional preprocessing after datagen.
#For example, scale images, convert masks to categorical, etc. 

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
from tensorflow.keras.utils import to_categorical
preprocess_input = sm.get_preprocessing(backbone)
def preprocess_data(img, mask, num_class):
    #Scale images
    img = scaler.fit_transform(img.reshape(-1, img.shape[-1])).reshape(img.shape)
    img = preprocess_input(img)  #Preprocess based on the pretrained backbone...
    #Convert mask to one-hot
    mask = to_categorical(mask, num_class)
    return (img,mask)

mask_datagen = ImageDataGenerator(dtype='uint8')
image_datagen = ImageDataGenerator(dtype='float64')
if Use_Data_augmentation:
  train_mask_datagen = ImageDataGenerator(**data_gen_args, dtype='uint8')
  train_image_datagen = ImageDataGenerator(**data_gen_args, dtype='float64')
  train_datagen = buildDoubleGenerator3d(train_image_datagen, 
                                       train_mask_datagen, 
                                       Train_source, 
                                       Train_target, 
                                       'training', 
                                       batch_size, 
                                       target_size=(patch_width, patch_height))
else:
  train_datagen = buildDoubleGenerator3d(image_datagen, 
                                       mask_datagen, 
                                       Train_source, 
                                       Train_target, 
                                       'training', 
                                       batch_size, 
                                       target_size=(patch_width, patch_height))
validation_datagen = buildDoubleGenerator3d(image_datagen, 
                                            mask_datagen, 
                                            Val_source, 
                                            Val_target, 
                                            'training', 
                                            batch_size, 
                                            target_size=(patch_width, patch_height))
#(train_datagen, validation_datagen) = prepareGenerators3d(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height), validatio_split = validatio_split)


# This modelcheckpoint will only save the best model from the validation loss point of view
model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)

# --------------------- Using pretrained model ------------------------
#Here we ensure that the learning rate set correctly when using pre-trained models
if Use_pretrained_model:
  if Weights_choice == "last":
    initial_learning_rate = lastLearningRate

  if Weights_choice == "best":            
    initial_learning_rate = bestLearningRate
else:
  h5_file_path = None
# --------------------- ---------------------- ------------------------

# --------------------- Reduce learning rate on plateau ------------------------

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, verbose=1, mode='auto',
                              patience=2, min_lr=0)
# --------------------- ---------------------- ------------------------

model = net3d(model_architecture,
               backbone,
               pretrained_weights = h5_file_path,
               input_size = (patch_width,patch_height,3), 
               learning_rate = initial_learning_rate, 
               labels = labels,
               optimizer = optimizer,
               verbose = verbose,
               imagenet_weights=use_imagnet_weights)
#unet.load_weights(h5_file_path) # ADDED
config_model= model.optimizer.get_config()
print(config_model)
if verbose:
    model.summary()
if verbose_graph:
  graph_file = os.path.join(full_model_path, model_architecture+'_'+backbone+'.png')
  plot_model(model, to_file=graph_file, show_shapes=True, show_layer_names=True, show_layer_activations=True)


## **4.2. Start Training**
---
<font size = 4>When playing the cell below you should see updates after each epoch (round). Network training can take some time.

<font size = 4>* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.

<font size = 4>Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.

In [None]:
#@markdown ## Rerun to visualize training patch shown in the sample image callback. After each epoch of training, this image will be segmented and logged to the console to provide a visual inspection of the model. 
#@markdown ## **!!! make it an interesting patch !!!**
random_choice = random.choice(os.listdir(Train_source))
x = io.imread(os.path.join(Train_source, random_choice))
#os.chdir(Training_target)
y = io.imread(os.path.join(Train_target, random_choice))
x2,y2 = preprocess_data(x,y, labels)
f=plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x)
plt.title('Training image patch (not normalized)')
plt.axis('off');

plt.subplot(1,2,2)
plt.imshow(y)
plt.title('Training mask patch')
plt.axis('off');
sample_batch = np.expand_dims(x, axis = 0)

sample_img = SampleImageCallback(model, sample_batch, os.path.join(model_path, model_name))

In [None]:
#@markdown ##Start training

start = time.time()
# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)
history = model.fit(train_datagen, 
                    steps_per_epoch = number_of_steps, 
                    epochs = number_of_epochs, 
                    callbacks=[model_checkpoint, sample_img], # deleted reduece_lr
                    validation_data = validation_datagen, 
                    #validation_steps = 3, 
                    validation_steps = validation_steps,
                    shuffle=True,
                    verbose=1)
# Save the last model
model.save(os.path.join(full_model_path, 'weights_last.hdf5'))


# convert the history.history dict to a pandas DataFrame:     
lossData = pd.DataFrame(history.history) 

# The training evaluation.csv is saved (overwrites the Files if needed). 

# Displaying the time elapsed for training
print("------------------------------------------")
dt = time.time() - start
mins, sec = divmod(dt, 60) 
hour, mins = divmod(mins, 60) 
print("Time elapsed:", hour, "hour(s)", mins,"min(s)",round(sec),"sec(s)")
print("------------------------------------------")
# The training evaluation.csv is saved (overwrites the Files if needed). 
if os.path.isdir(os.path.join(full_model_path, 'Quality Control')):
  print('QC Folder Exists')
else:
  print('Making QC Folder')
  os.mkdir(os.path.join(full_model_path, 'Quality Control'))
lossDataCSVpath = os.path.join(full_model_path,'Quality Control\\training_evaluation.csv')
with open(lossDataCSVpath, 'w',newline='') as f:
  writer = csv.writer(f)    
  writer.writerow(['loss','val_loss', 'iou_score', 'val_iou_score'])
  for i in range(len(history.history['loss'])):
    writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['iou_score'][i], history.history['val_iou_score'][i]])#, history.history['lr'][i]])
#Create a pdf document with training summary

pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)


# **5. Evaluate your model**
---

<font size = 4>This section allows the user to perform important quality checks on the validity and generalisability of the trained model. 

<font size = 4>**We highly recommend to perform quality control on all newly trained models.**



In [None]:
#@markdown ###Do you want to assess the model you just trained ?

Use_the_current_trained_model = False #@param {type:"boolean"}

#@markdown ###If not, please provide the path to the model folder:

QC_model_folder = "C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Models\\landcoverai-1" #@param {type:"string"}

#Here we define the loaded model name and path
QC_model_name = os.path.basename(QC_model_folder)
QC_model_path = os.path.dirname(QC_model_folder)


if (Use_the_current_trained_model): 
  print("Using current trained network")
  QC_model_name = model_name
  QC_model_path = model_path
else:
  # These are used in section 6
  model_name = QC_model_name
  model_path = QC_model_path

full_QC_model_path = os.path.join(QC_model_path, QC_model_name)
if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):
  print("The "+QC_model_name+" network will be evaluated")
else:
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')



## **5.1. Inspection of the loss function**
---

<font size = 4>First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.

<font size = 4>During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.

In [None]:
#@markdown ##Play the cell to show a plot of training errors vs. epoch number

epochNumber = []
lossDataFromCSV = []
vallossDataFromCSV = []

with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:
    csvRead = csv.reader(csvfile, delimiter=',')
    next(csvRead)
    next(csvRead) # needed when spaces between data lines in CSV file
    for row in csvRead:
      #print(row)
      lossDataFromCSV.append(float(row[0]))
      vallossDataFromCSV.append(float(row[1]))
      #next(csvRead) # needed when spaces between data lines in CSV file

epochNumber = range(len(lossDataFromCSV))

plt.figure(figsize=(15,10))

plt.subplot(2,1,1)
plt.plot(epochNumber,lossDataFromCSV, label='Training loss')
plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (linear scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()

plt.subplot(2,1,2)
plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')
plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (log scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()
plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)
plt.show()



## **5.2. Error mapping and quality metrics estimation**
---
<font size = 4>This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.

<font size = 4>The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** 

<font size = 4> To increase accuracy, smooth blending can be implemented according to methods described [here](https://github.com/Vooban/Smoothly-Blend-Image-Patches). This increases segmentation accuracy but **will be more time intesive**. This method also **requires square patches**

<font size = 4> The average IOU score is determined by weighting each class IOU score by the proportion of area it occupies in the image.

In [None]:
# MIT License
# Copyright (c) 2017 Vooban Inc.
# Coded by: Guillaume Chevalier
# Source to original code and license:
#     https://github.com/Vooban/Smoothly-Blend-Image-Patches
#     https://github.com/Vooban/Smoothly-Blend-Image-Patches/blob/master/LICENSE


"""Do smooth predictions on an image from tiled prediction patches."""

'''import pkg_resources
pkg_resources.require("numpy==`1.12.0")  # modified to use specific numpy
pkg_resources.require("scipy==`0.19.1")  # modified to use specific numpy'''
import numpy as np
import scipy.signal
from tqdm import tqdm

import gc
import gc

#@markdown ## Code for smoothing (check to see progress)
plot_progress = False #@param {type:"boolean"}
if plot_progress:
    import matplotlib.pyplot as plt
    

def _spline_window(window_size, power=2):
    """
    Squared spline (power=2) window function:
    https://www.wolframalpha.com/input/?i=y%3Dx**2,+y%3D-(x-2)**2+%2B2,+y%3D(x-4)**2,+from+y+%3D+0+to+2
    """
    intersection = int(window_size/4)
    wind_outer = (abs(2*(scipy.signal.triang(window_size))) ** power)/2
    wind_outer[intersection:-intersection] = 0

    wind_inner = 1 - (abs(2*(scipy.signal.triang(window_size) - 1)) ** power)/2
    wind_inner[:intersection] = 0
    wind_inner[-intersection:] = 0

    wind = wind_inner + wind_outer
    wind = wind / np.average(wind)
    return wind


cached_2d_windows = dict()
def _window_2D(window_size, power=2):
    """
    Make a 1D window function, then infer and return a 2D window function.
    Done with an augmentation, and self multiplication with its transpose.
    Could be generalized to more dimensions.
    """
    # Memoization
    global cached_2d_windows
    key = "{}_{}".format(window_size, power)
    if key in cached_2d_windows:
        wind = cached_2d_windows[key]
    else:
        wind = _spline_window(window_size, power)
        wind = np.expand_dims(np.expand_dims(wind, 1), 1)      #SREENI: Changed from 3, 3, to 1, 1 
        wind = wind * wind.transpose(1, 0, 2)
        if plot_progress:
            # For demo purpose, let's look once at the window:
            plt.imshow(wind[:, :, 0], cmap="viridis")
            plt.title("2D Windowing Function for a Smooth Blending of "
                      "Overlapping Patches")
            plt.show()
        cached_2d_windows[key] = wind
    return wind


def _pad_img(img, window_size, subdivisions):
    """
    Add borders to img for a "valid" border pattern according to "window_size" and
    "subdivisions".
    Image is an np array of shape (x, y, nb_channels).
    """
    aug = int(round(window_size * (1 - 1.0/subdivisions)))
    more_borders = ((aug, aug), (aug, aug), (0, 0))
    ret = np.pad(img, pad_width=more_borders, mode='reflect')
    # gc.collect()

    if plot_progress:
        # For demo purpose, let's look once at the window:
        plt.imshow(ret)
        plt.title("Padded Image for Using Tiled Prediction Patches\n"
                  "(notice the reflection effect on the padded borders)")
        plt.show()
    return ret


def _unpad_img(padded_img, window_size, subdivisions):
    """
    Undo what's done in the `_pad_img` function.
    Image is an np array of shape (x, y, nb_channels).
    """
    aug = int(round(window_size * (1 - 1.0/subdivisions)))
    ret = padded_img[
        aug:-aug,
        aug:-aug,
        :
    ]
    # gc.collect()
    return ret


def _rotate_mirror_do(im):
    """
    Duplicate an np array (image) of shape (x, y, nb_channels) 8 times, in order
    to have all the possible rotations and mirrors of that image that fits the
    possible 90 degrees rotations.
    It is the D_4 (D4) Dihedral group:
    https://en.wikipedia.org/wiki/Dihedral_group
    """
    mirrs = []
    mirrs.append(np.array(im))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=1))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=2))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=3))
    im = np.array(im)[:, ::-1]
    mirrs.append(np.array(im))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=1))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=2))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=3))
    return mirrs


def _rotate_mirror_undo(im_mirrs):
    """
    merges a list of 8 np arrays (images) of shape (x, y, nb_channels) generated
    from the `_rotate_mirror_do` function. Each images might have changed and
    merging them implies to rotated them back in order and average things out.
    It is the D_4 (D4) Dihedral group:
    https://en.wikipedia.org/wiki/Dihedral_group
    """
    origs = []
    origs.append(np.array(im_mirrs[0]))
    origs.append(np.rot90(np.array(im_mirrs[1]), axes=(0, 1), k=3))
    origs.append(np.rot90(np.array(im_mirrs[2]), axes=(0, 1), k=2))
    origs.append(np.rot90(np.array(im_mirrs[3]), axes=(0, 1), k=1))
    origs.append(np.array(im_mirrs[4])[:, ::-1])
    origs.append(np.rot90(np.array(im_mirrs[5]), axes=(0, 1), k=3)[:, ::-1])
    origs.append(np.rot90(np.array(im_mirrs[6]), axes=(0, 1), k=2)[:, ::-1])
    origs.append(np.rot90(np.array(im_mirrs[7]), axes=(0, 1), k=1)[:, ::-1])
    return np.mean(origs, axis=0)


def _windowed_subdivs(padded_img, window_size, subdivisions, nb_classes, pred_func):
    """
    Create tiled overlapping patches.
    Returns:
        5D numpy array of shape = (
            nb_patches_along_X,
            nb_patches_along_Y,
            patches_resolution_along_X,
            patches_resolution_along_Y,
            nb_output_channels
        )
    Note:
        patches_resolution_along_X == patches_resolution_along_Y == window_size
    """
    WINDOW_SPLINE_2D = _window_2D(window_size=window_size, power=2)

    step = int(window_size/subdivisions)
    padx_len = padded_img.shape[0]
    pady_len = padded_img.shape[1]
    subdivs = []

    for i in range(0, padx_len-window_size+1, step):
        subdivs.append([])
        for j in range(0, pady_len-window_size+1, step):            #SREENI: Changed padx to pady (Bug in original code)
            patch = padded_img[i:i+window_size, j:j+window_size, :]
            subdivs[-1].append(patch)

    # Here, `gc.collect()` clears RAM between operations.
    # It should run faster if they are removed, if enough memory is available.
    gc.collect()
    subdivs = np.array(subdivs)
    gc.collect()
    a, b, c, d, e = subdivs.shape
    subdivs = subdivs.reshape(a * b, c, d, e)
    gc.collect()

    subdivs = pred_func(subdivs)
    gc.collect()
    subdivs = np.array([patch * WINDOW_SPLINE_2D for patch in subdivs])
    gc.collect()

    # Such 5D array:
    subdivs = subdivs.reshape(a, b, c, d, nb_classes)
    gc.collect()

    return subdivs


def _recreate_from_subdivs(subdivs, window_size, subdivisions, padded_out_shape):
    """
    Merge tiled overlapping patches smoothly.
    """
    step = int(window_size/subdivisions)
    padx_len = padded_out_shape[0]
    pady_len = padded_out_shape[1]

    y = np.zeros(padded_out_shape)

    a = 0
    for i in range(0, padx_len-window_size+1, step):
        b = 0
        for j in range(0, pady_len-window_size+1, step):                #SREENI: Changed padx to pady (Bug in original code)
            windowed_patch = subdivs[a, b]
            y[i:i+window_size, j:j+window_size] = y[i:i+window_size, j:j+window_size] + windowed_patch
            b += 1
        a += 1
    return y / (subdivisions ** 2)


def predict_img_with_smooth_windowing(input_img, window_size, subdivisions, nb_classes, pred_func):
    """
    Apply the `pred_func` function to square patches of the image, and overlap
    the predictions to merge them smoothly.
    See 6th, 7th and 8th idea here:
    http://blog.kaggle.com/2017/05/09/dstl-satellite-imagery-competition-3rd-place-winners-interview-vladimir-sergey/
    """
    pad = _pad_img(input_img, window_size, subdivisions)
    pads = _rotate_mirror_do(pad)

    # Note that the implementation could be more memory-efficient by merging
    # the behavior of `_windowed_subdivs` and `_recreate_from_subdivs` into
    # one loop doing in-place assignments to the new image matrix, rather than
    # using a temporary 5D array.

    # It would also be possible to allow different (and impure) window functions
    # that might not tile well. Adding their weighting to another matrix could
    # be done to later normalize the predictions correctly by dividing the whole
    # reconstructed thing by this matrix of weightings - to normalize things
    # back from an impure windowing function that would have badly weighted
    # windows.

    # For example, since the U-net of Kaggle's DSTL satellite imagery feature
    # prediction challenge's 3rd place winners use a different window size for
    # the input and output of the neural net's patches predictions, it would be
    # possible to fake a full-size window which would in fact just have a narrow
    # non-zero dommain. This may require to augment the `subdivisions` argument
    # to 4 rather than 2.

    res = []
    #for pad in tqdm(pads):
    for pad in pads:
        # For every rotation:
        sd = _windowed_subdivs(pad, window_size, subdivisions, nb_classes, pred_func)
        one_padded_result = _recreate_from_subdivs(
            sd, window_size, subdivisions,
            padded_out_shape=list(pad.shape[:-1])+[nb_classes])

        res.append(one_padded_result)

    # Merge after rotations:
    padded_results = _rotate_mirror_undo(res)

    prd = _unpad_img(padded_results, window_size, subdivisions)

    prd = prd[:input_img.shape[0], :input_img.shape[1], :]

    if plot_progress:
        plt.imshow(np.argmax(prd, axis=2))
        plt.title("Smoothly Merged Patches that were Tiled Tighter")
        plt.show()
    return prd

In [None]:
# ------------- User input ------------
#@markdown #Choose the folders that contain your Quality Control dataset
Source_QC_folder = "C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Data\\test_source" #@param{type:"string"}
Target_QC_folder = "C:\\Users\\valencig\\Desktop\\Example_Seg_Pipeline\\Data\\test_target" #@param{type:"string"}

#@markdown ## File extension of quality control images
file_ext_source =  ".tif"#@param{type:"string"}
smoothing = False #@param{type:"boolean"}
#@markdown ## Image properties
labels =  5#@param{type:"number"}
patch_width =  256#@param{type:"number"}
patch_height =  256#@param{type:"number"}
weight_decay=0
# ------------- Initialise folders ------------
# Create a quality control/Prediction Folder
if smoothing:
  prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction_smooth')
else:
  prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')
if os.path.exists(prediction_QC_folder):
  shutil.rmtree(prediction_QC_folder)

os.makedirs(prediction_QC_folder)


# ------------- Prepare the model and run predictions ------------

# Load the model
if Use_the_current_trained_model == False:
  model_architecture = 'Unet'#@param ["Unet","Linknet","PSPNet","FPN"]{type: "string"}
  backbone = 'resnet50' #@param ['vgg16', 'vgg19', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152','seresnet18', 'seresnet34' ,'seresnet50' ,'seresnet101', 'seresnet152','resnext50', 'resnext101','seresnext50', 'seresnext101','senet154','densenet121','densenet169', 'densenet201','inceptionv3', 'inceptionresnetv2','mobilenet', 'mobilenetv2','efficientnetb0', 'efficientnetb1', 'efficientnetb2', 'efficientnetb3', 'efficientnetb4' ,'efficientnetb5' ,'efficientnetb6' ,'efficientnetb7']{type:'string'}
  model = net3d(model_architecture,
               backbone,
               pretrained_weights = os.path.join(full_QC_model_path, 'weights_best.hdf5'),
               input_size = (patch_width,patch_height,3), 
               learning_rate = 0.01, 
               labels = labels,
               optimizer = 'Adam',
               verbose = False,
               imagenet_weights=False)
###unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})
# load from sm package
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
from tensorflow.keras.utils import to_categorical
preprocess_input = sm.get_preprocessing(backbone)
def preprocess_data(img):
    #Scale images
    img = scaler.fit_transform(img.reshape(-1, img.shape[-1])).reshape(img.shape)
    img = preprocess_input(img)  #Preprocess based on the pretrained backbone...
    return img

Input_size = [patch_height, patch_width]
print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))
# Create a list of sources
source_dir_list = os.listdir(Source_QC_folder)
number_of_dataset = len(source_dir_list)
print('Number of dataset found in the folder: '+str(number_of_dataset))


### Save Predictions To List ###
predictions = []
for i in tqdm(range(number_of_dataset)):
  predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), model, smoothing=smoothing))

print('Saving Predictions')
# Save the results in the folder along with the masks according to the set threshold
saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, file_ext=file_ext_source)
 

In [None]:
#@markdown ##Determine IOU Scores
from fpdf import FPDF,HTMLMixin

#-----------------------------Calculate Metrics----------------------------------------#

# Here we start testing the differences between GT and predicted masks

with open(QC_model_path+'/'+QC_model_name+'/Quality Control/QC_metrics_'+QC_model_name+".csv", "w", newline='') as file:
    writer = csv.writer(file, delimiter=",")
    stats_columns = ["image"]

    for l in range(labels):
        stats_columns.append("Prediction v. GT IoU label = {}".format(l))
    stats_columns.append("Prediction v. GT averaged IoU")
    writer.writerow(stats_columns)  
    # Initialise the lists 
    filename_list = []
    iou_score_list = []
    print('Running QC on files')
    #m = tf.keras.metrics.MeanIoU(num_classes=labels)
    for i, filename in enumerate(tqdm(os.listdir(Source_QC_folder))):
        if not os.path.isdir(os.path.join(Source_QC_folder, filename)):
            #print('Running QC on: '+filename)
            test_input = io.imread(os.path.join(Source_QC_folder, filename))
            test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)
            test_prediction = io.imread(os.path.join(prediction_QC_folder, prediction_prefix + filename), as_gray=True)
            iou_labels = [filename]
            iou_score = 0.
            n_pix = test_ground_truth_image.shape[0]*test_ground_truth_image.shape[1]
            for l in range(labels):
                aux_gt = (test_ground_truth_image==l).astype(np.uint8)
                pix_gt = np.sum(aux_gt)
                aux_pred = (test_prediction==l).astype(np.uint8)
                intersection = np.logical_and(aux_gt, aux_pred)
                #print('inter: %f'%np.sum(intersection))
                union = np.logical_or(aux_gt, aux_pred)
                #print('union: %f'%np.sum(union))
                iou = np.sum(intersection) / np.sum(union)
                iou_labels.append(str(iou))
                if math.isnan(iou):
                  continue
                else:
                  iou_score += (pix_gt/n_pix) *(np.nansum(intersection) / np.nansum(union))
            filename_list.append(filename)
            iou_score_list.append(iou_score)
            iou_labels.append(str(iou_score))
            #iou_score_list.append(iou_score/labels)
            #iou_labels.append(str(iou_score/labels))
            writer.writerow(iou_labels)
    file.close()  

## Create a display of the results

# Table with metrics as dataframe output
pdResults = pd.DataFrame(index = filename_list)
pdResults["IoU"] = iou_score_list

# ------------- For display ------------
print('--------------------------------------------------------------')
@interact
def show_QC_results(file=os.listdir(Source_QC_folder)):
  
  plt.figure(figsize=(25,5))
  #Input
  plt.subplot(1,4,1)
  plt.axis('off')
  plt.imshow(io.imread(os.path.join(Source_QC_folder, file)), aspect='equal', interpolation='nearest')
  plt.title('Input')

  #Ground-truth
  plt.subplot(1,4,2)
  plt.axis('off')
  test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)
  plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')
  plt.title('Ground Truth')

  #Prediction
  plt.subplot(1,4,3)
  plt.axis('off')
  test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))
  plt.imshow(test_prediction, aspect='equal', cmap='Purples')
  plt.title('Prediction')

  #Overlay
  plt.subplot(1,4,4)
  plt.axis('off')
  plt.imshow(test_ground_truth_image, cmap='Greens')
  plt.imshow(test_prediction, alpha=0.5, cmap='Purples')
  metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file]["IoU"],3)) + ')'
  plt.title(metrics_title)
  plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)

qc_pdf_export()

pdResults.head()

# **6. Using the trained model**

---
<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1 Generate prediction(s) from unseen dataset**
---

<font size = 4>The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.

<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Result_folder`:** This folder will contain the predicted output images.

<font size = 4> Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.

<font size = 4> **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters.

In [None]:


# ------------- Initial user input ------------
#@markdown ###Provide the path to your dataset and to the folder where the predicted masks will be saved (Result folder), then play the cell to predict the output on your unseen images and store it.
# @markdown  `file_ext_source` = Extension you want results saved as (ex: .tif)
Data_folder = '/content/drive/MyDrive/Example_Seg_Pipeline - 3D/Data/test_source' #@param {type:"string"}
Results_folder = '/content/temp' #@param {type:"string"}
file_ext_save =  ".tif"#@param{type:"string"}
#@markdown ## Image properties
labels =  5#@param{type:"number"}
patch_width =  256#@param{type:"number"}
patch_height =  256#@param{type:"number"}
#@markdown ###Do you want to use the current trained model?
Use_the_current_trained_model = False #@param {type:"boolean"}
model_architecture = 'Unet'#@param ["Unet","Linknet","PSPNet","FPN"]{type: "string"}
backbone = 'resnet50' #@param ['vgg16', 'vgg19', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152','seresnet18', 'seresnet34' ,'seresnet50' ,'seresnet101', 'seresnet152','resnext50', 'resnext101','seresnext50', 'seresnext101','senet154','densenet121','densenet169', 'densenet201','inceptionv3', 'inceptionresnetv2','mobilenet', 'mobilenetv2','efficientnetb0', 'efficientnetb1', 'efficientnetb2', 'efficientnetb3', 'efficientnetb4' ,'efficientnetb5' ,'efficientnetb6' ,'efficientnetb7']{type:'string'}
weight_decay=0
#@markdown ###Do you want to use smooth blending?
smoothing = False #@param {type:"boolean"}

#@markdown ###If not, please provide the path to the model file:

Prediction_model_file = "/content/drive/MyDrive/Example_Seg_Pipeline - 3D/Models/landcoverai-1/weights_best.hdf5" #@param {type:"string"}

#Here we find the loaded model name and parent path
#Prediction_model_name = os.path.basename(Prediction_model_folder)
#Prediction_model_path = os.path.dirname(Prediction_model_folder)
Input_size = [patch_height, patch_width]

# ------------- Failsafes ------------
'''if (Use_the_current_trained_model): 
  print("Using current trained network")
  Prediction_model_name = model_name
  Prediction_model_path = model_path

full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)
if os.path.exists(full_Prediction_model_path):
  print("The "+Prediction_model_name+" network will be used.")
else:
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')'''


# ------------- Prepare the model and run predictions ------------
if not Use_the_current_trained_model:
  model = net3d(model_architecture,
               backbone,
               pretrained_weights = Prediction_model_file,
               input_size = (patch_width,patch_height,3), 
               learning_rate = 0.01, 
               labels = labels,
               optimizer = 'Adam',
               verbose = False,
               imagenet_weights=False)
else:
  pass
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
from tensorflow.keras.utils import to_categorical
preprocess_input = sm.get_preprocessing(backbone)
def preprocess_data(img):
    #Scale images
    img = scaler.fit_transform(img.reshape(-1, img.shape[-1])).reshape(img.shape)
    img = preprocess_input(img)  #Preprocess based on the pretrained backbone...
    return img

# Create a list of sources
source_dir_list = os.listdir(Data_folder)
number_of_dataset = len(source_dir_list)
print('Number of dataset found in the folder: '+str(number_of_dataset))

# Load the model and prepare generator
'''model_name = os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5')
if smooth_blend:
  # size of patches
  patch_size = patch_height
  # Number of classes 
  n_classes = labels
  predictions = []
  for i in tqdm(range(number_of_dataset)):
    input_img = io.imread(os.path.join(Data_folder, source_dir_list[i]), plugin='matplotlib')
    predictions_smooth = predict_img_with_smooth_windowing(
      input_img,
      window_size=patch_size,
      subdivisions=2,  # Minimal amount of overlap for windowing. Must be an even number.
      nb_classes=n_classes,
      pred_func=(lambda img_batch_subdiv: model.predict((img_batch_subdiv)))
    )
    predictions.append(np.argmax(predictions_smooth, axis=2).astype('uint8'))
  saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, file_ext=file_ext_save)

else:
  Input_size = [patch_height, patch_width]
  print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))
'''
predictions = []
for i in tqdm(range(number_of_dataset)):
  predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), model, smoothing=smoothing))
  # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))
  # Save the results in the folder along with the masks according to the set threshold
saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, file_ext=file_ext_save)

In [None]:
#@markdown ## Run cell to visualize predicted patches
# ------------- For display ------------

from pathlib import Path
def show_prediction_mask(file=os.listdir(Data_folder)):

  plt.figure(figsize=(10,6))
  # Wide-field
  plt.subplot(1,2,1)
  plt.axis('off')
  img_Source = plt.imread(os.path.join(Data_folder, file))
  plt.imshow(img_Source, cmap='gray')
  plt.title('Source image',fontsize=15)
  # Prediction
  plt.subplot(1,2,2)
  plt.axis('off')
  img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+Path(file).stem+file_ext_save))
  plt.imshow(img_Prediction)
  plt.title('Prediction',fontsize=15)

interact(show_prediction_mask);