  Author: Ankit Kariryaa, University of Bremen
  
  Modified by Xuehui Pi and Qiuqi Luo

### Getting started
Define the paths for the dataset and trained models in the `notebooks/config/UNetTraining.py` file.  

In [None]:
import os
os.environ["MKL_NUM_THREADS"] = '16'
os.environ["NUMEXPR_NUM_THREADS"] = '16'
os.environ["OMP_NUM_THREADS"] = '16'
print(os.environ.get('OMP_NUM_THREADS'))

In [None]:
import tensorflow as tf
import numpy as np
from PIL import Image
import rasterio
import imgaug as ia
from imgaug import augmenters as iaa
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras import mixed_precision 
mixed_precision.set_global_policy('mixed_float16')

import imageio
import os

import time
import rasterio.warp             # Reproject raster samples
from functools import reduce
from tensorflow.keras.models import load_model

from core.UNet import UNet  #
from core.losses import accuracy, dice_loss, IoU, recall, precision,F1_score
from tensorflow.keras.losses import BinaryCrossentropy as bce
from core.optimizers import adaDelta, adagrad, adam, nadam
from core.frame_info import FrameInfo
from core.dataset_generator import DataGenerator
from core.visualize import display_images_2,plot
import json
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import pickle
import random

%matplotlib inline
import matplotlib.pyplot as plt  # plotting tools
import matplotlib.patches as patches
from matplotlib.patches import Polygon
#matplotlib.use("Agg")

import warnings                  # ignore annoying warnings
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

#Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory.
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
print(tf.__version__)

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto(
    #device_count={"CPU": 64},
    allow_soft_placement=True, 
    log_device_placement=False)
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# Required configurations (including the input and output paths) are stored in a separate file (such as config/UNetTraining.py)
# Please provide required info in the file before continuing with this notebook. 
# hbh: in this scene,a new config named UNetTraining_sequential is created to distinguish from the original
from config import UNetTraining
# In case you are using a different folder name such as configLargeCluster, then you should import from the respective folder 
# Eg. from configLargeCluster import UNetTraining
config = UNetTraining.Configuration()

In [None]:
def readImgs(path_to_write, fn):
    image = rasterio.open(os.path.join(path_to_write, fn))
    read_image = image.read()/1000
    comb_img = np.transpose(read_image, axes=(1,2,0))
    annotation_im = Image.open(os.path.join(path_to_write, fn.replace(config.image_fn,config.annotation_fn).replace(config.image_type,config.ann_type)))
    annotation = np.array(annotation_im)
    rowNum=annotation.shape[0]/config.input_shape[0]
    colNum=annotation.shape[1]/config.input_shape[1]
    f = FrameInfo(comb_img, annotation)
    return f,rowNum*colNum

def readFrames(dataType):
    frames=[]
    numList=[]
    print(dataType)
    for i in range(0,config.type_num):
        path_to_write=os.path.join(config.dataset_dir,'{}/type{}'.format(dataType,i))
        all_files = os.listdir(path_to_write)
        all_files_image = [fn for fn in all_files if fn.startswith(config.image_fn) and fn.endswith(config.image_type)]#image.png
        print('type{} image number :{}'.format(i,len(all_files_image)))
        for j, fn in enumerate(all_files_image):
            f,num = readImgs(path_to_write,fn)
            frames.append(f)
            numList.append(num)
    return frames,numList


## 参数初始化

In [None]:
OPTIMIZER = adaDelta
OPTIMIZER =  mixed_precision.LossScaleOptimizer(OPTIMIZER)
OPTIMIZER_NAME = 'AdaDelta'

In [None]:
LOSS=dice_loss
LOSS_NAME = 'dice_loss'

## 数据集准备

In [None]:
frames,numList=readFrames('train')
percentages=np.array(numList)
train_sum=percentages.sum()
percentages=percentages/percentages.sum()
print('total training img count:'+str(len(frames)))
print('total training patches count:{}'.format(train_sum))
train_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = 'iaa').random_generator(config.BATCH_SIZE,percentages)#,normalize =

In [None]:
frames,numList=readFrames('val')
percentages=np.array(numList)
val_sum=percentages.sum()
percentages=percentages/percentages.sum()
print('total training img count:'+str(len(frames)))
print('total training patches count:{}'.format(val_sum))
val_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).random_generator(config.BATCH_SIZE,percentages)#, normalize = config.normalize

In [None]:
for _ in range(1):
    train_images, real_label = next(train_generator) 
#     print(train_images.Length())
    display_images_2(np.concatenate((train_images,real_label), axis = -1),pad=100,output_size=376)

In [None]:
for _ in range(1):
    val_images, val_label = next(val_generator) 
    print(val_images.shape)
    display_images_2(np.concatenate((val_images,val_label), axis = -1),pad=100,output_size=376)

### 训练参数设置

In [None]:
timestr = time.strftime("%Y%m%d-%H%M")
chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 

if not os.path.exists(config.model_path):
    os.makedirs(config.model_path)
model_name='_{}_{}_{}_{}_{}.h5'.format(timestr,OPTIMIZER_NAME,LOSS_NAME,chs,config.input_shape[0])
model_path = os.path.join(config.model_path,'lakes'+model_name)

chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 
print(model_path)

In [None]:
# Define the model and compile it  
model = UNet([config.BATCH_SIZE, *config.input_shape],config.input_label_channel)
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[accuracy, recall, precision,F1_score, IoU])

