In [None]:
import os
import matplotlib.pyplot as plt
import rasterio
import pandas as pd
import numpy as np
from pathlib import Path
import math
import maxflow as mf
import dask
import scipy.ndimage
import skimage.morphology
from concurrent.futures import ThreadPoolExecutor

#import settings.by_dop80c_1312.detectree_r1_deepforest_r1.setting as setting
import settings.opendata_luftbild_dop60_1312.detectree_r1_deepforest_r1.setting as setting

In [None]:
CONTINUE_MODE = True  # ignore already processed images
SAVE_PROBA = False
SAVE_IMAGE = True

DETECTREE_PROBA_RESULT_PATH = setting.POSTPROCESS_SRC_DETECTREE_PREDICT_RESULT_PATH.joinpath('p')
DEEPFOREST_BBOX_RESULT_PATH = setting.POSTPROCESS_SRC_DEEPFOREST_PREDICT_RESULT_PATH.joinpath('b')

OUTPUT_PROBA_RESULT_PATH = setting.DATASET_PREDICT_RESULT_PATH.joinpath('p')
OUTPUT_IMAGE_RESULT_PATH = setting.DATASET_PREDICT_RESULT_PATH.joinpath('r')

MODEL_PARAMS_CLASSIFIER = setting.MODEL_PARAMS_CLASSIFIER

FUSE_CONFIG = setting.POSTPROCESS_CONFIG

CONCURRENCY = 4

os.makedirs(OUTPUT_PROBA_RESULT_PATH, exist_ok=True)
os.makedirs(OUTPUT_IMAGE_RESULT_PATH, exist_ok=True)

print("detectree:", DETECTREE_PROBA_RESULT_PATH)
print("deepforest:", DEEPFOREST_BBOX_RESULT_PATH)
print("output:", setting.DATASET_PREDICT_RESULT_PATH)
print("fuse:", FUSE_CONFIG)
print("clf:", MODEL_PARAMS_CLASSIFIER)

In [None]:
# dilation -> closing -> fill holes -> erosion
def refine_morph(image, morph_size, tree_val, nontree_val):
  labels = image[0,:,:,] == tree_val # turn to boolean type ndarray
  refined = scipy.ndimage.morphology.binary_dilation(labels, 
                    structure=skimage.morphology.disk(morph_size))
  refined = scipy.ndimage.morphology.binary_closing(refined, 
                    structure=skimage.morphology.disk(morph_size*2))
  refined = scipy.ndimage.morphology.binary_fill_holes(refined, 
                    structure=skimage.morphology.disk(int(morph_size/2)))
  refined = scipy.ndimage.morphology.binary_erosion(refined, 
                    structure=skimage.morphology.disk(morph_size))
  return np.where(refined == True, tree_val, nontree_val).reshape(1, image.shape[1], image.shape[2])

# max-flow min0cut refinement
def refine_mfmc(img_shape, p_tree, refine_int_rescale, refine_beta, tree_val, nontree_val):
  g = mf.Graph[int]()
  node_ids = g.add_grid_nodes(img_shape)
  D_tree = (refine_int_rescale * np.log(1.0-p_tree)).astype(int)
  D_nontree = (refine_int_rescale * np.log(p_tree)).astype(int)
  MOORE_NEIGHBORHOOD_ARR = np.array([[0, 0, 0], [0, 0, 1], [1, 1, 1]])
  g.add_grid_edges(node_ids, refine_beta,
                    structure=MOORE_NEIGHBORHOOD_ARR)
  g.add_grid_tedges(node_ids, D_tree, D_nontree)
  g.maxflow()
  refined = np.full(img_shape, nontree_val)
  refined[g.get_grid_segments(node_ids)] = tree_val
  return refined
  
