In [None]:
# Import
import json
import random
import numpy as np
import rasterio as rio
from pathlib import Path
from math import floor, ceil
from itertools import product
from functools import partial
from tensorflow.keras.models import Model
from tensorflow.keras.utils import Sequence
from rasterio import windows as rio_windows
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from advanced_losses import DiceLossVariants
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.layers import Activation, add, multiply
from tensorflow.keras.layers import MaxPooling2D, SpatialDropout2D
from tensorflow.keras.layers import UpSampling2D, BatchNormalization
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, concatenate


In [None]:
# Define Parameters
droprate = 0.3
image_height = 6000
image_width = 6000
window_height = 512
window_width = 512
min_height_overlap = 32
min_width_overlap = 32
boundless_flag = True
class_count = 6
data_shuffle = True
batchsize=2
bands = (1, 2, 3, 4, 5)
image_features = len(bands)

In [None]:
geo_augs = (
    # ((lambda m : m), True),
    (partial(np.rot90, k=1, axes=(1, 2)), True),
    (partial(np.rot90, k=2, axes=(1, 2)), True),
    (partial(np.rot90, k=3, axes=(1, 2)), True),
    (partial(np.flip, axis=1), True),
    (partial(np.flip, axis=2), True)
)

In [None]:
# Input Data
config_dir = Path('Configs')
train_config = config_dir / 'Train_Map.json'
valid_config = config_dir / 'Validation_Map.json'
test_config = config_dir / 'Test_Map.json'
model_dir = Path("Models")
mplot = model_dir / "Model_Plot.png"
model_max_accuracy = model_dir / 'Model_MaxAccuracy.h5' 
model_min_loss = model_dir / 'Model_MinLoss.h5'
log_d = Path('Logs')

with open(train_config.as_posix(), 'r') as tm:
    train_map = json.load(tm)

with open(valid_config.as_posix(), 'r') as tm:
    valid_map = json.load(tm)

In [None]:
def generate_windows(img_height, img_width, win_height, win_width, min_hoverlap, min_woverlap, boundless=False):
    hc = ceil((img_height - min_hoverlap) / (win_height - min_hoverlap))
    wc = ceil((img_width - min_woverlap) / (win_width - min_woverlap))
    
    
    h_overlap = ((hc * win_height) - img_height) // (hc - 1)
    w_overlap = ((wc * win_height) - img_width) // (wc - 1)
    
    
    hslack_res = ((hc * win_height) - img_height) % (hc - 1)
    wslack_res = ((wc * win_width) - img_width) % (wc - 1)
    
    dh = win_height - h_overlap
    dw = win_width - w_overlap
    
    row_offsets = np.arange(0, (img_height-h_overlap), dh)
    col_offsets = np.arange(0, (img_width-w_overlap), dw)
    
    if hslack_res > 0:
        row_offsets[-hslack_res:] -= np.arange(1, (hslack_res + 1), 1)
    if wslack_res > 0:
        col_offsets[-wslack_res:] -= np.arange(1, (wslack_res + 1), 1)
    
    row_offsets = row_offsets.tolist()
    col_offsets = col_offsets.tolist()
    
    offsets = product(col_offsets, row_offsets)
    
    indices = product(range(len(col_offsets)), range(len(row_offsets)))
    
    big_window = rio_windows.Window(col_off=0, row_off=0, width=img_width, height=img_height)
    
    for index, (col_off, row_off) in zip(indices, offsets):
        window = rio_windows.Window(
            col_off=col_off,
            row_off=row_off,
            width=win_width,
            height=win_height
        )
        if boundless:
            yield index, window
        else:
            yield index, window.intersection(big_window)

