<a href="https://colab.research.google.com/github/docuracy/desCartes/blob/main/experiments/cnn-5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Authenticate GCS, mount Google Drive

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!gcloud auth application-default login
!gcloud config set project descartes-404713

In [None]:
#@title Upgrade TensorFlow

!pip install --upgrade tensorflow


Collecting tensorflow
  Downloading tensorflow-2.15.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (475.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m475.2/475.2 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting tensorboard<2.16,>=2.15 (from tensorflow)
  Downloading tensorboard-2.15.1-py3-none-any.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tensorflow-estimator<2.16,>=2.15.0 (from tensorflow)
  Downloading tensorflow_estimator-2.15.0-py2.py3-none-any.whl (441 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m442.0/442.0 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting keras<2.16,>=2.15.0 (from tensorflow)
  Downloading keras-2.15.0-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Collecting google-auth-oauth

In [None]:
#@title Initialise directories and global variables

import os
import sys
import subprocess

# Directory containing scripts such as 'map_from_tiles'
scripts_directory = '/content/drive/MyDrive/Colab Notebooks/scripts'
sys.path.append(scripts_directory)

# Directories used by 'map_from_tiles'
temp_directory = f"{scripts_directory}/temp"
cache_directory = f"{scripts_directory}/data/cache"

training_data_directory = '/content/drive/MyDrive/desCartes/training_data/'
map_directory = f"{training_data_directory}maps/"
map_classified_s1_directory = f"{map_directory}classified_s1/"
map_one_inch_directory = f"{map_directory}one_inch/"
map_osm_directory = f"{map_directory}osm/"
map_dem_directory = f"{map_directory}dem/"
map_elevation_directory = f"{map_dem_directory}elevation/"
map_slope_directory = f"{map_dem_directory}slope/"
map_augmented_s1_directory = f"{map_directory}augmented_s1/"
map_binary_directory = f"{map_directory}binary/"
map_skeleton_directory = f"{map_directory}skeleton/"
map_output_directory = f"{map_directory}output/"
map_mask_directory = f"{map_output_directory}masks/"
map_overlay_directory = f"{map_output_directory}overlays/"
map_geotiff_directory = f"{map_output_directory}geotiffs/"
labels_directory = f"{map_directory}labels/"
labels_raster_directory = f"{labels_directory}raster/"
labels_overlay_directory = f"{labels_directory}overlay/"

tile_directory = '/content/tiles/'
tile_size = 256 # (px)
min_overlap = 16 # Minimum tile overlap (px)

# GeoPackage containing map annotations created in QGIS
geopackage_path = '/content/drive/MyDrive/desCartes/templates/labels.gpkg'
linestring_buffer = 3 # (px) Use False for no buffer

maptiler_key = 'U2vLM8EbXurAd3Gq6C45'

# UK Great Britain, Ordnance Survey six-inch to the mile (1:10,560), 1888-1913 https://cloud.maptiler.com/tiles/uk-osgb10k1888/
basemap_url = 'https://api.maptiler.com/tiles/uk-osgb10k1888/{z}/{x}/{y}.jpg' + f'?key={maptiler_key}'

# UK Great Britain, Ordnance Survey one-inch to the mile (1:63,360), 1888-1913 https://cloud.maptiler.com/tiles/uk-osgb63k1885/
basemap_url_one_inch = 'https://api.maptiler.com/tiles/uk-osgb63k1885/{z}/{x}/{y}.png' + f'?key={maptiler_key}'

# DEM Tiles - see https://documentation.maptiler.com/hc/en-us/articles/4405444055313-RGB-Terrain-by-MapTiler
dem_tilesource = 'https://api.maptiler.com/tiles/terrain-rgb-v2/{z}/{x}/{y}.webp' + f'?key={maptiler_key}'
dem_max_zoom = 14

# Ilastik model used for Stage 1 pixel classification
ilastik_project_file = "/content/drive/MyDrive/desCartes/ilastik/preprocess.ilp"
ilastik_executable = './ilastik-1.4.0-Linux/run_ilastik.sh'

# Directory for saving trained CNN models
model_directory = "/content/drive/MyDrive/desCartes/models"

label_strings_file = os.path.join(model_directory, 'label_strings.txt')
class_weights_file = os.path.join(model_directory, 'class_weights.json')
num_classes = 5 # Allows for fill (zero) and road classes 1 to 4 (determined by QGIS labelling)

# Google Cloud Services
gcs_key_path = '/content/drive/MyDrive/desCartes/descartes-404713-cccf7c3921aa.json'
gcs_project_id = 'descartes-404713'
gcs_bucket_name = 'descartes'
gcs_data_directory = "training_data"

# Set the split ratios and batch size for training data
TFRecord_batch_size = 16
train_ratio = 0.85
eval_ratio = 0.15

initial_learning_rate = 0.0001

# Inference: Color mappings for classes
class_colors = {
    0: (0, 0, 0, 0),  # Transparent (background)
    1: (178,24,43,180),  # Red
    2: (239,138,98,180),  # Orange
    3: (84,39,136,180),  # Purple
    4: (153,142,195,180),  # Lilac
}

# Create directories if they do not exist
directories_to_create = [
    temp_directory,
    cache_directory,
    training_data_directory,
    map_directory,
    map_one_inch_directory,
    map_osm_directory,
    map_dem_directory,
    map_elevation_directory,
    map_slope_directory,
    map_classified_s1_directory,
    map_augmented_s1_directory,
    map_binary_directory,
    map_skeleton_directory,
    map_output_directory,
    map_mask_directory,
    map_overlay_directory,
    map_geotiff_directory,
    labels_directory,
    labels_raster_directory,
    labels_overlay_directory,
    model_directory,
]

for directory in directories_to_create:
    os.makedirs(directory, exist_ok=True)


# Load and Train CNN Model: **must be run on TPU**

In [None]:
#@title Clear Session

# Assuming `model` is the name of your compiled Keras model
# Clear Keras session
from keras import backend as K
K.clear_session()

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)




In [None]:
#@title v5 Load and Compile Model

import os
import re
import glob
import json
from datetime import datetime
import numpy as np
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.saving import load_model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import Sequence, custom_object_scope
from tensorflow.keras.metrics import Metric, CategoricalAccuracy, MeanIoU
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import Loss, CategoricalCrossentropy, CategoricalFocalCrossentropy

from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Dropout, GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, MaxPooling2D, Conv2DTranspose, concatenate, Activation
from tensorflow.keras.regularizers import l2

reload_existing = False # @param {type:"boolean"}
reduce_classes = False # @param {type:"boolean"}
ignore_ilastik = False # @param {type:"boolean"}

output_classes = 2 if reduce_classes else 5

##############################################################################
# Detect and initialize the TPU

import tensorflow as tf

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    print('Not connected to a TPU runtime. Running on CPU/GPU.')

try:
    # Check if TPU system has already been initialized
    if not tf.config.list_logical_devices('TPU'):
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        tpu_strategy = tf.distribute.TPUStrategy(tpu)
    else:
        print('TPU system already initialized.')

except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime.')
##############################################################################

label_strings = []
with open(label_strings_file, 'r') as file:
    for line in file:
        label = line.strip()  # Remove leading/trailing whitespace, like newline characters
        label_strings.append(label)

def unet_model(
    input_shape=(256, 256, 28),
    num_classes=output_classes,
    # sizes = [64, 128, 256, 512, 1024, 2048],
    # filters = [3, 3, 3, 3, 3, 3],
    # sizes = [256, 512, 512, 1024],
    # filters = [3, 3, 5, 7],
    # sizes = [512, 256, 128, 64],
    # filters = [3, 3, 5, 7],
    # sizes = [1024, 1024, 128, 64],
    # filters = [3, 3, 5, 7],
    # sizes = [1024, 1024, 128, 32, 8], # 20 mins per epoch; failed at 30
    # filters = [3, 3, 5, 7, 9],
    # sizes = [1024, 128, 8], # 20 mins per epoch, slow convergence to inadequate categorical crossentropy
    # filters = [3, 5, 9],
    # sizes = [1024, 1024, 64], # 18 mins per epoch
    # filters = [3, 3, 7],
    # sizes = [1024, 1024, 1024], # Crashed memory
    # filters = [3, 3, 3],
    # sizes = [1024, 1024, 128, 8], # 20 mins per epoch, 0.1350 @Epoch 30 (crashed)
    # filters = [3, 3, 5, 9],
    # sizes = [1024, 512, 256, 128],
    # filters = [3, 3, 5, 5],
    # sizes = [1024, 1024, 128, 32, 8], # 20 mins per epoch; failed at 30
    # filters = [3, 3, 5, 7, 9],
    # sizes = [512, 512, 256, 128, 64],
    # filters = [3, 3, 5, 7, 9],
    sizes = [1024, 1024, 512, 256, 128, 64],
    filters = [3, 3, 5, 5, 7, 9],
    ):

    inputs = Input(shape=input_shape)
    x = inputs

    encoders_layers = []
    for i, size in enumerate(sizes):

        # Vary the regularization strength based on the encoder depth or other criteria
        regularization_strength = 1e-5 if i < 2 else 1e-4

        print(f"Encoder - size: {size}; filter: {filters[i]}")

        x = Conv2D(size, (filters[i], filters[i]), activation=tf.nn.leaky_relu, padding='same', kernel_regularizer=l2(regularization_strength))(x)
        x = BatchNormalization()(x)
        x = Dropout(0.2)(x)
        x = Conv2D(size, (filters[i], filters[i]), activation=tf.nn.leaky_relu, padding='same', kernel_regularizer=l2(regularization_strength))(x)
        x = BatchNormalization()(x)

        encoder_pool = MaxPooling2D((2, 2))(x)
        encoders_layers.append((x, encoder_pool))
        x = encoder_pool

    # Bottleneck
    bottleneck_size = sizes[-1] * 2
    x = Conv2D(bottleneck_size, (filters[-1], filters[-1]), activation=tf.nn.leaky_relu, padding='same', kernel_regularizer=l2(1e-5))(x)
    x = Dropout(0.5)(x)

    # Decoder
    for i, (encoder_layer, encoder_pool) in enumerate(reversed(encoders_layers)):
        decoder_size = sizes[-i - 1]
        stride_size = 2
        strides = (stride_size, stride_size)

        print(f"Decoder - size: {decoder_size}; filter: {filters[-i-1]}; strides: {strides}")

        x = Conv2DTranspose(decoder_size, (filters[-i-1], filters[-i-1]), strides=strides, padding='same')(x)
        x = concatenate([x, encoder_layer], axis=-1)
        x = Conv2D(decoder_size, (filters[-i-1], filters[-i-1]), activation=tf.nn.leaky_relu, padding='same', kernel_regularizer=l2(1e-5))(x)

    # Additional Convolutional Layer
    x = Conv2D(num_classes, (filters[0], filters[0]), activation=tf.nn.leaky_relu, padding='same', kernel_regularizer=l2(1e-5))(x)

    # Output
    outputs = Activation('softmax')(x)

    model = Model(inputs, outputs)

    return model

def load_and_compile_model(resume, model_directory):
    def load_class_weights(model_directory):
        class_weights_file = os.path.join(model_directory, 'class_weights.json')
        if os.path.exists(class_weights_file):
            with open(class_weights_file, 'r') as json_file:
                class_weights = json.load(json_file)
                class_weights = {int(key): value for key, value in class_weights.items()}  # Convert keys to integers

                # Convert class_weights to a list, matching the number of classes
                num_classes = 5
                class_weights_list = np.array([class_weights[i] for i in range(num_classes)], dtype=np.float32)
                print(f"class_weights_list: {class_weights}")

                if reduce_classes:
                    adjusted_class_weights = np.zeros(2, dtype=np.float32)
                    adjusted_class_weights[0] = 1 / (1 / class_weights_list[0] + 1 / class_weights_list[3] + 1 / class_weights_list[4])
                    adjusted_class_weights[1] = 1 / (1 / class_weights_list[1] + 1 / class_weights_list[2])
                    class_weights_list = adjusted_class_weights
                    print(f"Adjusted Class Weights: {class_weights_list}")

                return class_weights_list
        else:
            raise ValueError("Class weights not found. Cannot proceed without class weights.")

    class_weights_list = load_class_weights(model_directory)
    model = None

    with tpu_strategy.scope():

        if resume:
            # Load the most recent model (or its checkpoint) from the model_directory
            model_files = sorted(glob.glob(os.path.join(model_directory, '*.keras')), key=os.path.getmtime)
            print(model_files)
            if model_files:
                model_filepath = model_files[-1]
                print(f"Loading model: {model_filepath}")
                model = load_model(model_filepath)
                print(f"... loaded.")
            else:
                print("No model to resume.")

        elif model is None:
            # Create a new model
            print("Training a new model.")
            timestamp = datetime.now().strftime("%Y%m%d_%H%M")
            model_filepath = os.path.join(model_directory, f"desCartes_{timestamp}.keras")

            model = unet_model()  # Use default inputs

            model.compile(
                optimizer=Adam(learning_rate = initial_learning_rate),
                loss=CategoricalFocalCrossentropy(alpha=class_weights_list),
                metrics=['categorical_accuracy', 'categorical_crossentropy']
            )

        model.summary()
        return model, model_filepath

# Call load_and_compile_model to load or create and compile the model
model, model_filepath = load_and_compile_model(resume=reload_existing, model_directory=model_directory)


Running on TPU  ['10.123.118.114:8470']
TPU system already initialized.
class_weights_list: {0: 0.20887414958955328, 1: 18.736409394308616, 2: 8.051957535094475, 3: 64.33833280169463, 4: 51.7601309518174}
Training a new model.
Encoder - size: 1024; filter: 3
Encoder - size: 1024; filter: 3
Encoder - size: 512; filter: 5
Encoder - size: 256; filter: 5
Encoder - size: 128; filter: 7
Encoder - size: 64; filter: 9
Decoder - size: 64; filter: 9; strides: (2, 2)
Decoder - size: 128; filter: 7; strides: (2, 2)
Decoder - size: 256; filter: 5; strides: (2, 2)
Decoder - size: 512; filter: 5; strides: (2, 2)
Decoder - size: 1024; filter: 3; strides: (2, 2)
Decoder - size: 1024; filter: 3; strides: (2, 2)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 256, 256, 28)]       0         []              

