  Author: Ankit Kariryaa, University of Bremen
  
  Modified by Xuehui Pi and Qiuqi Luo

### Getting started
Define the paths for the dataset and trained models in the `notebooks/config/UNetTraining.py` file.  

In [None]:
import os
os.environ["MKL_NUM_THREADS"] = '16'
os.environ["NUMEXPR_NUM_THREADS"] = '16'
os.environ["OMP_NUM_THREADS"] = '16'
print(os.environ.get('OMP_NUM_THREADS'))

In [None]:
import tensorflow as tf
import numpy as np
from PIL import Image
import rasterio
import imgaug as ia
from imgaug import augmenters as iaa
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import imageio
import os

import time
import rasterio.warp             # Reproject raster samples
from functools import reduce
from tensorflow.keras.models import load_model

from core.UNet import UNet  #
from core.losses import tversky, accuracy, dice_coef, dice_loss, mIoU,specificity, sensitivity
from core.optimizers import adaDelta, adagrad, adam, nadam
from core.frame_info import FrameInfo
from core.dataset_generator import DataGenerator
from core.split_frames import split_dataset1,split_dataset2,split_dataset3,split_dataset4#,split_dataset5
from core.visualize import display_images

import json
from sklearn.model_selection import train_test_split

%matplotlib inline
import matplotlib.pyplot as plt  # plotting tools
import matplotlib.patches as patches
from matplotlib.patches import Polygon
#matplotlib.use("Agg")

import warnings                  # ignore annoying warnings
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

#Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory.
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

print(tf.__version__)

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto(
    #device_count={"CPU": 64},
    allow_soft_placement=True, 
    log_device_placement=False)
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# Required configurations (including the input and output paths) are stored in a separate file (such as config/UNetTraining.py)
# Please provide required info in the file before continuing with this notebook. 
from config import UNetTraining
# In case you are using a different folder name such as configLargeCluster, then you should import from the respective folder 
# Eg. from configLargeCluster import UNetTraining
config = UNetTraining.Configuration()

In [None]:
# Read all images/frames into memory 
frames1 = []

all_files = os.listdir(config.path_to_write1)
all_files_GSW = [fn for fn in all_files if fn.startswith(config.GSW_fn) and fn.endswith(config.image_type)]#occurrence.png
len(all_files_GSW)
print(all_files_GSW)
for i, fn in enumerate(all_files_GSW):
    GSW_img = rasterio.open(os.path.join(config.path_to_write1, fn))
    read_GSW_img = GSW_img.read()
    comb_img = np.transpose(read_GSW_img, axes=(1,2,0)) #Channel at the end  ( , ,1)
    
    annotation_im = Image.open(os.path.join(config.path_to_write1, fn.replace(config.GSW_fn,config.annotation_fn)))
    annotation = np.array(annotation_im)
    f = FrameInfo(comb_img, annotation)
    frames1.append(f)
print(len(frames1))
    
training_frames1, validation_frames1, testing_frames1  = split_dataset1(frames1, config.frames_json1, config.patch_dir)

In [None]:
# Read all images/frames into memory 
frames2 = []

all_files = os.listdir(config.path_to_write2)
all_files_GSW = [fn for fn in all_files if fn.startswith(config.GSW_fn) and fn.endswith(config.image_type)]
len(all_files_GSW)
print(all_files_GSW)
for i, fn in enumerate(all_files_GSW):
    GSW_img = rasterio.open(os.path.join(config.path_to_write2, fn))
    read_GSW_img = GSW_img.read()
    comb_img = np.transpose(read_GSW_img, axes=(1,2,0)) 
    
    annotation_im = Image.open(os.path.join(config.path_to_write2, fn.replace(config.GSW_fn,config.annotation_fn)))
    annotation = np.array(annotation_im)
    f = FrameInfo(comb_img, annotation)
    frames2.append(f)

frames_12=frames1+frames2
print(len(frames_12))

training_frames2, validation_frames2, testing_frames2  = split_dataset2(frames1,frames_12,config.frames_json2, config.patch_dir)

In [None]:
# Read all images/frames into memory 
frames3 = []

all_files = os.listdir(config.path_to_write3)
all_files_GSW = [fn for fn in all_files if fn.startswith(config.GSW_fn) and fn.endswith(config.image_type)]
len(all_files_GSW)
print(all_files_GSW)
for i, fn in enumerate(all_files_GSW):
    GSW_img = rasterio.open(os.path.join(config.path_to_write3, fn))
    read_GSW_img = GSW_img.read()
    comb_img = np.transpose(read_GSW_img, axes=(1,2,0)) 
    
    annotation_im = Image.open(os.path.join(config.path_to_write3, fn.replace(config.GSW_fn,config.annotation_fn)))
    annotation = np.array(annotation_im)
    f = FrameInfo(comb_img, annotation)
    frames3.append(f)

frames_123=frames1+frames2+frames3
print(len(frames_123))

training_frames3, validation_frames3, testing_frames3  = split_dataset3(frames_12,frames_123, config.frames_json3, config.patch_dir)

In [None]:
# Read all images/frames into memory  
frames4 = []

all_files = os.listdir(config.path_to_write4)
all_files_GSW = [fn for fn in all_files if fn.startswith(config.GSW_fn) and fn.endswith(config.image_type)]
len(all_files_GSW)
print(all_files_GSW)
for i, fn in enumerate(all_files_GSW):
    GSW_img = rasterio.open(os.path.join(config.path_to_write4, fn))
    read_GSW_img = GSW_img.read()
    comb_img = np.transpose(read_GSW_img, axes=(1,2,0)) 
    
    annotation_im = Image.open(os.path.join(config.path_to_write4, fn.replace(config.GSW_fn,config.annotation_fn)))
    annotation = np.array(annotation_im)
    f = FrameInfo(comb_img, annotation)
    frames4.append(f)
    