In [None]:
class RasterDataGenerator(Sequence):
    def __init__(
        self,  
        map_dict,
        channels,
        img_height,
        img_width,
        win_height,
        win_width,
        min_hoverlap,
        min_woverlap,
        cls_count,
        augs=None, # list of tuples like (fn, flag), fn: aug function works on channel first image, flag: wheather fn should be applied on labels
        boundless=False,
        shuffle=True,
        batch_size=1,
    ):
        assert isinstance(map_dict, dict), 'Invalid type for parameter <map_dict>, expected type `dict`!'
        assert all([set(map_dict[k].keys()) == {'IMAGE', 'LABEL'} for k in map_dict.keys()]), "Invalid map <dict_map>, Key Mismatch!"
        if augs is None:
            augs = (((lambda m : m), True),)
        else:
            assert isinstance(augs, (tuple, list)) and all(
                [callable(fn) and isinstance(flag, bool) for (fn, flag) in augs]
            )
        
        couples =  [(Path(couple['IMAGE']).as_posix(), Path(couple['LABEL']).as_posix()) for couple in map_dict.values()]
        
        windows = list(
            generate_windows(
                img_height=img_height,
                img_width=img_width,
                win_height=win_height,
                win_width=win_width,
                min_hoverlap=min_hoverlap,
                min_woverlap=min_woverlap,
                boundless=boundless
            )
        )
        dat = list(product(couples, windows, augs))
        if shuffle:
            random.shuffle(dat)
        self.data = dat
        self.channels = channels
        self.class_count = cls_count
        self.batch_size = batch_size
    
    def __len__(self):
        return int(np.ceil(len(self.data) / float(self.batch_size)))
    
    def __getitem__(self, idx):
        current_batch = self.data[idx * self.batch_size:(idx + 1) * self.batch_size]
        islices = list()
        lslices = list()
        for (im, lb), (_, w), (aug, af) in current_batch:
            with rio.open(im, 'r') as isrc:
                islice = isrc.read(indexes=self.channels, window=w, boundless=boundless_flag, masked=False)
                islice = aug(islice)
                islice = np.moveaxis(a=islice, source=0, destination=-1)
                islices.append(islice)
            with rio.open(lb, 'r') as lsrc:
                lslice = lsrc.read(window=w, boundless=boundless_flag, masked=False)
                if af is True:
                    lslice = aug(lslice)
                lslice = np.moveaxis(a=lslice, source=0, destination=-1)
                lslice =to_categorical(
                    y=(lslice-1), 
                    num_classes=self.class_count
                )
                lslices.append(lslice)
        ibatch = np.stack(islices, axis=0)
        lbatch = np.stack(lslices, axis=0)
        return ibatch, lbatch

In [None]:
def upsample_conv(filters, kernel_size, strides, padding):
    return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)


def upsample_simple(filters, kernel_size, strides, padding):
    return UpSampling2D(filters, kernel_size, strides=strides, padding=padding)


def attention_gate(inp_1, inp_2, n_intermediate_filters, k_init='he_normal'):
    """Attention gate. Compresses both inputs to n_intermediate_filters filters before processing.
       Implemented as proposed by Oktay et al. in their Attention U-net, see: https://arxiv.org/abs/1804.03999.
    """
    inp_1_conv = Conv2D(
        n_intermediate_filters,
        kernel_size=1,
        strides=1,
        padding="same",
        kernel_initializer=k_init,
    )(inp_1)
    inp_2_conv = Conv2D(
        n_intermediate_filters,
        kernel_size=1,
        strides=1,
        padding="same",
        kernel_initializer=k_init,
    )(inp_2)

    f = Activation("relu")(add([inp_1_conv, inp_2_conv]))
    g = Conv2D(
        filters=1,
        kernel_size=1,
        strides=1,
        padding="same",
        kernel_initializer=k_init,
    )(f)
    h = Activation("sigmoid")(g)
    return multiply([inp_1, h])


def attention_concat(conv_below, skip_connection):
    """Performs concatenation of upsampled conv_below with attention gated version of skip-connection
    """
    below_filters = conv_below.get_shape().as_list()[-1]
    attention_across = attention_gate(skip_connection, conv_below, below_filters)
    return concatenate([conv_below, attention_across])


def conv2d_block(
    inputs,
    use_batch_norm=True,
    dropout=0.3,
    dropout_type="spatial",
    filters=16,
    kernel_size=(3, 3),
    activation="relu",
    kernel_initializer="he_normal",
    padding="same",
    momentum= 0.95
):

    if dropout_type == "spatial":
        DO = SpatialDropout2D
    elif dropout_type == "standard":
        DO = Dropout
    else:
        raise ValueError(
            f"dropout_type must be one of ['spatial', 'standard'], got {dropout_type}"
        )
    if isinstance(activation, str):
        c = Conv2D(
            filters,
            kernel_size,
            activation=activation,
            kernel_initializer=kernel_initializer,
            padding=padding,
            use_bias=not use_batch_norm,
        )(inputs)
    else:
        assert 'keras.layers' in getattr(activation, '__module__', None) and isinstance(activation, type(activation))
        c = activation(
            Conv2D(
                filters,
                kernel_size,
                activation=activation,
                kernel_initializer=kernel_initializer,
                padding=padding,
                use_bias=not use_batch_norm,
            )(inputs)
        )
        
    if use_batch_norm:
        c = BatchNormalization()(c)
    if dropout > 0.0:
        c = DO(dropout)(c)
    c = Conv2D(
        filters,
        kernel_size,
        activation=activation,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=not use_batch_norm,
    )(c)
    if use_batch_norm:
        c = BatchNormalization(momentum=momentum)(c)
    return c