In [None]:
model_filepath = '/content/drive/MyDrive/desCartes/models/desCartes_20231207_0723.keras'

In [None]:
#@title Train Model

from google.cloud import storage
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
import os
import sys
import glob
import matplotlib.pyplot as plt
import re
import json
import numpy as np

restart_at_epoch = 0 # @param {type:"integer"}
epochs = 150 # @param {type:"integer"}
batch_size = 16 # @param {type:"integer"}
verbose_callbacks = False # @param {type:"boolean"}
overwrite_checkpoints = True # @param {type:"boolean"}
verbose = 1 if verbose_callbacks else 0
ignore_ilastik = False
reduce_classes = False

image_classes = 28

# Set up GCS client
client = storage.Client(project=gcs_project_id)
gcs_train_directory = f"{gcs_data_directory}/train"
gcs_eval_directory = f"{gcs_data_directory}/eval"

# Function to parse TFRecord
def parse_tfrecord_fn(example):
    feature_description = {
        'label': tf.io.FixedLenFeature([], tf.string),
        'image': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature_description)
    label = tf.io.decode_raw(example['label'], tf.uint8)
    label = tf.reshape(label, (256, 256, 5))
    if reduce_classes:
        label_roads = tf.bitwise.bitwise_or(label[:, :, 1:2], label[:, :, 2:3])
        label = tf.concat([label[:, :, :1], label_roads], axis=-1)
    image = tf.io.decode_raw(example['image'], tf.uint8)
    image = tf.reshape(image, (256, 256, image_classes ))
    if ignore_ilastik:
        image = tf.concat([image[:, :, :4], image[:, :, -5:]], axis=-1)
    return image, label

