In [None]:
import numpy as np
from math import ceil
import rasterio as rio
from pathlib import Path
from itertools import product
from tqdm.notebook import tqdm
from keras.models import load_model
from rasterio import windows as rio_windows

In [None]:
valid_dir = Path(r"Data/Train_Data/Images")
model_path = Path(r"Models/Model_MaxAccuracy.h5")
pred_dir = Path(r"Data/Train_Data/Prediction")

In [None]:
# colors = np.array(
#     [
#         [255, 255, 255],
#         [0, 0, 255],
#         [0, 255, 255],
#         [0, 255, 0],
#         [255, 255, 0],
#         [255, 0, 0]
#     ]
# )
color_map = {
    0: np.array([255, 255, 255]),
    1: np.array([0, 0, 255]),
    2: np.array([0, 255, 255]),
    3: np.array([0, 255, 0]),
    4: np.array([255, 255, 0]),
    5: np.array([255, 0, 0])
}

In [None]:
def generate_windows(img_height, img_width, win_height, win_width, min_hoverlap, min_woverlap, boundless=False):
    hc = ceil((img_height - min_hoverlap) / (win_height - min_hoverlap))
    wc = ceil((img_width - min_woverlap) / (win_width - min_woverlap))
    
    
    h_overlap = ((hc * win_height) - img_height) // (hc - 1)
    w_overlap = ((wc * win_height) - img_width) // (wc - 1)
    
    
    hslack_res = ((hc * win_height) - img_height) % (hc - 1)
    wslack_res = ((wc * win_width) - img_width) % (wc - 1)
    
    dh = win_height - h_overlap
    dw = win_width - w_overlap
    
    row_offsets = np.arange(0, (img_height-h_overlap), dh)
    col_offsets = np.arange(0, (img_width-w_overlap), dw)
    
    if hslack_res > 0:
        row_offsets[-hslack_res:] -= np.arange(1, (hslack_res + 1), 1)
    if wslack_res > 0:
        col_offsets[-wslack_res:] -= np.arange(1, (wslack_res + 1), 1)
    
    row_offsets = row_offsets.tolist()
    col_offsets = col_offsets.tolist()
    
    offsets = product(col_offsets, row_offsets)
    
    indices = product(range(len(col_offsets)), range(len(row_offsets)))
    
    big_window = rio_windows.Window(col_off=0, row_off=0, width=img_width, height=img_height)
    
    for index, (col_off, row_off) in zip(indices, offsets):
        window = rio_windows.Window(
            col_off=col_off,
            row_off=row_off,
            width=win_width,
            height=win_height
        )
        if boundless:
            yield index, window
        else:
            yield index, window.intersection(big_window)

In [None]:
from keras import backend as kb
from keras.backend import int_shape

def tversky_index(
        y_true,
        y_pred,
        alpha: float = 0.5,
        beta: float = 0.5,
        eps: float = 1e-10,
        preserve_axis=(0, -1)
):
    """
    Ref A: https://arxiv.org/abs/1706.05721
    alpha = beta = 0.5 : Dice coefficient
    alpha = beta = 1   : Tanimoto coefficient (also known as Jaccard Index)
    alpha + beta = 1   : Produces set of F*-scores

    Ref B: https://arxiv.org/abs/1707.03237
    The scores should be computed for each voxel in a batch and for each
    class separately. Thus for a 4D tensor the resultant scores should be a
    2D tensor having the batch axis and label axis. Therefore for a typical
    channels last 4D tensor the axis 0 and axis -1 should be preserved (See
    default value for `preserve_axis` parameter)

    :param y_true:
    :param y_pred:
    :param alpha:
    :param beta:
    :param eps:
    :param preserve_axis:
    :return:
    """

    # assert int_shape(y_true) == int_shape(y_pred), "Shape Mismatch"

    once = kb.ones(kb.shape(y_true))
    p0 = y_pred  # probability that voxels are class i
    p1 = once - y_pred  # probability that voxels are not class i
    g0 = y_true
    g1 = once - y_true

    dims = list(range(kb.ndim(p0)))
    if isinstance(preserve_axis, int):
        preserve_axis = (preserve_axis,)
    assert isinstance(
        preserve_axis, (tuple, list)
    ) and all(
        [
            (isinstance(n, int) and ((0 <= n < kb.ndim(p0)) or (-kb.ndim(p0) <= n < 0)))
            for n in preserve_axis
         ]
    ), '`preserve_axis`: Illegal value!'
    preserve_axis = list(set(preserve_axis))
    for ax in preserve_axis:
        del dims[ax]

    numerator = kb.sum(
        x=p0 * g0,
        axis=dims
    ) + eps
    denominator = numerator + alpha * kb.sum(
        x=p0 * g1,
        axis=dims
    ) + beta * kb.sum(
        x=p1 * g0,
        axis=dims
    ) + eps

    t = numerator / denominator
    return t

