In [1]:
# Import
import json
import random
import numpy as np
import rasterio as rio
from pathlib import Path
from math import floor, ceil
from itertools import product
from functools import partial
from keras.models import Model
from keras.utils import Sequence
from keras.optimizers import Adam
from keras.regularizers import l2
from keras.layers import LeakyReLU
from keras.utils import to_categorical
from keras.callbacks import TensorBoard
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from rasterio import windows as rio_windows
from keras.utils.vis_utils import plot_model
from keras.layers import Input, Dense, Dropout
from keras.layers import Activation, add, multiply
from keras.layers import MaxPooling2D, SpatialDropout2D
from keras.layers import UpSampling2D, BatchNormalization
from keras.layers import Conv2D, Conv2DTranspose, Concatenate, Add

Using TensorFlow backend.


In [2]:
from keras import backend as kb
from keras.backend import int_shape

def tversky_index(
        y_true,
        y_pred,
        alpha: float = 0.5,
        beta: float = 0.5,
        eps: float = 1e-10,
        preserve_axis=(0, -1)
):
    """
    Ref A: https://arxiv.org/abs/1706.05721
    alpha = beta = 0.5 : Dice coefficient
    alpha = beta = 1   : Tanimoto coefficient (also known as Jaccard Index)
    alpha + beta = 1   : Produces set of F*-scores

    Ref B: https://arxiv.org/abs/1707.03237
    The scores should be computed for each voxel in a batch and for each
    class separately. Thus for a 4D tensor the resultant scores should be a
    2D tensor having the batch axis and label axis. Therefore for a typical
    channels last 4D tensor the axis 0 and axis -1 should be preserved (See
    default value for `preserve_axis` parameter)

    :param y_true:
    :param y_pred:
    :param alpha:
    :param beta:
    :param eps:
    :param preserve_axis:
    :return:
    """

    # assert int_shape(y_true) == int_shape(y_pred), "Shape Mismatch"

    once = kb.ones(kb.shape(y_true))
    p0 = y_pred  # probability that voxels are class i
    p1 = once - y_pred  # probability that voxels are not class i
    g0 = y_true
    g1 = once - y_true

    dims = list(range(kb.ndim(p0)))
    if isinstance(preserve_axis, int):
        preserve_axis = (preserve_axis,)
    assert isinstance(
        preserve_axis, (tuple, list)
    ) and all(
        [
            (isinstance(n, int) and ((0 <= n < kb.ndim(p0)) or (-kb.ndim(p0) <= n < 0)))
            for n in preserve_axis
         ]
    ), '`preserve_axis`: Illegal value!'
    preserve_axis = list(set(preserve_axis))
    for ax in preserve_axis:
        del dims[ax]

    numerator = kb.sum(
        x=p0 * g0,
        axis=dims
    ) + eps
    denominator = numerator + alpha * kb.sum(
        x=p0 * g1,
        axis=dims
    ) + beta * kb.sum(
        x=p1 * g0,
        axis=dims
    ) + eps

    t = numerator / denominator
    return t

def tversky_loss(
        alpha: float = 0.5,
        beta: float = 0.5,
        eps: float = 1e-10,
        along_axis=(0, -1),
        norm=True
):
    """

    :param alpha:
    :param beta:
    :param eps:
    :param along_axis:
    :param norm
    :return:
    """

    def tversky_loss_function(
            y_true,
            y_pred,
    ):
        t_values = tversky_index(
            y_true=y_true,
            y_pred=y_pred,
            alpha=alpha,
            beta=beta,
            eps=eps,
            preserve_axis=along_axis
        )
        losses = kb.ones_like(x=t_values, dtype=t_values.dtype) - t_values
        if norm:
            return kb.mean(x=loses, axis=None, keepdims=False)
        else:
            agg_loss = kb.sum(x=losses, axis=None, keepdims=False)
            return agg_loss
    return tversky_loss_function