# GCS paths for training and evaluation
gcs_train_pattern = 'gs://{}/{}/*.tfrecord'.format(gcs_bucket_name, gcs_train_directory)
gcs_eval_pattern = 'gs://{}/{}/*.tfrecord'.format(gcs_bucket_name, gcs_eval_directory)

# Get the list of file paths matching the pattern
gcs_train_files = tf.io.gfile.glob(gcs_train_pattern)
gcs_eval_files = tf.io.gfile.glob(gcs_eval_pattern)

# Calculate steps_per_epoch
steps_per_epoch = len(gcs_train_files) * TFRecord_batch_size // batch_size
validation_steps = len(gcs_eval_files) * TFRecord_batch_size // batch_size
print(f"{steps_per_epoch} steps per epoch and {validation_steps} validation steps.")

# Create TFRecord dataset from GCS paths
train_dataset = tf.data.TFRecordDataset(gcs_train_files)
eval_dataset = tf.data.TFRecordDataset(gcs_eval_files)

# Map the parsing function and shuffle the training dataset
# train_dataset = train_dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle(buffer_size=1000)
train_dataset = train_dataset.map(parse_tfrecord_fn).shuffle(buffer_size=1000) # Remove AUTOTUNE to save memory?
# Create padded batches for the training dataset and repeat indefinitely
train_dataset = train_dataset.padded_batch(batch_size, drop_remainder=True).repeat()