frames_1234=frames1+frames2+frames3+frames4
print(len(frames_1234))

training_frames4, validation_frames4, testing_frames4  = split_dataset4(frames_123,frames_1234, config.frames_json4, config.patch_dir)

In [None]:
training_frames=training_frames1+training_frames2+training_frames3+training_frames4
validation_frames=validation_frames1+validation_frames2+validation_frames3+validation_frames4
testing_frames=testing_frames1+testing_frames2+testing_frames3+testing_frames4

annotation_channels = config.input_label_channel

train_generator = DataGenerator(config.input_image_channel, config.patch_size, training_frames, frames_1234, annotation_channels, augmenter = 'iaa').random_generator(config.BATCH_SIZE, normalize = config.normalize)
val_generator = DataGenerator(config.input_image_channel, config.patch_size, validation_frames, frames_1234, annotation_channels, augmenter= None).random_generator(config.BATCH_SIZE, normalize = config.normalize)
test_generator = DataGenerator(config.input_image_channel, config.patch_size, testing_frames, frames_1234, annotation_channels, augmenter= None).random_generator(config.BATCH_SIZE, normalize = config.normalize)

In [None]:
for _ in range(1):
    train_images, real_label = next(train_generator) 
    display_images(np.concatenate((train_images,real_label), axis = -1))

In [None]:
OPTIMIZER = adaDelta 
OPTIMIZER = tf.train.experimental.enable_mixed_precision_graph_rewrite(OPTIMIZER)
# OPTIMIZER=mixed_precision.LossScaleOptimizer(OPTIMIZER, loss_scale='dynamic')#tf.keras.mixed_precision.experimental.LossScaleOptimizer///tf.keras.mixed_precision.LossScaleOptimizer
#Wraps the original optimizer in a LossScaleOptimizer
LOSS = tversky 
# OPTIMIZER.minimize(LOSS)## 'minimize' applies loss scaling to the loss and updates the loss sale.

#Only for the name of the model in the very end
OPTIMIZER_NAME = 'AdaDelta'
LOSS_NAME = 'weightmap_tversky'

# Declare the path to the final model
# If you want to retrain an exising model then change the cell where model is declared. 
# This path is for storing a model after training.
timestr = time.strftime("%Y%m%d-%H%M")
chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 

if not os.path.exists(config.model_path):
    os.makedirs(config.model_path)
model_path = os.path.join(config.model_path,'lakes_{}_{}_{}_{}_{}.h5'.format(timestr,OPTIMIZER_NAME,LOSS_NAME,chs,config.input_shape[0]))
print(model_path)

In [None]:
# Define the model and compile it  
model = UNet([config.BATCH_SIZE, *config.input_shape],config.input_label_channel)
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[dice_coef, dice_loss, specificity, sensitivity, accuracy,mIoU])

In [None]:
# Define callbacks      for the early stopping of training, LearningRateScheduler and model checkpointing 
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, TensorBoard

checkpoint = ModelCheckpoint(model_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only = False)

#reduceonplatea： It can be useful when using adam as optimizer
#Reduce learning rate when a metric has stopped improving (after some patience 个epoch, reduce by a factor of 0.33, new_lr = lr * factor). 
#cooldown: number of epochs to wait before resuming normal operation after lr has been reduced. 

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.33,
                                   patience=4, verbose=1, mode='min',
                                   min_delta=0.0001, cooldown=4, min_lr=1e-16) 

early = EarlyStopping(monitor="val_loss", mode="min", verbose=2, patience=100)


log_dir = os.path.join('./logs','UNet_{}_{}_{}_{}_{}'.format(timestr,OPTIMIZER_NAME,LOSS_NAME,chs, config.input_shape[0]))
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')

callbacks_list = [checkpoint, tensorboard, early] #reduceLROnPlat is not required with adaDelta

In [None]:
loss_history = model.fit(train_generator, 
                         steps_per_epoch=config.MAX_TRAIN_STEPS,
                         epochs=config.NB_EPOCHS, 
                         validation_data=val_generator,
                         validation_steps=config.VALID_IMG_COUNT,
                         callbacks=callbacks_list,
                         workers=1,
#                          use_multiprocessing=True # the generator is not very thread safe 
                         #max_queue_size = 60,
                        )

In [None]:
# Load model after training 
# If you load a model with different python version, than you may run into a problem: https://github.com/keras-team/keras/issues/9595#issue-303471777
model = load_model(model_path, custom_objects={'tversky': LOSS, 'dice_coef': dice_coef, 'dice_loss':dice_loss, 'accuracy':accuracy ,'mIoU': mIoU,'specificity': specificity, 'sensitivity':sensitivity}, compile=False) 

# # In case you want to use multiple GPU you can uncomment the following lines.
# strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"], cross_device_ops=tf.distribute.ReductionToOneDevice())
# print('Number of devices: %d' % strategy.num_replicas_in_sync)

model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[dice_coef, dice_loss, accuracy,mIoU, specificity, sensitivity])

In [None]:
# Print one batch on the training/test data! 
for i in range(1):
    test_images, real_label = next(test_generator)
    #3 images per row: GSW, label, prediction
    prediction = model.predict(test_images, steps=1)
    prediction[prediction>0.5]=1
    prediction[prediction<=0.5]=0
    display_images(np.concatenate((test_images, real_label, prediction), axis = -1))# test_images( GSW), real_label(label), prediction