def fuse(src_proba_path, src_bbox_path, output_proba_path, output_img_path, 
         radius_min_thres, radius_max_thres, score_min_thres, morph_size,
         bias_tree_factor, bias_nontree_factor):
  save_proba = SAVE_PROBA
  save_img = SAVE_IMAGE

  if save_proba and output_proba_path.exists() and CONTINUE_MODE:
    save_proba = False
  if save_img and output_img_path.exists() and CONTINUE_MODE:
    save_img = False

  if not save_proba and not save_img:
    print("PASS:", src_proba_path.stem)
    return

  # load bbox calculated from deepforest
  bbox_df = pd.read_pickle(src_bbox_path).sort_values(by='score', ascending=True) if src_bbox_path.exists() else None

  # load image of tree pixel probability from detectree
  img_proba, img_transform, img_crs  = None, None, None
  with rasterio.open(src_proba_path) as src:
    img_transform = src.transform
    img_crs = src.crs
    img_proba = src.read()

  # approximate an image of tree pixel probability according to bbox
  if bbox_df is not None:
    mask_score = np.zeros(img_proba.shape, dtype=img_proba.dtype)
    for _, box in bbox_df.iterrows():
      xmin, xmax, ymin, ymax = box.xmin, box.xmax, box.ymin, box.ymax
      score = box.score
      center_r, center_c = int((ymin + ymax)/2), int((xmin + xmax)/2)
      r = math.ceil(min(ymax-ymin, xmax-xmin) / 2)
      if r > radius_min_thres and r < radius_max_thres:
        y, x = np.ogrid[-center_r:img_proba.shape[1]-center_r, -center_c:img_proba.shape[2]-center_c]
        mask = x*x + y*y <= r*r
        mask_score[:,mask] = score
    # fuse probability and score to get a smooth coverage
    mask_score -= score_min_thres
    img_proba = np.where(mask_score >= 0, 
                        img_proba + (img_proba.max()-img_proba) * mask_score * bias_tree_factor, 
                        img_proba + (img_proba-img_proba.min()) * mask_score * bias_nontree_factor)
    np.clip(img_proba, 0.0, 1.0, out=img_proba)
  if save_proba: # float 1 channel image
    with rasterio.open(output_proba_path, 'w', driver='GTiff',
                  width=img_proba.shape[2], height=img_proba.shape[1],
                  count=1, dtype=img_proba.dtype, nodata=0,
                  transform=img_transform, crs=img_crs) as dst:
      dst.write(img_proba.astype(img_proba.dtype))

  if save_img: # uint8 1 channel image
    #img_refined = refine_naive(img_proba, score_min_thres, MODEL_PARAMS_CLASSIFIER['tree_val'], MODEL_PARAMS_CLASSIFIER['nontree_val'])
    img_refined = refine_mfmc(img_shape=img_proba.shape, p_tree=img_proba, 
                      refine_int_rescale=MODEL_PARAMS_CLASSIFIER['refine_int_rescale'],
                      refine_beta=MODEL_PARAMS_CLASSIFIER['refine_beta'],
                      tree_val=MODEL_PARAMS_CLASSIFIER['tree_val'], 
                      nontree_val=MODEL_PARAMS_CLASSIFIER['nontree_val'])
    #img_refined = refine_morph(img_refined, morph_size=morph_size,
    #                  tree_val=MODEL_PARAMS_CLASSIFIER['tree_val'], 
    #                  nontree_val=MODEL_PARAMS_CLASSIFIER['nontree_val'])
    with rasterio.open(output_img_path, 'w', driver='GTiff',
                    width=img_proba.shape[2], height=img_proba.shape[1],
                    count=1, dtype=rasterio.uint8, nodata=0,
                    transform=img_transform, crs=img_crs) as dst:
      dst.write(img_refined.astype(np.uint8))
  
  print("PROCESSED:", src_proba_path.stem)

tasks = []
for src_proba_path in DETECTREE_PROBA_RESULT_PATH.glob('*.tiff'):
  src_bbox_path = DEEPFOREST_BBOX_RESULT_PATH.joinpath(src_proba_path.stem + ".pkl")
  output_proba_path = OUTPUT_PROBA_RESULT_PATH.joinpath(src_proba_path.stem + ".tiff")
  output_img_path = OUTPUT_IMAGE_RESULT_PATH.joinpath(src_proba_path.stem + ".tiff")
  delayed = dask.delayed(fuse)(src_proba_path, src_bbox_path, output_proba_path, output_img_path, **FUSE_CONFIG)
  tasks.append(delayed)
with dask.config.set(pool=ThreadPoolExecutor(CONCURRENCY)):  
  dask.compute(*tasks)