def focal_tversky_loss(
    alpha: float = 0.5,
    beta: float = 0.5,
    gamma=0.75,
    eps: float = 1e-10,
    along_axis=(0, -1),
    norm=True
):
    def focal_tversky_loss_function(
            y_true,
            y_pred,
    ):
        t_values = tversky_index(
            y_true=y_true,
            y_pred=y_pred,
            alpha=alpha,
            beta=beta,
            eps=eps,
            preserve_axis=along_axis
        )
        losses = kb.pow((kb.ones_like(x=t_values, dtype=t_values.dtype) - t_values), gamma)
        if norm:
            return kb.mean(x=losses, axis=None, keepdims=False)
        else:
            agg_loss = kb.sum(x=losses, axis=None, keepdims=False)
            return agg_loss
    return focal_tversky_loss_function

In [3]:
# Define Parameters
droprate = 0.3
image_height = 6000
image_width = 6000
window_height = 512
window_width = 512
min_height_overlap = 32
min_width_overlap = 32
boundless_flag = True
class_count = 6
data_shuffle = True
batchsize=1
bands = (1, 2, 3, 4, 5)
image_features = len(bands)

In [4]:
geo_augs = (
    # ((lambda m : m), True),
    (partial(np.rot90, k=1, axes=(1, 2)), True),
    (partial(np.rot90, k=2, axes=(1, 2)), True),
    (partial(np.rot90, k=3, axes=(1, 2)), True),
    (partial(np.flip, axis=1), True),
    (partial(np.flip, axis=2), True)
)

In [5]:
# Input Data
config_dir = Path('Configs')
train_config = config_dir / 'Train_Map.json'
valid_config = config_dir / 'Validation_Map.json'
test_config = config_dir / 'Test_Map.json'
model_dir = Path("Models")
mplot = model_dir / "Model_Plot.png"
model_max_accuracy = model_dir / 'Model_MaxAccuracy.h5' 
model_min_loss = model_dir / 'Model_MinLoss.h5'
log_d = Path('Logs')

with open(train_config.as_posix(), 'r') as tm:
    train_map = json.load(tm)

with open(valid_config.as_posix(), 'r') as tm:
    valid_map = json.load(tm)

In [6]:
def generate_windows(img_height, img_width, win_height, win_width, min_hoverlap, min_woverlap, boundless=False):
    hc = ceil((img_height - min_hoverlap) / (win_height - min_hoverlap))
    wc = ceil((img_width - min_woverlap) / (win_width - min_woverlap))
    
    
    h_overlap = ((hc * win_height) - img_height) // (hc - 1)
    w_overlap = ((wc * win_height) - img_width) // (wc - 1)
    
    
    hslack_res = ((hc * win_height) - img_height) % (hc - 1)
    wslack_res = ((wc * win_width) - img_width) % (wc - 1)
    
    dh = win_height - h_overlap
    dw = win_width - w_overlap
    
    row_offsets = np.arange(0, (img_height-h_overlap), dh)
    col_offsets = np.arange(0, (img_width-w_overlap), dw)
    
    if hslack_res > 0:
        row_offsets[-hslack_res:] -= np.arange(1, (hslack_res + 1), 1)
    if wslack_res > 0:
        col_offsets[-wslack_res:] -= np.arange(1, (wslack_res + 1), 1)
    
    row_offsets = row_offsets.tolist()
    col_offsets = col_offsets.tolist()
    
    offsets = product(col_offsets, row_offsets)
    
    indices = product(range(len(col_offsets)), range(len(row_offsets)))
    
    big_window = rio_windows.Window(col_off=0, row_off=0, width=img_width, height=img_height)
    
    for index, (col_off, row_off) in zip(indices, offsets):
        window = rio_windows.Window(
            col_off=col_off,
            row_off=row_off,
            width=win_width,
            height=win_height
        )
        if boundless:
            yield index, window
        else:
            yield index, window.intersection(big_window)