# Map the parsing function and shuffle the validation dataset
# val_dataset = eval_dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle(buffer_size=1000)
val_dataset = eval_dataset.map(parse_tfrecord_fn).shuffle(buffer_size=1000) # Remove AUTOTUNE to save memory?
# Create padded batches for the validation dataset
val_dataset = val_dataset.padded_batch(batch_size, drop_remainder=True).repeat()

# # Take a small subset of records for testing
# train_dataset = train_dataset.take(8)
# val_dataset = val_dataset.take(2)

assert isinstance(train_dataset, tf.data.Dataset), "train_dataset should be a tf.data.Dataset"
assert isinstance(val_dataset, tf.data.Dataset), "val_dataset should be a tf.data.Dataset"

def update_filepath(model_filepath, epoch):
    model_filepath = re.sub(r'_epoch(\d+)\.keras', lambda match: f'_epoch{int(match.group(1)) + epoch}.keras', model_filepath)
    return model_filepath

def convert_np_floats(obj):
    if isinstance(obj, np.float32):
        return float(obj)
    return obj

# Define the ModelCheckpoint callback
if overwrite_checkpoints:
    checkpoint_filepath = model_filepath
else:
    checkpoint_filepath = model_filepath.replace('.keras', '-{epoch:02d}-{val_categorical_crossentropy:.2f}_checkpoint.keras')
checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_categorical_crossentropy',
    save_best_only=True,
    save_weights_only=False,
    mode='auto',
    verbose=0
)

# Define the Early Stopping callback
early_stopping = EarlyStopping(
    monitor='val_categorical_crossentropy',
    patience=16,
    restore_best_weights=True,
)

# Define the Learning Rate reduction callback
reduce_lr = ReduceLROnPlateau(
    monitor='val_categorical_crossentropy',
    factor=0.4,
    patience=5,
    min_lr=1e-9,
    verbose=verbose
)

# Train the model
history = model.fit(
    train_dataset,
    epochs=epochs,
    initial_epoch=restart_at_epoch,
    steps_per_epoch=steps_per_epoch,
    shuffle=True,
    validation_data=val_dataset,
    validation_steps=validation_steps,
    callbacks=[checkpoint_callback, reduce_lr, early_stopping]
    )

if early_stopping.stopped_epoch > 0:
    model_filepath = update_filepath(model_filepath, early_stopping.stopped_epoch)
    print(f"Training stopped at epoch {early_stopping.stopped_epoch} due to early stopping.")
else:
    model_filepath = update_filepath(model_filepath, epochs)
    print("Training completed all epochs.")

try:
    model.save(model_filepath)
    print(f"Model saved successfully at: {model_filepath}")
except Exception as e:
    print(f"Error saving the model: {e}")

# Save the complete training history as a JSON file
history_filepath = model_filepath.replace('.keras', '.history.json')
try:
    with open(history_filepath, 'w') as json_file:
        json.dump(history.history, json_file, default=convert_np_floats)
    print(f"History saved successfully at: {history_filepath}")
except Exception as e:
    print(f"Error saving the history: {e}")




421 steps per epoch and 75 validation steps.
Epoch 1/150
Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150

In [None]:
#@title Plot history

import json
import os
import re
import matplotlib.pyplot as plt
from matplotlib import colormaps
from sklearn.preprocessing import MinMaxScaler

import numpy as np
from scipy.optimize import curve_fit

def load_most_recent_history(directory):
    # Get all history files in the directory
    history_files = [f for f in os.listdir(directory) if f.endswith('.history.json')]

    if not history_files:
        print("No history files found in the directory.")
        return None

    # Sort history files by ascending creation date
    sorted_history_files = sorted(history_files, key=lambda x: os.path.getctime(os.path.join(directory, x)))

    # Get the most recent history file
    most_recent_file = sorted_history_files[-1]

    with open(os.path.join(directory, most_recent_file), 'r') as file:
        history = json.load(file)

    return history

