In [None]:
!pip install opencv-python-headless numpy pandas tensorflow keras matplotlib scikit-learn

In [None]:
import re
import os
import cv2
import shutil
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import layers
from google.colab import drive
from keras.models import Model
from keras import backend as K
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.utils import to_categorical
from keras.applications import ResNet50
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers.legacy import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet50 import preprocess_input
from keras.layers import Input, Conv2D, Concatenate, UpSampling2D, Conv2DTranspose

In [None]:
# Helper functions for data processing
def one_hot_encode(label, label_values):
    semantic_map = []

    # Convert grayscale to RGB if needed
    if len(label.shape) == 2:
        label = np.stack([label, label, label], axis=-1)

    for colour in label_values:
        equality = np.all(np.equal(label, colour), axis=-1)
        semantic_map.append(equality)

    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map.astype('float')

def resize_mask(mask, target_size):
    resized_mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
    return resized_mask

def reverse_one_hot(image):
    x = np.argmax(image, axis=-1)
    return x

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x

In [None]:
# ResNet-based UNet model
def build_resnet_unet(input_shape, num_classes):

    # Input layer
    input_tensor = Input(shape=input_shape)

    # Load ResNet50 model
    base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)

    # Freeze the layers in the base model
    for layer in base_model.layers:
        layer.trainable = False

    # Intermediate layers from the base model
    block1 = base_model.get_layer('conv1_relu').output
    block2 = base_model.get_layer('conv2_block3_out').output
    block3 = base_model.get_layer('conv3_block4_out').output
    block4 = base_model.get_layer('conv4_block6_out').output

    # Decoder
    x = UpSampling2D(size=(2, 2))(block4)
    x = Concatenate()([x, block3])
    x = Conv2D(1024, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(1024, (3, 3), padding='same', activation='relu')(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Concatenate()([x, block2])
    x = Conv2D(512, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(512, (3, 3), padding='same', activation='relu')(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Concatenate()([x, block1])
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)


    # Output layer with a variable number of filters based on num_classes
    x = Conv2D(num_classes, (1, 1), activation='sigmoid')(x)

    # Compile
    model = Model(inputs=input_tensor, outputs=x)
    return model


In [None]:
def data_generator(image_paths, mask_paths, batch_size, input_shape):
    num_samples = len(image_paths)
    indices = list(range(num_samples))

    while True:
        random.shuffle(indices)
        for i in range(0, num_samples, batch_size):
            batch_indices = indices[i:min(i + batch_size, num_samples)]
            batch_images = []
            batch_masks = []

            for idx in batch_indices:
                #random_zoom = random.random()
                random_zoom = 0.01

                img = cv2.imread(os.path.join(x_train_dir, image_paths[idx]), cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (256, 256), fx = random_zoom, fy = random_zoom)
                img = preprocess_input(img)

                mask = cv2.imread(os.path.join(y_train_dir, mask_paths[idx]), cv2.COLOR_BGR2RGB)
                mask = cv2.resize(mask, (256, 256), cv2.IMREAD_GRAYSCALE, fx = random_zoom, fy = random_zoom)

                mask = one_hot_encode(mask, select_class_rgb_values).astype('float')

                batch_images.append(img)
                batch_masks.append(mask)

            batch_masks = np.stack(batch_masks, axis=0)

            yield np.array(batch_images), batch_masks

In [None]:
if __name__ == "__main__":

    # ============================
    # --------- CONFIG -----------
    # ============================

    # Mount Google Drive
    drive.mount('/content/gdrive')

    ROOT_DIR = ''                                # Root path
    DATASET_DIR = os.path.join(ROOT_DIR, 'combined-dataset/tiff/')
    LABEL_DICT_PATH = os.path.join(ROOT_DIR, 'combined-dataset/label_class_dict.csv')

    # Model configuration
    TF_LITE_MODEL_NAME = 'map_buildings.tflite'
    SHOULD_TRAIN = False
    INPUT_SHAPE = (256, 256, 3)

    # Output folders
    MODEL_CKPT_DIR = os.path.join(ROOT_DIR, 'output/model_checkpoints')
    SAVED_MODEL_DIR = os.path.join(ROOT_DIR, 'output/saved_model')
    INFERENCE_OUTPUT_DIR = os.path.join(ROOT_DIR, 'output/inference_images')

    # Dataset folders
    TRAIN_IMG_DIR = os.path.join(DATASET_DIR, 'train')
    TRAIN_MASK_DIR = os.path.join(DATASET_DIR, 'train_labels')

    VAL_IMG_DIR = os.path.join(DATASET_DIR, 'val')
    VAL_MASK_DIR = os.path.join(DATASET_DIR, 'val_labels')

    TEST_IMG_DIR = os.path.join(DATASET_DIR, 'test')
    TEST_MASK_DIR = os.path.join(DATASET_DIR, 'test_labels')

    # ============================
    # ----- LOAD CLASS INFO ------
    # ============================

    class_dict = pd.read_csv(LABEL_DICT_PATH)
    class_names = class_dict['name'].tolist()
    class_rgb_values = class_dict[['r', 'g', 'b']].values.tolist()

    select_classes = ['building']
    select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
    select_class_rgb_values = np.array(class_rgb_values)[select_class_indices]
    num_classes = len(select_classes)

    # ============================
    # ----- TRAIN/VAL/TEST SPLIT -
    # ============================

    train_image_paths = sorted(os.listdir(TRAIN_IMG_DIR))
    train_mask_paths = sorted(os.listdir(TRAIN_MASK_DIR))

    train_image_paths, val_image_paths, train_mask_paths, val_mask_paths = train_test_split(
        train_image_paths, train_mask_paths, test_size=0.2
    )

    val_image_paths, test_image_paths, val_mask_paths, test_mask_paths = train_test_split(
        val_image_paths, val_mask_paths, test_size=0.5
    )

    # ============================
    # ---- DATA GENERATORS -------
    # ============================

    batch_size = 4
    steps_per_epoch = len(train_image_paths) // batch_size
    validation_steps = len(val_image_paths) // batch_size

    train_generator = data_generator(train_image_paths, train_mask_paths, batch_size, INPUT_SHAPE)
    val_generator = data_generator(val_image_paths, val_mask_paths, batch_size, INPUT_SHAPE)

    # ============================
    # ----- CHECKPOINT SETUP -----
    # ============================

    checkpoint_path = os.path.join(MODEL_CKPT_DIR, 'cp-{epoch:04d}.ckpt')
    model_callback = ModelCheckpoint(
        checkpoint_path,
        save_weights_only=True,
        save_best_only=True,
        verbose=1
    )

    # ============================
    # ----- MODEL LOADING --------
    # ============================

    checkpoint_files = [f for f in os.listdir(MODEL_CKPT_DIR) if 'ckpt.' in f]

    if checkpoint_files:
        checkpoint_file_path = os.path.join(MODEL_CKPT_DIR, 'checkpoint')

        with open(checkpoint_file_path, 'r') as f:
            latest_checkpoint_info = f.readline().strip()

        latest_checkpoint_basename = re.search(r'model_checkpoint_path: "([^"]+)"',
                                               latest_checkpoint_info).group(1)

        latest_checkpoint_path = os.path.join(MODEL_CKPT_DIR, latest_checkpoint_basename)

        model = build_resnet_unet(INPUT_SHAPE, num_classes)
        model.load_weights(latest_checkpoint_path)
        model.compile(optimizer=Adam(1e-4), loss='binary_crossentropy', metrics=['accuracy'])

        print(f"Loaded weights from checkpoint: {latest_checkpoint_path}")

    elif os.path.exists(SAVED_MODEL_DIR):
        model = tf.keras.models.load_model(SAVED_MODEL_DIR)
        print("Loaded full TensorFlow SavedModel")

    else:
        model = build_resnet_unet(INPUT_SHAPE, num_classes)
        model.compile(optimizer=Adam(1e-4), loss='binary_crossentropy', metrics=['accuracy'])
        print("Training new model...")

    # ============================
    # -------- TRAINING ----------
    # ============================

    if SHOULD_TRAIN:
        history = model.fit(
            train_generator,
            steps_per_epoch=steps_per_epoch,
            validation_data=val_generator,
            validation_steps=validation_steps,
            epochs=50,
            callbacks=[model_callback]
        )

        # Save TF SavedModel
        model.save(SAVED_MODEL_DIR)

        # Convert to TFLite
        converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
        tflite_model = converter.convert()

        # Save TFLite model
        tflite_path = os.path.join(ROOT_DIR, 'output', TF_LITE_MODEL_NAME)
        with open(tflite_path, 'wb') as f:
            f.write(tflite_model)

        print(f"Saved TensorFlow Lite model: {tflite_path}")

    # ============================
    # -------- EVALUATION --------
    # ============================

    test_generator = data_generator(test_image_paths, test_mask_paths, batch_size, INPUT_SHAPE)
    test_steps = len(test_image_paths) // batch_size

    test_loss, test_accuracy = model.evaluate(test_generator, steps=test_steps)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # ============================
    # -------- INFERENCE ---------
    # ============================

    os.makedirs(INFERENCE_OUTPUT_DIR, exist_ok=True)

    for idx, (image_batch, mask_batch) in enumerate(test_generator):
        if idx >= test_steps:
            break

        predicted_masks_batch = model.predict(image_batch)

        for i in range(len(image_batch)):
            original_image = image_batch[i]
            true_mask = mask_batch[i]
            predicted_mask = (predicted_masks_batch[i] > 0.5).astype(float)

            plt.figure(figsize=(12, 4))

            plt.subplot(1, 3, 1)
            plt.imshow(original_image)
            plt.title('Image')
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(true_mask.squeeze(), cmap='gray')
            plt.title('Ground Truth')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(predicted_mask.squeeze(), cmap='gray')
            plt.title('Predicted')
            plt.axis('off')

            result_path = os.path.join(INFERENCE_OUTPUT_DIR, f'result_{idx}_{i}.png')
            plt.savefig(result_path)
            plt.close()

        print(f"Inference images saved in: {INFERENCE_OUTPUT_DIR}")