# @ Jani Kuhno
# Model from https://keras.io/examples/vision/oxford_pets_image_segmentation/ 
# by Francois Chollet

In [None]:
import tensorflow as tf
import keras
from keras import layers
from keras import ops
import keras_cv

import os
import numpy as np
import matplotlib.pyplot as plt

# For data preprocessing
from tensorflow import image as tf_image
from tensorflow import data as tf_data
from tensorflow import io as tf_io
from tensorflow.keras.utils import load_img, img_to_array

In [None]:
#IMAGE_SIZE: indicating one side length of a square image, when needing a tuple of WxH, use (IMAGE_SIZE, IMAGE_SIZE)
#BATCH_SIZE: smaller models might fit memory with larger batch size, experiment with batch size and optimizer learning rate
#NUM_CLASSES: how many classes are annotated by replicator, take into account background and unclassified which come automatic
#input_dir: directory for training data, same as output_dir in replicator
#NUM_TRAIN_IMAGES: how many samples to train on
#NUM_VAL_IMAGES: how many samples to validate on


IMAGE_SIZE = 512
BATCH_SIZE = 4
NUM_CLASSES = 4
input_dir = "your_dataset_output_path"
NUM_TRAIN_IMAGES = 9999
NUM_VAL_IMAGES = 50
EPOCHS = 10

AUTOTUNE = tf_data.AUTOTUNE

In [None]:
# data loading, built-in keras function will not work for image segmentation targets

# sorted list of sample file paths, as many as NUM_TRAIN_IMAGES, from the first file
input_img_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.startswith("rgb")])[:NUM_TRAIN_IMAGES]

# sorted list of target file paths, as many as NUM_TRAIN_IMAGES
target_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.endswith(".png") and not fname.startswith("rgb")])[:NUM_TRAIN_IMAGES]

"""
# sorted list of validation sample file paths, as many as NUM_VAL_IMAGES, from the first file not included in training images
val_img_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.startswith("rgb")])[NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES]

# sorted list of validation target file paths, as many as NUM_VAL_IMAGES
val_target_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.endswith(".png") and not fname.startswith("rgb")])[NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES]
"""

# support function for loading images
# bool mask not in use, bool rotate for rotating iphone pictures for inference
# read the file, decode it from string tensor to a uint8 tensor, resize it to match IMAGE_SIZE
def read_image(image_path, mask=False, rotate=False):
    image = tf_io.read_file(image_path)
    image = tf_image.decode_png(image, channels=3)
    image.set_shape([None, None, 3])
    image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
    if rotate:
        image = tf_image.rot90(image=image, k=3)
    return image


# takes a path of a target, returns the target
def path_to_target(path):
    img = img_to_array(load_img(path, target_size=(IMAGE_SIZE, IMAGE_SIZE), color_mode="grayscale"))
    img = img.astype("uint8")
    return img


# prevents the targets from going to read_image
def load_data(image_list, mask_list):
    image = read_image(image_list)
    #mask = read_image(mask_list, mask=True)
    return image, mask_list


# in case CutMix augmentation is needed, keras_cv_CutMix layer inputs are dicts with keys "images" and "labels"
def to_dict(image, label):
     return {"images": image, "labels": label}


# apply keras_cv Cutout augmentation layer to data
random_cutout = keras_cv.layers.RandomCutout(0.1, 0.1)
def apply_augment(samples, targets):
    samples = random_cutout(samples)
    return samples, targets


# loop the targets to an np array, create a tf_data dataset, map the dataset to load images, shuffle, batch the dataset and apply augment
def data_generator(image_list, mask_list):
    targets = np.zeros((len(mask_list),) + (IMAGE_SIZE, IMAGE_SIZE) + (1,), dtype="uint8")
    for i in range(len(mask_list)): 
        targets[i] = path_to_target(mask_list[i])
    dataset = tf_data.Dataset.from_tensor_slices((image_list, targets))
    dataset = dataset.map(load_data, num_parallel_calls=AUTOTUNE)
   # dataset = dataset.map(to_dict, num_parallel_calls=AUTOTUNE)
    dataset = dataset.shuffle(400)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.map(apply_augment, num_parallel_calls=AUTOTUNE)

    return dataset


train_dataset = data_generator(input_img_paths, target_paths)
#val_dataset = data_generator(val_img_paths, val_target_paths)


print("Train Dataset:", train_dataset)
#print("Val Dataset:", val_dataset)


In [None]:
def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model

model = get_model(img_size=(IMAGE_SIZE, IMAGE_SIZE), num_classes=4)
model.summary()

In [None]:
model.compile(
    optimizer=keras.optimizers.Adam(0.001), loss="sparse_categorical_crossentropy"
)

callbacks = [
    keras.callbacks.ModelCheckpoint("models/unet_xception.keras", save_best_only=True),
    #keras.callbacks.TensorBoard(log_dir="tensorboard", write_images=True,)
]