def plot_best_fit_curves(history):
    plt.figure(figsize=(10, 6))

    # Define the fitting function (you need to adjust this based on your data)
    def fit_function(x, a, b, c):
        return a * np.exp(-b * x) + c

    # Extract the number of epochs
    epochs = np.arange(1, len(history['loss']) + 1)

    # Normalize each metric individually using MinMaxScaler
    scaler = MinMaxScaler()

    metrics = [metric for metric in history.keys() if not metric.startswith('val_') and not metric == 'loss' and not metric == 'lr']

    colors = colormaps.get_cmap('Dark2')

    # Normalize and handle NaN or Inf values
    def filterNaN(metric):
        values = np.array(history[metric])
        mask_finite = np.isfinite(values)
        masked_values = values[mask_finite]
        if not masked_values.size:
            return None
        normalized_values = scaler.fit_transform([[v] for v in masked_values])
        return normalized_values

    for metric in metrics:

        normalized_values = filterNaN(metric)
        if normalized_values is None:
            continue  # Skip the metric if there are no valid values
        # Fit and plot the curve on normalized data
        popt, _ = curve_fit(fit_function, epochs, normalized_values.flatten())
        plt.plot(epochs, fit_function(epochs, *popt), linestyle=':', label=f'Training {metric} (best fit)', color=colors(metrics.index(metric) / len(metrics)))

        val_metric = 'val_' + metric
        if val_metric in history:
            normalized_values = filterNaN(val_metric)
            if normalized_values is None:
                continue  # Skip the metric if there are no valid values
            plt.plot(epochs, normalized_values, linestyle='--', label=f'Validation {metric}', color=colors(metrics.index(metric) / len(metrics)))

    plt.xlabel('Epoch')
    plt.ylabel('Normalized Metric Value')
    plt.legend()
    plt.show()

# Load the most recent history file
most_recent_history = load_most_recent_history(model_directory)

if most_recent_history is not None:
    # Plot the best-fit curves
    plot_best_fit_curves(most_recent_history)


In [None]:
#@title Load Trained Model from Last Checkpoint
model_filename = "" # @param {type:"string"}

import importlib
import tensorflow as tf

import os
import sys
import json
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import Loss, CategoricalCrossentropy, CategoricalFocalCrossentropy
import cv2
import numpy as np
import math
from osgeo import gdal
from osgeo import osr
from tqdm.notebook import tqdm
import contextlib
import shutil

ignore_ilastik = False
reduce_classes = False

label_strings = []
with open(label_strings_file, 'r') as file:
    for line in file:
        label = line.strip()  # Remove leading/trailing whitespace, like newline characters
        label_strings.append(label)

def load_class_weights(model_directory):
    class_weights_file = os.path.join(model_directory, 'class_weights.json')
    if os.path.exists(class_weights_file):
        with open(class_weights_file, 'r') as json_file:
            class_weights = json.load(json_file)
            class_weights = {int(key): value for key, value in class_weights.items()}  # Convert keys to integers

            # Convert class_weights to a list, matching the number of classes
            num_classes = 5
            class_weights_list = np.array([class_weights[i] for i in range(num_classes)], dtype=np.float32)
            print(f"class_weights_list: {class_weights}")

            if reduce_classes:
                adjusted_class_weights = np.zeros(2, dtype=np.float32)
                adjusted_class_weights[0] = 1 / (1 / class_weights_list[0] + 1 / class_weights_list[3] + 1 / class_weights_list[4])
                adjusted_class_weights[1] = 1 / (1 / class_weights_list[1] + 1 / class_weights_list[2])
                class_weights_list = adjusted_class_weights
                print(f"Adjusted Class Weights: {class_weights_list}")

            return class_weights_list
    else:
        raise ValueError("Class weights not found. Cannot proceed without class weights.")

class_weights_list = load_class_weights(model_directory)

