  Author: Ankit Kariryaa, University of Bremen
  
  Modified by Xuehui Pi and Qiuqi Luo

### Getting started
Define the paths for the dataset and trained models in the `notebooks/config/UNetTraining.py` file.  

In [None]:
import os
os.environ["MKL_NUM_THREADS"] = '16'
os.environ["NUMEXPR_NUM_THREADS"] = '16'
os.environ["OMP_NUM_THREADS"] = '16'
print(os.environ.get('OMP_NUM_THREADS'))

In [None]:
import tensorflow as tf
import numpy as np
from PIL import Image
import rasterio
import imgaug as ia
from imgaug import augmenters as iaa
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import mixed_precision 
mixed_precision.set_global_policy('mixed_float16')

import tensorflow as tf

import imageio
import os

import time
import rasterio.warp             # Reproject raster samples
from functools import reduce
from tensorflow.keras.models import load_model

from core.UNet import UNet  #
from core.losses import tversky, focalTversky, bce_dice_loss, accuracy, dice_loss, IoU, recall, precision
from tensorflow.keras.losses import BinaryCrossentropy as bce
from core.optimizers import adaDelta, adagrad, adam, nadam
from core.frame_info import FrameInfo
from core.dataset_generator import DataGenerator
from core.visualize import display_images,plot
import json
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import shutil
import pickle
import random

%matplotlib inline
import matplotlib.pyplot as plt  # plotting tools
import matplotlib.patches as patches
from matplotlib.patches import Polygon
#matplotlib.use("Agg")

import warnings                  # ignore annoying warnings
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

#Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory.
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print(tf.__version__)

In [None]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
# tf.device('/gpu:1')

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto(
    #device_count={"CPU": 64},
    allow_soft_placement=True, 
    log_device_placement=False)
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# Required configurations (including the input and output paths) are stored in a separate file (such as config/UNetTraining.py)
# Please provide required info in the file before continuing with this notebook. 
# hbh: in this scene,a new config named UNetTraining_sequential is created to distinguish from the original
from config import UNetTraining_b3
# In case you are using a different folder name such as configLargeCluster, then you should import from the respective folder 
# Eg. from configLargeCluster import UNetTraining
config = UNetTraining_b3.Configuration()

In [None]:
def readBands(path_to_write,fn):
    img=rasterio.open(os.path.join(path_to_write,fn))
#     im=img.read()
#     axis=(0, 1)
#     read_img=(im - im.mean(axis)) / (im.std(axis) + 1e-8)
    read_img=img.read()/1000
    return read_img

def readImgs(path_to_write, fn):
    NDWI_img = rasterio.open(os.path.join(path_to_write, fn))
    read_NDWI_img = NDWI_img.read()/100
    rowNum=read_NDWI_img.shape[1]/config.patch_size[0]
    colNum=read_NDWI_img.shape[2]/config.patch_size[1]
    read_green_img =readBands(path_to_write,fn.replace(config.NDWI_fn ,config.green_fn))
    read_swir_img = readBands(path_to_write, fn.replace(config.NDWI_fn ,config.swir_fn))
    comb_img = np.concatenate((read_NDWI_img,read_green_img, read_swir_img), axis=0)
    comb_img = np.transpose(comb_img, axes=(1,2,0)) #Channel at the end  ( , ,1) 
    
    annotation_im = Image.open(os.path.join(path_to_write, fn.replace(config.NDWI_fn,config.annotation_fn)))
    annotation = np.array(annotation_im)
    
    f = FrameInfo(comb_img, annotation)
    return f ,rowNum*colNum

def readFrames(dataType):
    frames=[]
    numList=[]
    print(dataType)
    for i in range(0,config.type_num):
        path_to_write=os.path.join(config.dataset_dir,'{}/type{}'.format(dataType,i))
        all_files = os.listdir(path_to_write)
        all_files_NDWI = [fn for fn in all_files if fn.startswith(config.NDWI_fn) and fn.endswith(config.image_type)]#ndwi.png
        print('type{} image number:{}'.format(i,len(all_files_NDWI)))
        for j, fn in enumerate(all_files_NDWI):
            f,num = readImgs(path_to_write,fn)
            frames.append(f)
            numList.append(num)
    return frames,numList

### 数据集准备

In [None]:
frames,numList=readFrames('train')
percentages=np.array(numList)
print(percentages.sum())
percentages=percentages/percentages.sum()
print('total training img count:'+str(len(frames)))
train_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = 'iaa').random_generator(config.BATCH_SIZE,percentages)#,normalize = config.normalize

In [None]:
# frames=readFrames('train')
# train_patches = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = 'iaa').all_sequential_patches(config.step_size)
# print('train patchs number:',len(train_patches[0]))