In [7]:
class RasterDataGenerator(Sequence):
    def __init__(
        self,  
        map_dict,
        channels,
        img_height,
        img_width,
        win_height,
        win_width,
        min_hoverlap,
        min_woverlap,
        cls_count,
        augs=None, # list of tuples like (fn, flag), fn: aug function works on channel first image, flag: wheather fn should be applied on labels
        boundless=False,
        shuffle=True,
        batch_size=1,
    ):
        assert isinstance(map_dict, dict), 'Invalid type for parameter <map_dict>, expected type `dict`!'
        assert all([set(map_dict[k].keys()) == {'IMAGE', 'LABEL'} for k in map_dict.keys()]), "Invalid map <dict_map>, Key Mismatch!"
        if augs is None:
            augs = (((lambda m : m), True),)
        else:
            assert isinstance(augs, (tuple, list)) and all(
                [callable(fn) and isinstance(flag, bool) for (fn, flag) in augs]
            )
        
        couples =  [(Path(couple['IMAGE']).as_posix(), Path(couple['LABEL']).as_posix()) for couple in map_dict.values()]
        
        windows = list(
            generate_windows(
                img_height=img_height,
                img_width=img_width,
                win_height=win_height,
                win_width=win_width,
                min_hoverlap=min_hoverlap,
                min_woverlap=min_woverlap,
                boundless=boundless
            )
        )
        dat = list(product(couples, windows, augs))
        if shuffle:
            random.shuffle(dat)
        self.data = dat
        self.channels = channels
        self.class_count = cls_count
        self.batch_size = batch_size
    
    def __len__(self):
        return int(np.ceil(len(self.data) / float(self.batch_size)))
    
    def __getitem__(self, idx):
        current_batch = self.data[idx * self.batch_size:(idx + 1) * self.batch_size]
        islices = list()
        lslices = list()
        for (im, lb), (_, w), (aug, af) in current_batch:
            with rio.open(im, 'r') as isrc:
                islice = isrc.read(indexes=self.channels, window=w, boundless=boundless_flag, masked=False)
                islice = aug(islice)
                islice = np.moveaxis(a=islice, source=0, destination=-1)
                islices.append(islice)
            with rio.open(lb, 'r') as lsrc:
                lslice = lsrc.read(window=w, boundless=boundless_flag, masked=False)
                if af is True:
                    lslice = aug(lslice)
                lslice = np.moveaxis(a=lslice, source=0, destination=-1)
                lslice =to_categorical(
                    y=(lslice-1), 
                    num_classes=self.class_count
                )
                lslices.append(lslice)
        ibatch = np.stack(islices, axis=0)
        lbatch = np.stack(lslices, axis=0)
        return ibatch, lbatch