# Train the model, validation on a dataset split is unnecessary
epochs = EPOCHS
history = model.fit(
    train_dataset,
    epochs=epochs,
    # validation_data=val_dataset,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# custom colormap in order to colorize our predictions as a segmentation mask
#                        R  G  B
colormap = np.asarray([[0, 99, 0], # class id 0
                       [99, 0, 0], # class id 1
                       [0, 0, 99], # class id 2
                       [66, 0, 33] # class id 3
                       ], 
                       dtype=np.uint8)


def infer(model, image_tensor):
    # add a batch axis and convert to a numpy array
    predictions = model.predict(np.expand_dims((image_tensor), axis=0))
    
    # remove 1D axes, in this case the batch axis
    predictions = np.squeeze(predictions)

    # get the indices with higest value along the softmax prediction axis, results in a class id for each pixel
    predictions = np.argmax(predictions, axis=2)
    return predictions


# apply our colormap to the pixel ints given by infer function
# mask is the prediction matrix, shape of 512, 512
def decode_segmentation_masks(mask, colormap, n_classes):
    # make a 512x512 matrix for each color value in rgb
    r = np.zeros_like(mask).astype(np.uint8)
    g = np.zeros_like(mask).astype(np.uint8)
    b = np.zeros_like(mask).astype(np.uint8)

    # l classes from 0 to NUM_CLASSES, each r,g,b array gets indexed by a shape 512, 512 array of booleans indicating a given class l
    # each iteration of the loop updates the values r,g,b matrices for a given class
    # for example, when l is 1, idx matrix gets True values for those indices where the prediction matrix has values 1
    for l in range(0, n_classes):
        idx = mask == l
        r[idx] = colormap[l, 0]
        g[idx] = colormap[l, 1]
        b[idx] = colormap[l, 2]

    # stack the r,g,b arrays to get an rgb array of shape 512,512,3 
    rgb = np.stack([r, g, b], axis=2)
    return rgb


# display_list is the real world picture and its corresponding predicted rgb-colorized mask, figsize to configure the size 
def plot_samples_matplotlib(display_list, figsize=(5, 3)):
    # make subplots to display both the image and its mask
    # _ for Figure is of no use, axs is a list of the subplots (one row, two columns for each image we are predicting)
    _, axs = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)

    # separate rgb images and segmentation masks by checking if the last axis is ,3 which is the rgb axis
    for i in range(len(display_list)):
        if display_list[i].shape[-1] == 3:
            axs[i].axis('off')
            axs[i].imshow(keras.utils.array_to_img(display_list[i]))
        else:
            axs[i].axis('off')
            axs[i].imshow(display_list[i])
    plt.show()

# collection function for all the support functions
# gets called with a list of pictures to predict, a colormap to colorize the prediction and a trained model to use in inference
# remember to rotate if needed
def plot_predictions(images_list, colormap, model):
    for image_file in images_list:
        image_tensor = read_image(image_file, rotate=False)
        prediction_mask = infer(image_tensor=image_tensor, model=model)
        prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, NUM_CLASSES)
        plot_samples_matplotlib(
            [image_tensor, prediction_colormap], figsize=(18, 14)
        )

In [None]:
input_paths_list = sorted(
    [os.path.join('path_to_test_data, fname)
    for fname in os.listdir('path_to_test_data')
    if fname.startswith("rgb")])

plot_predictions(input_paths_list, colormap, model=model)

In [None]:
# manipulate class ids after the fact
def correct_classes(target_arr):
    testing = target_arr[:, :, 0] 
    unique_ids = np.unique(testing)
    #labels in hand labelled ground truth
    labels = {"0": {"class":"background"}, 
              "1": {"class": "table"}, 
              "2": {"class": "props"}}
    # matching labels with SD
    CUSTOM_LABELS = {
                "background": 0,
                "props": 2,
                "table": 3,
        }

    corrected_class_ids = np.zeros((testing.shape[0], testing.shape[1]), dtype=np.uint8)
    for i, _id in enumerate(unique_ids):
        obj_label = [*labels[str(_id)].values()][0].lower()
        if obj_label in CUSTOM_LABELS:
            corrected_class_ids[testing == _id] = CUSTOM_LABELS[obj_label]

    print("Corrected label id's: " + str(np.unique(corrected_class_ids, return_counts=True)))
    return corrected_class_ids

In [None]:
# Mean IoU metric to evaluate the data numerically

results_list = []
m = keras.metrics.MeanIoU(num_classes=NUM_CLASSES)

for i in range(len(input_paths_list)):

    metrics_img_path = f'path_to_test_data/rgb_{i}.jpg'
    ground_truth_path = f'path_to_test_data/mask_{i}.png'

    # remember to rotate if needed
    metrics_img = read_image(metrics_img_path, rotate=False)
    metrics_pred = infer(model, metrics_img)
    ground_truth = path_to_target(ground_truth_path)
    ground_truth = correct_classes(ground_truth)

    # calculate the mean of IoU values of all classes
    m.reset_state()
    m.update_state(ground_truth, metrics_pred)
    results_list.append(m.result().numpy())
    #print(results_list[i])

for i in results_list:
    print(i)