In [None]:
frames,numList=readFrames('val')
percentages=np.array(numList)
print(percentages.sum())
percentages=percentages/percentages.sum()
print('total validation img count:'+str(len(frames)))
val_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).random_generator(config.BATCH_SIZE,percentages)#, normalize = config.normalize

In [None]:
del frames,percentages

In [None]:
for _ in range(1):
    val_images, val_label = next(val_generator) 
    print(val_images.shape)
    display_images(np.concatenate((val_images,val_label), axis = -1))

In [None]:
for _ in range(1):
    train_images, real_label = next(train_generator) 
#     print(train_images.Length())
    display_images(np.concatenate((train_images,real_label), axis = -1))

### 参数初始化

In [None]:
OPTIMIZER = adaDelta
OPTIMIZER = mixed_precision.LossScaleOptimizer(OPTIMIZER)
OPTIMIZER_NAME = 'AdaDelta'

# OPTIMIZER = adam
# OPTIMIZER = mixed_precision.LossScaleOptimizer(OPTIMIZER)
# OPTIMIZER_NAME = 'adam'

In [None]:
# LOSS = tversky 
# LOSS_NAME = 'tversky'

# LOSS=focalTversky
# LOSS_NAME = 'focalTversky'

#LOSS=tf.keras.losses.BinaryCrossentropy()
#LOSS_NAME = 'bce'

# LOSS=bce_dice_loss
# LOSS_NAME = 'bce_dice_loss'

LOSS=dice_loss
LOSS_NAME = 'dice_loss'

### 模型训练

In [None]:
timestr = time.strftime("%Y%m%d-%H%M")
chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 

if not os.path.exists(config.model_path):
    os.makedirs(config.model_path)
model_name='_{}_{}_{}_{}_{}.h5'.format(timestr,OPTIMIZER_NAME,LOSS_NAME,chs,config.input_shape[0])
model_path = os.path.join(config.model_path,'lakes'+model_name)

chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a,b: a+str(b), chf, '') 
print(model_path)

In [None]:
# Define the model and compile it  
model = UNet([config.BATCH_SIZE, *config.input_shape],config.input_label_channel)
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[dice_loss, accuracy, recall, precision, IoU])

In [None]:
# Define callbacks      for the early stopping of training, LearningRateScheduler and model checkpointing 
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, TensorBoard

checkpoint = ModelCheckpoint(model_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only = False)

#reduceonplatea： It can be useful when using adam as optimizer
#Reduce learning rate when a metric has stopped improving (after some patience 个epoch, reduce by a factor of 0.33, new_lr = lr * factor). 
#cooldown: number of epochs to wait before resuming normal operation after lr has been reduced. 

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.33,
                                   patience=4, verbose=1, mode='min',
                                   min_delta=0.0001, cooldown=4, min_lr=1e-16) 

early = EarlyStopping(monitor="val_loss", mode="min", verbose=2, patience=20)


log_dir = os.path.join('./logs','UNet'+model_name)
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')

callbacks_list = [checkpoint, tensorboard, early] #reduceLROnPlat is not required with adaDelta

In [None]:
loss_history = model.fit(train_generator, 
                         steps_per_epoch=config.MAX_TRAIN_STEPS,
                         epochs=config.NB_EPOCHS, 
                         validation_data=val_generator,
                         validation_steps=config.VALID_IMG_COUNT,
                         callbacks=callbacks_list,
                         workers=1,
#                          shuffle=True,
#                          use_multiprocessing=True # the generator is not very thread safe 
                         #max_queue_size = 60,
                        )
h=loss_history.history
with open('history_{}_{}_{}_{}_{}.txt'.format(timestr,OPTIMIZER_NAME,LOSS_NAME, chs,config.input_shape[0]), 'wb') as file_pi:
    pickle.dump(h, file_pi)
plot(h,timestr, OPTIMIZER_NAME,LOSS_NAME, config.patch_size[0], config.NB_EPOCHS, config.BATCH_SIZE,chs)

In [None]:
# # 读取现有history文件
# with open('.txt','rb')as file_pi:
#     h=pickle.load(file_pi)
# print(h)

In [None]:
plot(h,timestr, OPTIMIZER_NAME,LOSS_NAME, config.patch_size[0], config.NB_EPOCHS, config.BATCH_SIZE,chs)

### 模型预测

In [None]:
frames,numList=readFrames('test')
percentages=np.array(numList)
percentages=percentages/percentages.sum()
print('total validation img count:'+str(len(frames)))
test_generator = DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None).random_generator(config.BATCH_SIZE,percentages)
print('done')