def tversky_loss(
        alpha: float = 0.5,
        beta: float = 0.5,
        eps: float = 1e-10,
        along_axis=(0, -1),
        norm=True
):
    """

    :param alpha:
    :param beta:
    :param eps:
    :param along_axis:
    :param norm
    :return:
    """

    def tversky_loss_function(
            y_true,
            y_pred,
    ):
        t_values = tversky_index(
            y_true=y_true,
            y_pred=y_pred,
            alpha=alpha,
            beta=beta,
            eps=eps,
            preserve_axis=along_axis
        )
        losses = kb.ones_like(x=t_values, dtype=t_values.dtype) - t_values
        agg_loss = kb.sum(x=losses, axis=None, keepdims=False)
        batch_size = int_shape(y_true)[0]
        if norm:
            return agg_loss / batch_size
        else:
            return agg_loss
    return tversky_loss_function

def focal_tversky_loss(
    alpha: float = 0.5,
    beta: float = 0.5,
    gamma=0.75,
    eps: float = 1e-10,
    along_axis=(0, -1),
    norm=True
):
    def focal_tversky_loss_function(
            y_true,
            y_pred,
    ):
        t_values = tversky_index(
            y_true=y_true,
            y_pred=y_pred,
            alpha=alpha,
            beta=beta,
            eps=eps,
            preserve_axis=along_axis
        )
        losses = kb.pow((kb.ones_like(x=t_values, dtype=t_values.dtype) - t_values), gamma)
        if norm:
            return kb.mean(x=losses, axis=None, keepdims=False)
        else:
            agg_loss = kb.sum(x=losses, axis=None, keepdims=False)
            return agg_loss
    return focal_tversky_loss_function

In [None]:
assert model_path.is_file()
# trained_model = load_model(str(model_path))
cl = focal_tversky_loss()
trained_model = load_model(str(model_path), custom_objects={cl.__name__: cl})

In [None]:
for imf in tqdm(list(valid_dir.glob('*.tif'))):
    with rio.open(imf, 'r') as src:
        wins = list(
            generate_windows(
                img_height=src.height,
                img_width=src.width,
                win_height=512,
                win_width=512,
                min_hoverlap=32,
                min_woverlap=32,
            )
        )
        dst_path = pred_dir / ('Prediction_' + imf.name)
        meta = src.meta.copy()
        # meta['count'] = 1
        meta['count'] = 3
        meta['dtype'] = np.uint8
        with rio.open(dst_path, 'w', **meta) as dst:
            for _, w in tqdm(wins):
                im_array = np.stack([np.moveaxis(src.read(window=w), 0, -1),], axis=0)
                pred = trained_model.predict(
                    im_array
                )
                # pred_array = np.stack([np.argmax(pred[0], axis=-1).astype(np.uint8),], axis=0)
                pred_array = np.argmax(pred[0], axis=-1).astype(np.uint8)
                lookup = np.zeros((255, 3), dtype=np.uint8)
                clist = list()
                for u in np.unique(pred_array, return_counts=False):
                    clist.append(color_map[u])
                colors = np.stack(clist, axis=0)
                lookup[np.unique(pred_array)] = colors
                rgb_img = lookup[pred_array]
                rgb_img = np.moveaxis(rgb_img, -1, 0)
                # dst.write(pred_array, window=w)
                dst.write(rgb_img, window=w)