  Author: Ankit Kariryaa, University of Bremen
  
  Modified by Xuehui Pi, Qiuqi Luo and Beihui Hu

### Getting started
Define the paths for the dataset and trained models in the `notebooks/config/UNetTraining.py` file.  

In [None]:
import os
os.environ["MKL_NUM_THREADS"] = '16'
os.environ["NUMEXPR_NUM_THREADS"] = '16'
os.environ["OMP_NUM_THREADS"] = '16'
print(os.environ.get('OMP_NUM_THREADS'))

In [None]:
import tensorflow as tf
import numpy as np
from PIL import Image
import rasterio
from tensorflow.keras import mixed_precision 
mixed_precision.set_global_policy('mixed_float16')

import os

import time
import rasterio.warp             # Reproject raster samples
from functools import reduce
from tensorflow.keras.models import load_model

from core.UNet import UNet 
from core.losses import accuracy, dice_loss, IoU, recall, precision,F1_score
from core.optimizers import adaDelta
from core.frame_info import FrameInfo
from core.dataset_generator import DataGenerator
from core.visualize import display_images


import warnings                  # ignore annoying warnings
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

#Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory.
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
print(tf.__version__)

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto(
    #device_count={"CPU": 64},
    allow_soft_placement=True, 
    log_device_placement=False)
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# Required configurations (including the input and output paths) are stored in a separate file (such as config/UNetTraining.py)
# Please provide required info in the file before continuing with this notebook. 
# hbh: in this scene,a new config named UNetTraining_sequential is created to distinguish from the original
from config import UNetTraining
# In case you are using a different folder name such as configLargeCluster, then you should import from the respective folder 
# Eg. from configLargeCluster import UNetTraining
config = UNetTraining.Configuration()

In [None]:
def readImgs(path_to_write, fn):
    image = rasterio.open(os.path.join(path_to_write, fn))
    read_image = image.read()
    comb_img = np.transpose(read_image, axes=(1,2,0))
    annotation_im = Image.open(os.path.join(path_to_write, fn.replace(config.image_fn,config.annotation_fn).replace(config.image_type,config.ann_type)))
    annotation = np.array(annotation_im)
    patch_count=annotation.shape[0]*annotation.shape[1]/(config.input_shape[0]*config.input_shape[1])
    f = FrameInfo(comb_img, annotation)
    return f,patch_count

def readFrames(dataType):
    frames=[]
    patch_count_list=[] 
    print(dataType)
    dataset_dir=os.path.join(config.dataset_dir,'{}'.format(dataType))
    all_files = os.listdir(dataset_dir)
    all_files_image = [fn for fn in all_files if fn.startswith(config.image_fn) and fn.endswith(config.image_type)]
    for j, fn in enumerate(all_files_image):
        f,pc = readImgs(dataset_dir,fn)
        frames.append(f)
        patch_count_list.append(pc)
    return frames,patch_count_list

In [None]:
# Read images for training, calculate the percentage of each image to be selected while use random strategy
frames,patch_count_list=readFrames('train')
patch_count_list=np.array(patch_count_list)
train_patch_count=patch_count_list.sum()
percentages=patch_count_list/train_patch_count
print('total training image count:'+str(len(frames)))
train_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = 'iaa').random_generator(config.BATCH_SIZE,percentages)

In [None]:
# Read images for validation, calculate the percentage of each image to be selected while use random strategy
frames,patch_count_list=readFrames('val')
patch_count_list=np.array(patch_count_list)
val_patch_count=patch_count_list.sum()
percentages=patch_count_list/val_patch_count
print('total validation image count:'+str(len(frames)))
val_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).random_generator(config.BATCH_SIZE,percentages)

In [None]:
# Read images for test, calculate the percentage of each image to be selected while use random strategy
frames,patch_count_list=readFrames('test')
patch_count_list=np.array(patch_count_list)
test_patch_count=patch_count_list.sum()
percentages=patch_count_list/test_patch_count
print('total test image count:'+str(len(frames)))
test_generator=DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).random_generator(config.BATCH_SIZE,percentages)

In [None]:
for _ in range(1):
    train_images, real_label = next(train_generator) 
    display_images(np.concatenate((train_images,real_label), axis = -1))

In [None]:
for _ in range(1):
    val_images, val_label = next(val_generator) 
    display_images(np.concatenate((val_images,val_label), axis = -1))

In [None]:
for _ in range(1):
    test_images, real_label = next(test_generator) 
    display_images(np.concatenate((test_images,real_label), axis = -1))

In [None]:
OPTIMIZER = adaDelta
OPTIMIZER =  mixed_precision.LossScaleOptimizer(OPTIMIZER)
OPTIMIZER_NAME = 'AdaDelta'

LOSS=dice_loss
LOSS_NAME = 'dice_loss'

#Declare the path to the final model
#If you want to retrain an exising model then change the cell where model is declared. 
# This path is for storing a model after training.

timestr = time.strftime("%Y%m%d-%H%M")
chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 

if not os.path.exists(config.model_path):
    os.makedirs(config.model_path)
model_name='{}_{}_{}_{}_{}.h5'.format(timestr,OPTIMIZER_NAME,LOSS_NAME,chs,config.input_shape[0])
model_path = os.path.join(config.model_path,'lakes_'+model_name)

chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 
print(model_path)

In [None]:
# Define the model and compile it  
model = UNet([config.BATCH_SIZE, *config.input_shape],config.input_label_channel)
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[accuracy, recall, precision,F1_score, IoU])

In [None]:
# Define callbacks      for the early stopping of training, LearningRateScheduler and model checkpointing 
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, TensorBoard

checkpoint = ModelCheckpoint(model_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only = False)

#reduceonplatea： It can be useful when using adam as optimizer
#Reduce learning rate when a metric has stopped improving (after some patience 个epoch, reduce by a factor of 0.33, new_lr = lr * factor). 
#cooldown: number of epochs to wait before resuming normal operation after lr has been reduced. 

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.33,
                                   patience=4, verbose=1, mode='min',
                                   min_delta=0.0001, cooldown=4, min_lr=1e-16) 

early = EarlyStopping(monitor="val_loss", mode="min", verbose=2, patience=50)


log_dir = os.path.join('./logs','UNet_'+model_name)
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')

callbacks_list = [checkpoint, tensorboard, early] #reduceLROnPlat is not required with adaDelta

In [None]:
loss_history = model.fit(train_generator, 
                         steps_per_epoch=config.steps_per_epoch,
                         epochs=config.NB_EPOCHS, 
                         validation_data=val_generator,
                         validation_steps=config.validation_steps,
                         callbacks=callbacks_list,
                         workers=1
                        )

In [None]:
# Load model after training 
model = load_model(model_path, custom_objects={'dice loss': LOSS, 'accuracy':accuracy ,'recall':recall, 'precision':precision,'F1_score':F1_score,'IoU': IoU,}, compile=False) 

# # In case you want to use multiple GPU you can uncomment the following lines.
# strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"], cross_device_ops=tf.distribute.ReductionToOneDevice())
# print('Number of devices: %d' % strategy.num_replicas_in_sync)

model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[accuracy,recall,F1_score, precision, IoU])

In [None]:
titles=['ndwi','rgb','swir','annotation','prediction']
for i in range(1):
    test_images, real_label = next(test_generator)
    prediction = model.predict(test_images, steps=1)
    prediction[prediction>0.5]=1
    prediction[prediction<=0.5]=0
    display_images(np.concatenate((test_images, real_label, prediction), axis = -1),titles=titles)