def xnet(
    input_shape=(None, None, image_features),
    num_classes=class_count,
    activation="relu",
    use_batch_norm=True,
    upsample_mode="deconv",  # 'deconv' or 'simple'
    dropout=droprate,
    dropout_change_per_layer=0.0,
    dropout_type="spatial",
    use_dropout_on_upsampling=False,
    use_attention=True,
    filters=32,
    num_layers=4,
    output_activation="softmax",
):

    if upsample_mode == "deconv":
        upsample = upsample_conv
    else:
        upsample = upsample_simple

    # Build U-Net model
    inputs = Input(input_shape)
    x = inputs

    down_layers = []
    for l in range(num_layers):
        x = conv2d_block(
            inputs=x,
            filters=filters,
            use_batch_norm=use_batch_norm,
            dropout=dropout,
            dropout_type=dropout_type,
            activation=activation,
        )
        down_layers.append(x)
        x = MaxPooling2D((2, 2))(x)
        dropout += dropout_change_per_layer
        filters = filters * 2  # double the number of filters with each layer

    x = conv2d_block(
        inputs=x,
        filters=filters,
        use_batch_norm=use_batch_norm,
        dropout=dropout,
        dropout_type=dropout_type,
        activation=activation,
    )

    if not use_dropout_on_upsampling:
        dropout = 0.3
        dropout_change_per_layer = 0.0

    for conv in reversed(down_layers):
        filters //= 2  # decreasing number of filters with each layer
        dropout -= dropout_change_per_layer
        x = upsample(filters, (2, 2), strides=(2, 2), padding="same")(x)
        if use_attention:
            x = attention_concat(conv_below=x, skip_connection=conv)
        else:
            x = concatenate([x, conv])
        x = conv2d_block(
            inputs=x,
            filters=filters,
            use_batch_norm=use_batch_norm,
            dropout=dropout,
            dropout_type=dropout_type,
            activation=activation,
        )

    outputs = Conv2D(num_classes, (1, 1), activation=output_activation)(x)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

In [None]:
xnet_classifier = xnet(
    input_shape=(None, None, image_features),
    num_classes=class_count,
    activation="relu",
    use_batch_norm=True,
    upsample_mode="deconv",  # 'deconv' or 'simple'
    dropout=droprate,
    dropout_change_per_layer=0.0,
    dropout_type="standard",
    use_dropout_on_upsampling=False,
    use_attention=True,
    filters=64,
    num_layers=4,
    output_activation="softmax",
)
xnet_classifier.summary(line_length=116)
# plot_model(xnet_classifier, to_file=mplot, show_shapes=True, show_layer_names=True)

In [None]:
es_val_loss = EarlyStopping(monitor='val_loss', mode='min', patience=5, verbose=1)
# es_val_accu = EarlyStopping(monitor='val_accuracy', mode='max', min_delta=0.001)
mc_val_accu = ModelCheckpoint(str(model_max_accuracy.absolute()), monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
mc_val_loss = ModelCheckpoint(str(model_min_loss.absolute()), monitor='val_loss', mode='min', verbose=1, save_best_only=True)

tb = TensorBoard(
    log_dir=log_d, 
    histogram_freq=1, 
    write_graph=True, 
    write_images=True,
    update_freq='batch', 
    embeddings_freq=0,
    embeddings_metadata=None
)

# Load Pre Trained Weights if any
if model_max_accuracy.is_file():
    xnet_classifier.load_weights(str(model_max_accuracy))

xnet_classifier.compile(
    optimizer=Adam(learning_rate=1e-4), 
    # loss='categorical_crossentropy', 
    loss=DiceLossVariants(name='tversky'),
    metrics=['accuracy']
)

In [None]:
train_generator = RasterDataGenerator( 
    map_dict=train_map,
    channels=bands,
    img_height=image_height,
    img_width=image_width,
    win_height=window_height,
    win_width=window_width,
    min_hoverlap=min_height_overlap,
    min_woverlap=min_width_overlap,
    cls_count=class_count,
    boundless=boundless_flag,
    augs=geo_augs,
    # augs=None,
    shuffle=data_shuffle,
    batch_size=batchsize
)
valid_generator = RasterDataGenerator(
    map_dict=valid_map,
    channels=bands,
    img_height=image_height,
    img_width=image_width,
    win_height=window_height,
    win_width=window_width,
    min_hoverlap=1,
    min_woverlap=1,
    cls_count=class_count,
    augs=None,
    boundless=boundless_flag,
    shuffle=data_shuffle,
    batch_size=batchsize
)

In [None]:
# Train
xnet_classifier.fit(
    x=train_generator, 
    epochs=50, 
    validation_data=valid_generator,
    use_multiprocessing=False,
    callbacks=[
        tb,
        es_val_loss,
        mc_val_loss,
#         es_val_accu,
        mc_val_accu,
        
    ]
)