def load_trained_model(model_filename="", model_directory=model_directory):
    if model_filename == "":
        # If no specific model filename is provided, find the most recently saved model checkpoint
        model_files = [f for f in os.listdir(model_directory) if f.endswith('_checkpoint.keras')]
        if not model_files:
            raise FileNotFoundError("No saved models found in the specified directory.")

        model_files.sort(key=lambda x: os.path.getmtime(os.path.join(model_directory, x)), reverse=True)
        most_recent_model = model_files[0]
        model_path = os.path.join(model_directory, most_recent_model)
    else:
        # If a model filename is provided, use that
        model_path = os.path.join(model_directory, model_filename)

    try:
        print(f"Loading model from: {model_path}")
        model = load_model(model_path, compile=False)
        print(f"... Loaded model from: {model_path}")
        return model
    except Exception as e:
        raise Exception(f"Error loading the model: {str(e)}")

model = load_trained_model(model_filename)
model.summary()
model.compile(
    optimizer=Adam(learning_rate = initial_learning_rate),
    loss=CategoricalFocalCrossentropy(alpha=class_weights_list),
    metrics=['categorical_accuracy', 'categorical_crossentropy']
)

In [None]:
#@title Inference

import importlib
import tensorflow as tf

# Check if Rasterio is installed
try:
    import rasterio
except ImportError:
    !pip install rasterio
    import rasterio

import os
import sys
from tensorflow.keras.models import load_model
import cv2
import numpy as np
import math
from osgeo import gdal
from osgeo import osr
from tqdm.notebook import tqdm
import contextlib
import shutil
from PIL import Image
import matplotlib.pyplot as plt

include_labelled = False
ignore_ilastik = False # @param {type:"boolean"}

def load_trained_model(model_filename=None, model_directory=model_directory):
    if model_filename is None:
        # If no specific model filename is provided, find the most recently saved model
        model_files = [f for f in os.listdir(model_directory) if f.endswith('.keras')]
        if not model_files:
            raise FileNotFoundError("No saved models found in the specified directory.")

        model_files.sort(key=lambda x: os.path.getmtime(os.path.join(model_directory, x)), reverse=True)
        most_recent_model = model_files[0]
        model_path = os.path.join(model_directory, most_recent_model)
    else:
        # If a model filename is provided, use that
        model_path = os.path.join(model_directory, model_filename)

    try:
        model = load_model(model_path, compile=False)
        print(f"Loaded model from: {model_path}")
        return model
    except Exception as e:
        raise Exception(f"Error loading the model: {str(e)}")

# Load model if necessary
if 'model' not in locals():
    model = load_trained_model()

def post_process(predictions):
    result_mask = np.zeros((predictions.shape[0], predictions.shape[1], 4), dtype=np.uint8)  # 4 channels for RGBa

    max_class_probabilities = np.argmax(predictions, axis=-1)

    for class_id, color in class_colors.items():
        # Replace all pixels with class_id in max_class_probabilities with the corresponding color
        result_mask[max_class_probabilities == class_id] = color

    return result_mask

def calculate_overlaps(map_height, map_width, tile_size, min_overlap):

    horizontal_count = math.ceil((map_width - min_overlap) / (tile_size - min_overlap))
    vertical_count = math.ceil((map_height - min_overlap) / (tile_size - min_overlap))

    horizontal_overlap = (tile_size * horizontal_count - map_width) / (horizontal_count - 1)
    vertical_overlap = (tile_size * vertical_count - map_height) / (vertical_count - 1)

    return horizontal_count, horizontal_overlap, vertical_count, vertical_overlap