In [None]:
# Define callbacks      for the early stopping of training, LearningRateScheduler and model checkpointing 
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, TensorBoard

checkpoint = ModelCheckpoint(model_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only = False)

#reduceonplatea： It can be useful when using adam as optimizer
#Reduce learning rate when a metric has stopped improving (after some patience 个epoch, reduce by a factor of 0.33, new_lr = lr * factor). 
#cooldown: number of epochs to wait before resuming normal operation after lr has been reduced. 

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.33,
                                   patience=4, verbose=1, mode='min',
                                   min_delta=0.0001, cooldown=4, min_lr=1e-16) 

#early = EarlyStopping(monitor="val_loss", mode="min", verbose=2, patience=40)


log_dir = os.path.join('./logs','UNet'+model_name)
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')

callbacks_list = [checkpoint, tensorboard] #reduceLROnPlat is not required with adaDelta, early

## 训练模型
保存history文件并绘图

In [None]:
loss_history = model.fit(train_generator, 
                         steps_per_epoch=train_sum//config.BATCH_SIZE,
                         epochs=config.NB_EPOCHS, 
                         validation_data=val_generator,
                         validation_steps=val_sum//config.BATCH_SIZE,
                         callbacks=callbacks_list,
                         workers=1
                        )
h=loss_history.history
with open('history_{}_{}_{}_{}_{}.txt'.format(timestr,OPTIMIZER_NAME,LOSS_NAME, chs,config.input_shape[0]), 'wb') as file_pi:
    pickle.dump(h, file_pi)
plot(h,timestr, OPTIMIZER_NAME,LOSS_NAME, config.patch_size[0], config.NB_EPOCHS, config.BATCH_SIZE,chs)

In [None]:
##### 读取现有history文件
with open(r'.txt','rb')as file_pi:
    h=pickle.load(file_pi)
# print(h)

## 模型精度评价

In [None]:
# Load model after training 
model_path=os.path.join(config.model_path,r'lakes_20240219-2347_AdaDelta_dice_loss_01234_576.h5')
model = load_model(model_path, custom_objects={'dice loss': LOSS, 'accuracy':accuracy ,'recall':recall, 'precision':precision,'F1_score':F1_score,'IoU': IoU,}, compile=False) 
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[accuracy,recall,F1_score, precision, IoU])

In [None]:
# titles=['ndwi','rgb','swir','annotation','prediction']
# j=1
# frames=[]
# path_to_write=os.path.join(config.dataset_dir,'test/type'+str(j))
# all_files = os.listdir(path_to_write)
# all_image_files = [fn for fn in all_files if fn.startswith(config.image_fn) and fn.endswith(config.image_type)]#ndwi.png
# for i, fn in enumerate(all_image_files):
#     f,nums = readImgs(path_to_write,fn)
#     frames.append(f)
# test_DGT=DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None)
# test_patches_type = test_DGT.all_sequential_patches(config.step_size)
# print(len(test_patches_type[0]))
# step=8
# for i in range(len(test_patches_type[0])//step+1):
#     prediction = model.predict(test_patches_type[0][i*step:(i+1)*step], steps=1)
#     prediction[prediction>0.5]=1
#     prediction[prediction<=0.5]=0
#     fn=r'D:\sample746\image\type{}_{}_576_2.png'.format(j,i)
#     image=np.concatenate((test_patches_type[0][i*step:(i+1)*step], test_patches_type[1][i*step:(i+1)*step], prediction), axis = -1)
#     display_images_2(image,92,388,fn,titles=titles)

## 总体精度评价

In [None]:
frames,numList=readFrames('test')#,sample_factor
test_patches = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).all_sequential_patches(config.step_size)
print('test patchs number:',len(test_patches[0]))

In [None]:
ev=model.evaluate(test_patches[0],test_patches[1],config.BATCH_SIZE)
ev.append(len(test_patches[0]))

## 分类别精度评价

In [None]:
ev_list=[]
for i in range(0,config.type_num):
    frames=[]
    test_patchs_type=[]
    path_to_write=os.path.join(config.dataset_dir,'test/type{}'.format(i))
    all_files = os.listdir(path_to_write)
    all_files_NDWI = [fn for fn in all_files if fn.startswith(config.image_fn) and fn.endswith(config.image_type)]#ndwi.png
    for j, fn in enumerate(all_files_NDWI):
        f ,num = readImgs(path_to_write,fn)
        frames.append(f)
    test_patchs_type = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).all_sequential_patches(config.step_size)
    print('type{} patches number:{}'.format(i,len(test_patchs_type[0])))
    ev2 = model.evaluate(test_patchs_type[0],test_patchs_type[1],config.BATCH_SIZE,verbose=1)
    ev2.append(len(test_patchs_type[0]))
    ev_list.append(ev2)
    del frames,test_patchs_type

In [None]:
print('      patch_num   loss accuracy   recall  precision F1_score     IoU')
print('total:{:^9}{:^9.3f}{:^9.3f}{:^9.3f} {:^9.3f} {:^9.3f} {:^9.3f}'.format(ev[-1],ev[0],ev[1],ev[2],ev[3],ev[4],ev[5]))
for i in range(0,config.type_num):
    print('type{}:{:^9}{:^9.3f}{:^9.3f}{:^9.3f} {:^9.3f} {:^9.3f} {:^9.3f}'.format(i,ev_list[i][-1],ev_list[i][0],ev_list[i][1],ev_list[i][2],ev_list[i][3],ev_list[i][4],ev_list[i][5]))