In [None]:
import json
from pathlib import Path
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import time
import rasterio as rio
from functools import reduce
from tensorflow.keras.models import load_model

from core.UNet import UNet
from core.losses import (
    tversky,
    accuracy,
    dice_coef,
    dice_loss,
    specificity,
    sensitivity,
)
from core.optimizers import adaDelta

from core.dataset_generator import DataGenerator
from core.split_frames import split_dataset
from core.visualize import display_images

import warnings  # ignore annoying warnings

warnings.filterwarnings("ignore")
import logging

logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%reload_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

print(tf.__version__)

In [None]:
from importlib import reload
import core.split_frames as split
import core.frame_info as fram
import core.dataset_generator as dg
import config.conf as conf

conf = reload(conf)
config = conf.Configuration()
fram = reload(fram)
split = reload(split)
dg = reload(dg)

In [None]:
# Read all images/frames into memory
frames = []
norm_images = list(config.image_dir.glob(f"*{config.image_type}"))
dims={}
for i, image_path in enumerate(norm_images):
    
    norm_img = rio.open(image_path)
    
    dims[image_path.name] = {
        "width": norm_img.profile["width"],
        "height": norm_img.profile["height"],
    }
    
    annotation_path = config.annotation_dir/image_path.name
    
    # Check if the input annotation has detected trees
    with open(annotation_path.with_suffix(".json")) as f:
        trees = len(json.load(f)["Trees"])
        
    if trees:
        annotation_img = Image.open(annotation_path)
        weight_img = Image.open(config.boundary_dir/image_path.name)
        norm_array = norm_img.read()
        
        # Change the order of the bands. Let the channel at the end
        norm_array = np.transpose(norm_array, axes=(1,2,0))
        annotation_array = np.array(annotation_img)
        weight_array = np.array(weight_img)

        frames.append(
            fram.FrameInfo(
                norm_array, 
                annotation_array, 
                weight_array, 
                image_path.name,
            )
        )

In [None]:
dims

In [None]:
training_frames, validation_frames, testing_frames = split.split_dataset(
    frames, config.frames_json, config.patch_dir
);

In [None]:
frames[0].img.shape

In [None]:
# Get frames id
# [frame.id for frame in frames]

In [None]:
# Manually select which are the frames I'd like to test
training_frames = [frames.index(f) for f in frames[-2:]]
validation_frames = [frames.index(frames[1])]
testing_frames  = [frames.index(frames[2])]

In [None]:
annotation_channels = config.input_label_channel + config.input_weight_channel
train_generator = dg.DataGenerator(
    config.input_image_channel,
    config.patch_size,
    training_frames,
    frames,
    annotation_channels,
    augmenter="iaa",
).random_generator(config.BATCH_SIZE, normalize=config.normalize)
# training_frames = validation_frames = testing_frames  = list(range(len(frames)))


In [None]:
val_generator = dg.DataGenerator(
    config.input_image_channel,
    config.patch_size,
    validation_frames,
    frames,
    annotation_channels,
    augmenter=None,
).random_generator(config.BATCH_SIZE, normalize=config.normalize)

In [None]:
test_generator = dg.DataGenerator(
    config.input_image_channel,
    config.patch_size,
    testing_frames,
    frames,
    annotation_channels,
    augmenter=None,
).random_generator(config.BATCH_SIZE, normalize=config.normalize)

In [None]:
for _ in range(1):
    train_images, real_label = next(train_generator)
    ann = real_label[:, :, :, 0]
    wei = real_label[:, :, :, 1]
    overlay = ann + wei
    overlay = overlay[:, :, :, np.newaxis]
    display_images(np.concatenate((train_images, real_label, overlay), axis=-1))

In [None]:
OPTIMIZER = adaDelta
LOSS = tversky

# Only for the name of the model in the very end
OPTIMIZER_NAME = "AdaDelta"
LOSS_NAME = "weightmap_tversky"

# Declare the path to the final model
# If you want to retrain an exising model then change the cell where model is declared.
# This path is for storing a model after training.

timestr = time.strftime("%Y%m%d-%H%M")
chf = config.input_image_channel + config.input_label_channel
chs = reduce(lambda a, b: a + str(b), chf, "")

model_path = config.model_dir/f"trees_{timestr}_{OPTIMIZER_NAME}_{LOSS_NAME}_{chs}_{config.input_shape[0]}.h5"

# The weights without the model architecture can also be saved. Just saving the weights is more efficent.

# weight_path="./saved_weights/UNet/{}/".format(timestr)
# if not os.path.exists(weight_path):
#     os.makedirs(weight_path)
# weight_path=weight_path + "{}_weights.best.hdf5".format('UNet_model')
# print(weight_path)

In [None]:
conf = reload(conf)
config = conf.Configuration()

In [None]:
# Define the model and compile it
model = UNet([config.BATCH_SIZE, *config.input_shape], config.input_label_channel)
model.compile(
    optimizer=OPTIMIZER,
    loss=LOSS,
    metrics=[dice_coef, dice_loss, specificity, sensitivity, accuracy],
)

In [None]:
# Define callbacks for the early stopping of training, LearningRateScheduler and model checkpointing
from tensorflow.keras.callbacks import (
    ModelCheckpoint,
    LearningRateScheduler,
    EarlyStopping,
    ReduceLROnPlateau,
    TensorBoard,
)

checkpoint = ModelCheckpoint(
    model_path,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    mode="min",
    save_weights_only=False,
)

# reduceonplatea; It can be useful when using adam as optimizer
# Reduce learning rate when a metric has stopped improving (after some patience,reduce by a factor of 0.33, new_lr = lr * factor).
# cooldown: number of epochs to wait before resuming normal operation after lr has been reduced.
reduceLROnPlat = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.33,
    patience=4,
    verbose=1,
    mode="min",
    min_delta=0.0001,
    cooldown=4,
    min_lr=0.01,
)

early = EarlyStopping(monitor="val_loss", mode="min", verbose=2, patience=20)

log_dir = os.path.join(
    "./logs",
    "UNet_{}_{}_{}_{}_{}".format(
        timestr, OPTIMIZER_NAME, LOSS_NAME, chs, config.input_shape[0]
    ),
)
tensorboard = TensorBoard(
    log_dir=log_dir,
    histogram_freq=0,
    write_graph=True,
    write_grads=False,
    write_images=False,
    embeddings_freq=0,
    embeddings_layer_names=None,
    embeddings_metadata=None,
    embeddings_data=None,
    update_freq="epoch",
)

callbacks_list = [
    checkpoint,
    tensorboard,
]  # reduceLROnPlat is not required with adaDelta

In [None]:
loss_history = [
    model.fit(
        train_generator,
        steps_per_epoch=config.MAX_TRAIN_STEPS,
        epochs=config.NB_EPOCHS,
        validation_data=val_generator,
        validation_steps=config.VALID_IMG_COUNT,
        callbacks=callbacks_list,
        workers=1,
        # use_multiprocessing=True
#                         use_multiprocessing=True # the generator is not very thread safe
    )
]