def perform_sliding_window_inference(map_name):

    augmented_map = np.load(f"{map_augmented_s1_directory}{map_name}.augmented_s1.npy")
    if ignore_ilastik:
        augmented_map = tf.concat([augmented_map[:, :, :4], augmented_map[:, :, -5:]], axis=-1)

    map_height, map_width = augmented_map.shape[:2]

    # Calculate the number of tiles and overlaps
    horizontal_count, horizontal_overlap, vertical_count, vertical_overlap = calculate_overlaps(map_height, map_width, tile_size, min_overlap)

    result_mask = np.zeros((map_height, map_width, 4), dtype=np.uint8)  # 4 channels for RGBA

    # Create a tqdm progress bar with dynamic_ncols=True
    patches = tqdm(total=horizontal_count * vertical_count, dynamic_ncols=True, desc=f"Processing {map_name}", position=0, leave=True)

    for h in range(horizontal_count):
        for v in range(vertical_count):

            # Calculate the starting coordinates for the tile
            x_start = int(h * (tile_size - horizontal_overlap))
            y_start = int(v * (tile_size - vertical_overlap))

            # Calculate the ending coordinates for the tile
            x_end = min(x_start + tile_size, map_width)
            y_end = min(y_start + tile_size, map_height)

            # Extract the tile from the map
            tile = augmented_map[y_start:y_end, x_start:x_end]

            # Suppress the output of model.predict
            with open(os.devnull, 'w') as fnull:
                with contextlib.redirect_stdout(fnull):
                    predictions = model.predict(np.expand_dims(tile, axis=0))

            color_mask = post_process(predictions[0])

            # Trim the color mask on all sides by half of the minimum overlap
            trim_size = min_overlap // 2
            trimmed_mask = color_mask[trim_size:tile_size - trim_size, trim_size:tile_size - trim_size]

            # Calculate the offset for placing the trimmed color mask in the result_mask
            offset_x = x_start + trim_size
            offset_y = y_start + trim_size

            # Place the trimmed color mask in the result_mask with the calculated offset
            result_mask[offset_y:offset_y + tile_size - 2 * trim_size, offset_x:offset_x + tile_size - 2 * trim_size] = trimmed_mask

            # Update the progress bar for each iteration
            patches.update(1)

    # Ensure the progress bar reaches 100%
    patches.update(horizontal_count * vertical_count - patches.n)

    return result_mask

# Save georeferenced outputs
def save_outputs(result_mask, map_path, map_name):

    mask_output_path = f"{map_mask_directory}{map_name}.png"
    overlay_output_path = f"{map_overlay_directory}{map_name}.png"
    geotiff_output_path = f"{map_geotiff_directory}{map_name}.tif"

    # Loop over paths and delete if they pre-exist
    output_paths = [
        mask_output_path,
        overlay_output_path,
        geotiff_output_path
    ]
    for path in output_paths:
        if os.path.exists(path):
            os.remove(path)

    # Generate and copy georeferencing .aux.xml for png images
    with rasterio.open(map_path) as src:
        transform = src.transform
        crs = src.crs
        with rasterio.open(mask_output_path, 'w', driver='PNG', width=src.width, height=src.height, count=src.count, dtype=src.dtypes[0], crs=crs, transform=transform) as dst:
            dst.write(src.read()) # .png will be overwritten when mask is generated

        shutil.copyfile(f"{mask_output_path}.aux.xml", f"{overlay_output_path}.aux.xml")

        # Save the result_mask as a PNG image
        mask = Image.fromarray(result_mask)
        mask.save(mask_output_path)

        # Save the result_mask overlaid on the original map as a PNG image
        jpg_image = Image.open(map_path)
        overlay = Image.alpha_composite(jpg_image.convert("RGBA"), mask)
        overlay.save(overlay_output_path, "PNG")

        display(overlay)

        # Create a new GeoTIFF file and write the result_mask
        with rasterio.open(geotiff_output_path, 'w', nodata=0, driver='GTiff', width=src.width, height=src.height, count=4, dtype=src.dtypes[0], crs=crs, transform=transform, compress='lzw') as dst_gt:
            dst_gt.write(result_mask.transpose(2, 0, 1))

def process_map(jpg_filename):
      map_name = jpg_filename.replace('.jpg', '')
      map_path = os.path.join(map_directory, jpg_filename)

      label_path = f"{labels_raster_directory}{map_name}.label.npy"
      if os.path.exists(label_path) and not include_labelled:
          return {'map_name': map_name, 'processed': False}

      # Perform sliding window inference
      result_mask = perform_sliding_window_inference(map_name)

      # Save outputs
      save_outputs(result_mask, map_path, map_name)

      return {'map_name': map_name, 'processed': True}

inference = [process_map(jpg_filename) for jpg_filename in os.listdir(map_directory) if jpg_filename.endswith('.jpg')]
skipped_maps = [result for result in inference if result['processed'] is False]
if skipped_maps:
    print("Skipped the following pre-labelled maps:")
    for skipped_map in skipped_maps:
        print(skipped_map['map_name'])


In [None]:
#@title Post-Processing
'''
Convert Pixel-Level Masks to Vector Representations
Develop code to skeletonise to junction points
Add code to post-process the pixel-level masks and convert them to vector representations

Step 7: Model Evaluation

Assess Model Performance
Add code to evaluate the model using metrics and visual inspection

Step 8: Model Refinement

Fine-Tune the Model
Add code to fine-tune the model based on evaluation results
'''