In [None]:
# Print one batch on the training/test data! 
for i in range(1):
    test_images, real_label = next(test_generator)
    #3 images per row: GSW, label, prediction
    prediction = model.predict(test_images, steps=1)
    prediction[prediction>0.5]=1
    prediction[prediction<=0.5]=0
    display_images(np.concatenate((test_images, real_label, prediction), axis = -1))# test_images( NDWI), real_label(label), prediction

### 模型精度评价

In [None]:
# Load model after training 
model_path=r'D:\lakemapping\U_Net\saved_models/UNet\lakes_20231129-0018_AdaDelta_dice_loss_0123_512.h5'
# model_path=r'D:\lakemapping\U_Net\saved_models\UNet\lakes_20231109-1134_AdaDelta_dice_loss_0123_512.h5 '
# model_path=r'D:\lakemapping\U_Net\saved_models\UNet\lakes_area550_20231113-0337_AdaDelta_dice_loss_0123_512_percentages.h5'
model = load_model(model_path, custom_objects={'dice loss': LOSS, 'accuracy':accuracy ,'recall':recall, 'precision':precision,'IoU': IoU}, compile=False) 
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=[dice_loss, accuracy,recall, precision, IoU])

#### 总体精度评价

In [None]:
frames,numList=readFrames('test')
random.shuffle(frames)
testDG=DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None)
test_patches = testDG.all_sequential_patches(config.step_size)
print('test patches number:',len(test_patches[0]))

In [None]:
i=0
print(test_patches[0][i*config.BATCH_SIZE:i*config.BATCH_SIZE+config.BATCH_SIZE].shape)
display_images(np.concatenate((test_patches[0][i*config.BATCH_SIZE:i*config.BATCH_SIZE+config.BATCH_SIZE],test_patches[1][i*config.BATCH_SIZE:i*config.BATCH_SIZE+config.BATCH_SIZE]), axis = -1))

In [None]:
i=1
prediction = model.predict(test_patches[0][i*config.BATCH_SIZE:i*config.BATCH_SIZE+config.BATCH_SIZE], steps=1)
prediction[prediction>0.5]=1
prediction[prediction<=0.5]=0
display_images(np.concatenate((test_patches[0][i*config.BATCH_SIZE:i*config.BATCH_SIZE+config.BATCH_SIZE], test_patches[1][i*config.BATCH_SIZE:i*config.BATCH_SIZE+config.BATCH_SIZE], prediction), axis = -1))

In [None]:
model.evaluate(test_patches[0],test_patches[1],config.BATCH_SIZE)

In [None]:
del frames,testDG,test_patches

In [None]:
j=0
frames=[]
path_to_write=os.path.join(config.dataset_dir,'test\\type'+str(j))
all_files = os.listdir(path_to_write)
all_files_NDWI = [fn for fn in all_files if fn.startswith(config.NDWI_fn) and fn.endswith(config.image_type)]#ndwi.png
for j, fn in enumerate(all_files_NDWI):
    f,nums = readImgs(path_to_write,fn)
    frames.append(f)
test_DGT=DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None)
test_patches_type = test_DGT.all_sequential_patches(config.step_size)

In [None]:
i=1

In [None]:
prediction = model.predict(test_patches_type[0][i*config.BATCH_SIZE:(i+1)*config.BATCH_SIZE], steps=1)
prediction[prediction>0.5]=1
prediction[prediction<=0.5]=0
display_images(np.concatenate((test_patches_type[0][i*config.BATCH_SIZE:(i+1)*config.BATCH_SIZE ], test_patches_type[1][i*config.BATCH_SIZE:i*16+16], prediction), axis = -1),titles='i='+str(i))
i=i+1

#### 分类别精度评价

In [None]:
for i in range(0,config.type_num):
    frames=[]
    path_to_write=os.path.join(config.dataset_dir,'test\\type'+str(i))
    all_files = os.listdir(path_to_write)
    all_files_NDWI = [fn for fn in all_files if fn.startswith(config.NDWI_fn) and fn.endswith(config.image_type)]#ndwi.png
    for j, fn in enumerate(all_files_NDWI):
        f,nums = readImgs(path_to_write,fn)
        frames.append(f)
    test_DGT=DataGenerator(config.input_image_channel, config.patch_size, frames, config.input_label_channel, augmenter = None)
    test_patches_type = test_DGT.all_sequential_patches(config.step_size)
    print('type{} patches number:{}'.format(i,len(test_patches_type[0])))
    model.evaluate(test_patches_type[0],test_patches_type[1],config.BATCH_SIZE)
    # del frames,test_DGT,test_patches_type

In [1]:
%load_ext tensorboard

In [3]:
%tensorboard --logdir=logs 

Reusing TensorBoard on port 6006 (pid 12364), started 0:58:25 ago. (Use '!kill 12364' to kill it.)