Author: Ankit Kariryaa, University of Bremen.

Modified by Xuehui Pi, Qiuqi Luo and Beihui Hu

In [None]:
from tensorflow.keras.models import load_model
from tensorflow.keras import mixed_precision 
mixed_precision.set_global_policy('mixed_float16')
import rasterio                  # I/O raster data (netcdf, height, geotiff, ...)
import rasterio.warp             # Reproject raster samples
from rasterio import windows
import numpy as np               # numerical array manipulation
import os
from tqdm import tqdm
from itertools import product

from core.losses import accuracy, dice_loss, IoU, recall, precision,F1_score
from core.optimizers import adaDelta

%matplotlib inline
import warnings                  # ignore annoying warnings
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
import tensorflow as tf
print(tf.__version__)

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto(
    #device_count={"CPU": 64},
    allow_soft_placement=True, 
    log_device_placement=False)
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# Required configurations (including the input and output paths) are stored in a separate file (such as config/RasterAnalysis.py)
# Please provide required info in the file before continuing with this notebook. 
 
from config import RasterAnalysis_bands
# In case you are using a different folder name such as configLargeCluster, then you should import from the respective folder 
# Eg. from configLargeCluster import RasterAnalysis
config = RasterAnalysis_bands.Configuration()

In [None]:
# Load a pretrained model 
OPTIMIZER = adaDelta
OPTIMIZER=mixed_precision.LossScaleOptimizer(OPTIMIZER)
model = load_model(config.trained_model_path, custom_objects={'dice loss': dice_loss, 'accuracy':accuracy ,'recall':recall, 'F1_score':F1_score,'precision':precision,'IoU': IoU}, compile=False)
model.compile(optimizer=OPTIMIZER, loss=dice_loss, metrics=[dice_loss, accuracy,recall, precision, IoU])

In [None]:
# Methods to add results of a patch   to    the total results of a larger area. 
#The operator could be min (useful if there are too many false positives), max (useful for tackle false negatives)
#res:mask [rows,cols] predition=np.squeeze(prediction[i], axis = -1) (col, row, wi, he) = batch_pos[i]
def addTOResult(res, prediction, row, col, he, wi,ie_width,operator):
    currValue = res[row+ie_width:row+he-ie_width, col+ie_width:col+wi-ie_width]
    newPredictions = prediction[ie_width:he-ie_width, ie_width:wi-ie_width]
    
# IMPORTANT: MIN can't be used as long as the mask is initialed with 0!!!!! 
#If you want to use MIN initial the mask with -1 and handle the case of default value(-1) separately.
    if operator == 'min': # Takes the min of current prediction and new prediction for each pixel  
        currValue [currValue == -1] = 1 #Replace -1 with 1 in case of MIN  
        res[row+ie_width:row+he-ie_width, col+ie_width:col+wi-ie_width] = np.minimum(currValue, newPredictions)
    elif operator == 'max':
        res[row+ie_width:row+he-ie_width, col+ie_width:col+wi-ie_width] = np.maximum(currValue, newPredictions)
    else:
        res[row+ie_width:row+he-ie_width, col+ie_width:col+wi-ie_width] = newPredictions  
#     print(res.max())
#     print(newPredictions.max())
    return (res)

In [None]:
# Methods that actually makes the predictions
def predict_using_model(model, batch, batch_pos, mask,ie_width,operator):
    tm = np.stack(batch, axis = 0)
    prediction = model.predict(tm)
#     fn=r'D:\lakemapping\2_dataset\test\test.png'
#     display_images(np.concatenate((tm, prediction), axis = -1),fn=fn)
    for i in range(len(batch_pos)): 
        (col, row, wi, he) = batch_pos[i]
        p = np.squeeze(prediction[i], axis = -1)
        # Instead of replacing the current values with new values, use the user specified operator (MIN,MAX,REPLACE)
        mask = addTOResult(mask, p, row, col, he, wi,ie_width,operator)  
    return mask

def detect_lake(image ,width=576, height=576, stride = 376,ie_width=100,bandNum=5):
    nols, nrows = image.meta['width'], image.meta['height']
    meta = image.meta.copy() 
    if 'float' not in meta['dtype']: #The prediction is a float so we keep it as float to be consistent with the prediction. 
        meta['dtype'] = np.float32
    meta['count'] = 1
    col_index=list(range(0, nols-width, stride))
    col_index.append(nols-width)
    row_index=list(range(0, nrows-height, stride))
    row_index.append(nrows-height)
    offsets = product(col_index,row_index)
    big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
    print('the size of current image',nrows, nols) 
    mask = np.zeros((nrows, nols), dtype=meta['dtype'])

#     mask = mask -1   # Note: The initial mask is initialized with -1 instead of zero   to handle the MIN case (see addToResult)
    batch = []
    batch_pos = [ ]
    for col_off, row_off in  tqdm(offsets):
        window = windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
        transform = windows.transform(window, image.transform) 
        patch = np.full((height, width, bandNum),-1.0)#Add -1 padding in case of corner images
        read_img = image.read(window=window)/1000
        read_img =  np.transpose(read_img, axes=(1,2,0))

        patch[:window.height, :window.width] = read_img   
        batch.append(patch)
        batch_pos.append((window.col_off, window.row_off, window.width, window.height))
        
        if (len(batch) == config.BATCH_SIZE):
            mask = predict_using_model(model, batch, batch_pos, mask,ie_width,'max')
#             print(mask.max()) 
            batch = []
            batch_pos = []   
    # To handle the edge of images as the image size may not be divisible by n complete batches and few frames on the edge may be left.
    if batch:
        mask = predict_using_model(model, batch, batch_pos, mask,ie_width,'max')
        batch = []
        batch_pos = []

    return(mask, meta)

In [None]:
def writeMaskToDisk(detected_mask, detected_meta, wp, write_as_type = 'uint8', th = 0.5, create_countors = False):
    # Convert to correct required before writing
    if 'float' in str(detected_meta['dtype']) and 'int' in write_as_type:
        print(f'Converting prediction from {detected_meta["dtype"]} to {write_as_type}, using threshold of {th}')#float32 to uint8
#         initial code have problem of big lake
        detected_mask[detected_mask<th]=0
        detected_mask[detected_mask>=th]=1
        detected_mask = detected_mask.astype(write_as_type)#'uint8'
        detected_meta['dtype'] =  write_as_type
    
    # compress tif
    detected_meta.update({"compress": 'lzw'})
    
    with rasterio.open(wp, 'w', **detected_meta) as outds:
        outds.write(detected_mask, 1)

In [None]:
all_files = []
for root, dirs, files in os.walk(config.input_image_dir):
    for file in files:
        if file.endswith(config.input_image_type) and file.startswith(config.image_fn_st):
             all_files.append((os.path.join(root, file), file))
# print(all_files)

for fullPath, filename in all_files:
    outputFile=os.path.join(config.output_dir,filename.replace(config.image_fn_st, config.output_prefix) )
    if not os.path.isfile(outputFile) or config.overwrite_analysed_files: 
        with rasterio.open(fullPath) as image:
            print(fullPath)
            detectedMask, detectedMeta = detect_lake(image,width = config.WIDTH, height = config.HEIGHT, stride = config.STRIDE,ie_width=config.ignore_edge_width,bandNum=config.band_num)
            writeMaskToDisk(detectedMask, detectedMeta, outputFile, write_as_type = config.output_dtype, th = 0.5, create_countors = False)            
        print('File already analysed!', fullPath)
        
print('finish')