In [8]:
def make_features(input_shape, f=128):
    input_layer = Input(shape=(None, None, image_features))
    
    c1 = Conv2D(
        filters=128, 
        kernel_size=(3, 3), 
        padding='same', 
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(input_layer)
    c2 = Conv2D(
        filters=128, 
        kernel_size=(3, 3), 
        padding='same', 
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(c1)
    return input_layer, c2

def res_block(p_layer, f=128):
    l1 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        dilation_rate=2,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(p_layer)
    l1 = BatchNormalization()(l1)
    l2 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        dilation_rate=2,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(l1)
    l2 = BatchNormalization()(l2)
    l3 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        dilation_rate=2,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(l2)
    l3 = BatchNormalization()(l3)
    
    a = Add()([p_layer, l3])
    res = Activation('relu')(a)
    return res
    
def branch_block(parent_layer, f=128):
    m1 = MaxPooling2D(
        pool_size=(2, 2),
        padding='same',
        data_format='channels_last',
    )(parent_layer)
    c3 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        activation='relu',
        padding='same', 
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(m1)
    c3 = BatchNormalization()(c3)
    c4 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        activation='relu',
        padding='same', 
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(c3)
    c4 = BatchNormalization()(c4)
    return c4

def downsample(p_layer, f=128, scale=2):
    cd1 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        strides=scale,
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(p_layer)
    cd1 = BatchNormalization()(cd1)
    c5 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(cd1)
    c5 = BatchNormalization()(c5)
    c6 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same',
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(c5)
    c6 = BatchNormalization()(c6)
    return c6

def stack_block(main_layer, parent_layer, f=128, scale=2):
    b1 = branch_block(parent_layer=parent_layer, f=f)
    b2 = downsample(p_layer=main_layer, f=f, scale=scale)
    a1 = Add()([b1, b2])
    res_b = res_block(p_layer=a1, f=f)
    return res_b

def stack_mid(parent):
    f = int_shape(parent)[-1]
    r1 = res_block(p_layer=parent, f=f)
    r2 = res_block(p_layer=r1, f=f)
    r3 = res_block(p_layer=r2, f=f)
    
    d1 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        dilation_rate=1,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(r3)
    d1 = BatchNormalization()(d1)
    d2 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        dilation_rate=2,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(r3)
    d2 = BatchNormalization()(d2)
    d3 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        dilation_rate=4,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(r3)
    d3 = BatchNormalization()(d3)
    d4 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        dilation_rate=8,
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(r3)
    d4 = BatchNormalization()(d4)
    
    c = Concatenate(axis=-1)([d1, d2, d3, d4])
    return c

def up_merge(xx_lyr, yy_lyr, f=128):
    c11 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(xx_lyr)
    c11 = BatchNormalization()(c11)

    c12 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(c11)
    c12 = BatchNormalization()(c12)
    
    upl = UpSampling2D(size=2, interpolation='bilinear')(c12)
    
    cat = Concatenate(axis=-1)([upl, yy_lyr])
    return cat

def final_blk(xx_lyr, f=128, count=4):
    c11 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(xx_lyr)
    c11 = BatchNormalization()(c11)

    c12 = Conv2D(
        filters=f, 
        kernel_size=(3, 3), 
        padding='same', 
        activation='relu',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(c11)
    c12 = BatchNormalization()(c12)
    
    cx = Conv2D(
        filters=count, 
        kernel_size=(1, 1), 
        padding='same', 
        activation='softmax',
        data_format='channels_last',
        kernel_initializer='glorot_uniform',
    )(c11)
    return cx

In [9]:
xx, yy = make_features(input_shape=(512, 512, image_features), f=64)
lyr1 = stack_block(main_layer=xx, parent_layer=yy, f=128, scale=2)
lyr2 = stack_block(main_layer=xx, parent_layer=lyr1, f=256, scale=4)
lyr3 = stack_block(main_layer=xx, parent_layer=lyr2, f=512, scale=8)
mid = stack_mid(parent=lyr3)
lyr4 = up_merge(xx_lyr=mid, yy_lyr=lyr2, f=512)
lyr5 = up_merge(xx_lyr=lyr4, yy_lyr=lyr1, f=256)
lyr6 = up_merge(xx_lyr=lyr5, yy_lyr=yy, f=128)
zz = final_blk(xx_lyr=lyr6, f=64, count=class_count)

In [10]:
xnet_classifier = Model(inputs=xx, outputs=zz)
xnet_classifier.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, None, 5 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 1 5888        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 1 147584      conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, None, None, 1 5888        input_1[0][0]                    
____________________________________________________________________________________________

In [11]:
# xnet_classifier = xnet(
#     input_shape=(None, None, image_features),
#     num_classes=class_count,
#     activation="relu",
#     use_batch_norm=True,
#     upsample_mode="deconv",  # 'deconv' or 'simple'
#     dropout=droprate,
#     dropout_change_per_layer=0.0,
#     dropout_type="standard",
#     use_dropout_on_upsampling=False,
#     use_attention=True,
#     filters=64,
#     num_layers=4,
#     output_activation="softmax",
# )
# xnet_classifier.summary(line_length=116)
# plot_model(xnet_classifier, to_file=mplot, show_shapes=True, show_layer_names=True)

In [12]:
es_val_loss = EarlyStopping(monitor='val_loss', mode='min', patience=5, verbose=1)
# es_val_accu = EarlyStopping(monitor='val_accuracy', mode='max', min_delta=0.001)
mc_val_accu = ModelCheckpoint(str(model_max_accuracy.absolute()), monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
mc_val_loss = ModelCheckpoint(str(model_min_loss.absolute()), monitor='val_loss', mode='min', verbose=1, save_best_only=True)

tb = TensorBoard(
    log_dir=log_d, 
    histogram_freq=1, 
    write_graph=True, 
    write_images=True,
    update_freq='batch', 
    embeddings_freq=0,
    embeddings_metadata=None
)

# Load Pre Trained Weights if any
# if model_max_accuracy.is_file():
#     xnet_classifier.load_weights(str(model_max_accuracy))

xnet_classifier.compile(
    optimizer=Adam(learning_rate=1e-4), 
    # loss='categorical_crossentropy', 
    loss=focal_tversky_loss(),
    metrics=['accuracy']
)

In [13]:
train_generator = RasterDataGenerator( 
    map_dict=train_map,
    channels=bands,
    img_height=image_height,
    img_width=image_width,
    win_height=window_height,
    win_width=window_width,
    min_hoverlap=min_height_overlap,
    min_woverlap=min_width_overlap,
    cls_count=class_count,
    boundless=boundless_flag,
    # augs=geo_augs,
    augs=None,
    shuffle=data_shuffle,
    batch_size=batchsize
)
valid_generator = RasterDataGenerator(
    map_dict=valid_map,
    channels=bands,
    img_height=image_height,
    img_width=image_width,
    win_height=window_height,
    win_width=window_width,
    min_hoverlap=1,
    min_woverlap=1,
    cls_count=class_count,
    augs=None,
    boundless=boundless_flag,
    shuffle=data_shuffle,
    batch_size=batchsize
)

In [14]:
# Train
xnet_classifier.fit_generator(
    generator=train_generator, 
    epochs=50, 
    validation_data=valid_generator,
    use_multiprocessing=False,
    callbacks=[
        tb,
        es_val_loss,
        mc_val_loss,
#         es_val_accu,
        mc_val_accu,
        
    ]
)

Epoch 1/50

Epoch 00001: val_loss improved from inf to 0.82369, saving model to D:\UNet\Models\Model_MinLoss.h5

Epoch 00001: val_accuracy improved from -inf to 0.74333, saving model to D:\UNet\Models\Model_MaxAccuracy.h5
Epoch 2/50

Epoch 00002: val_loss did not improve from 0.82369

Epoch 00002: val_accuracy improved from 0.74333 to 0.75719, saving model to D:\UNet\Models\Model_MaxAccuracy.h5
Epoch 3/50

Epoch 00003: val_loss did not improve from 0.82369

Epoch 00003: val_accuracy improved from 0.75719 to 0.76395, saving model to D:\UNet\Models\Model_MaxAccuracy.h5
Epoch 4/50

Epoch 00004: val_loss improved from 0.82369 to 0.78862, saving model to D:\UNet\Models\Model_MinLoss.h5

Epoch 00004: val_accuracy improved from 0.76395 to 0.80164, saving model to D:\UNet\Models\Model_MaxAccuracy.h5
Epoch 5/50

Epoch 00005: val_loss did not improve from 0.78862

Epoch 00005: val_accuracy improved from 0.80164 to 0.83545, saving model to D:\UNet\Models\Model_MaxAccuracy.h5
Epoch 6/50

Epoch 000

<keras.callbacks.callbacks.History at 0x20f02dad548>