# @ Jani Kuhno
# Model from https://keras.io/examples/vision/deeplabv3_plus/
# by Soumik Rakshit

In [None]:
import tensorflow as tf
import keras
from keras import layers
from keras import ops
import keras_cv

import os
import numpy as np
import matplotlib.pyplot as plt

# For data preprocessing
from tensorflow import image as tf_image
from tensorflow import data as tf_data
from tensorflow import io as tf_io
from tensorflow.keras.utils import load_img, img_to_array

In [None]:
#IMAGE_SIZE: indicating one side length of a square image, when needing a tuple of WxH, use (IMAGE_SIZE, IMAGE_SIZE)
#BATCH_SIZE: smaller models might fit memory with larger batch size, experiment with batch size and optimizer learning rate
#NUM_CLASSES: how many classes are annotated by replicator, take into account background and unclassified which come automatic
#input_dir: directory for training data, same as output_dir in replicator
#NUM_TRAIN_IMAGES: how many samples to train on
#NUM_VAL_IMAGES: how many samples to validate on


IMAGE_SIZE = 512
BATCH_SIZE = 4
NUM_CLASSES = 4
input_dir = "your_dataset_output_path"
NUM_TRAIN_IMAGES = 5000
NUM_VAL_IMAGES = 50
EPOCHS = 10

AUTOTUNE = tf_data.AUTOTUNE

In [None]:
# data loading, built-in keras function will not work for image segmentation targets

# sorted list of sample file paths, as many as NUM_TRAIN_IMAGES, from the first file
input_img_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.startswith("rgb")])[:NUM_TRAIN_IMAGES]

# sorted list of target file paths, as many as NUM_TRAIN_IMAGES
target_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.endswith(".png") and not fname.startswith("rgb")])[:NUM_TRAIN_IMAGES]

"""
# sorted list of validation sample file paths, as many as NUM_VAL_IMAGES, from the first file not included in training images
val_img_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.startswith("rgb")])[NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES]

# sorted list of validation target file paths, as many as NUM_VAL_IMAGES
val_target_paths = sorted(
 [os.path.join(input_dir, fname)
 for fname in os.listdir(input_dir)
 if fname.endswith(".png") and not fname.startswith("rgb")])[NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES]
"""

# support function for loading images
# bool mask not in use, bool rotate for rotating iphone pictures for inference
# read the file, decode it from string tensor to a uint8 tensor, resize it to match IMAGE_SIZE
def read_image(image_path, mask=False, rotate=False):
    image = tf_io.read_file(image_path)
    image = tf_image.decode_png(image, channels=3)
    image.set_shape([None, None, 3])
    image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
    if rotate:
        image = tf_image.rot90(image=image, k=3)
    return image


# takes a path of a target, returns the target
def path_to_target(path):
    img = img_to_array(load_img(path, target_size=(IMAGE_SIZE, IMAGE_SIZE), color_mode="grayscale"))
    img = img.astype("uint8")
    return img


# prevents the targets from going to read_image
def load_data(image_list, mask_list):
    image = read_image(image_list)
    #mask = read_image(mask_list, mask=True)
    return image, mask_list


# in case CutMix augmentation is needed, keras_cv_CutMix layer inputs are dicts with keys "images" and "labels"
def to_dict(image, label):
     return {"images": image, "labels": label}


# apply keras_cv Cutout augmentation layer to data
random_cutout = keras_cv.layers.RandomCutout(0.1, 0.1)
def apply_augment(samples, targets):
    samples = random_cutout(samples)
    return samples, targets


# loop the targets to an np array, create a tf_data dataset, map the dataset to load images, shuffle, batch the dataset and apply augment
def data_generator(image_list, mask_list):
    targets = np.zeros((len(mask_list),) + (IMAGE_SIZE, IMAGE_SIZE) + (1,), dtype="uint8")
    for i in range(len(mask_list)): 
        targets[i] = path_to_target(mask_list[i])
    dataset = tf_data.Dataset.from_tensor_slices((image_list, targets))
    dataset = dataset.map(load_data, num_parallel_calls=AUTOTUNE)
   # dataset = dataset.map(to_dict, num_parallel_calls=AUTOTUNE)
    dataset = dataset.shuffle(400)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.map(apply_augment, num_parallel_calls=AUTOTUNE)

    return dataset


train_dataset = data_generator(input_img_paths, target_paths)
#val_dataset = data_generator(val_img_paths, val_target_paths)


print("Train Dataset:", train_dataset)
#print("Val Dataset:", val_dataset)



In [None]:
def convolution_block(
    block_input,
    num_filters=256,
    kernel_size=3,
    dilation_rate=1,
    use_bias=False,
):
    x = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        dilation_rate=dilation_rate,
        padding="same",
        use_bias=use_bias,
        kernel_initializer=keras.initializers.HeNormal(),
    )(block_input)
    x = layers.BatchNormalization()(x)
    return ops.nn.relu(x)


def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x = convolution_block(x, kernel_size=1, use_bias=True)
    out_pool = layers.UpSampling2D(
        size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
        interpolation="bilinear",
    )(x)

    out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

    x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
    output = convolution_block(x, kernel_size=1)
    return output

In [None]:
def DeeplabV3Plus(image_size, num_classes):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    preprocessed = keras.applications.resnet50.preprocess_input(model_input)
    resnet50 = keras.applications.ResNet50(
        weights="imagenet", include_top=False, input_tensor=preprocessed
    )
    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = DilatedSpatialPyramidPooling(x)

    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=model_input, outputs=model_output)


model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
model.summary()

In [None]:
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=loss,
    metrics=["accuracy"],
)

callbacks = [
    keras.callbacks.ModelCheckpoint("models/deeplab.keras", save_best_only=True),
    # keras.callbacks.TensorBoard(log_dir="tensorboard", write_images=True,)
]

# calling model checkpoint callback has error with keras ResNet preprocessing layer, TODO: fix
history = model.fit(train_dataset, 
                    #validation_data=val_dataset, 
                    epochs=EPOCHS, 
                    #callbacks=callbacks,
                    verbose=1)


In [None]:
# custom colormap in order to colorize our predictions as a segmentation mask
#                        R  G  B
colormap = np.asarray([[0, 99, 0], # class id 0
                       [99, 0, 0], # class id 1
                       [0, 0, 99], # class id 2
                       [66, 0, 33] # class id 3
                       ], 
                       dtype=np.uint8)


def infer(model, image_tensor):
    # add a batch axis and convert to a numpy array
    predictions = model.predict(np.expand_dims((image_tensor), axis=0))
    
    # remove 1D axes, in this case the batch axis
    predictions = np.squeeze(predictions)

    # get the indices with higest value along the softmax prediction axis, results in a class id for each pixel
    predictions = np.argmax(predictions, axis=2)
    return predictions


# apply our colormap to the pixel ints given by infer function
# mask is the prediction matrix, shape of 512, 512
def decode_segmentation_masks(mask, colormap, n_classes):
    # make a 512x512 matrix for each color value in rgb
    r = np.zeros_like(mask).astype(np.uint8)
    g = np.zeros_like(mask).astype(np.uint8)
    b = np.zeros_like(mask).astype(np.uint8)

    # l classes from 0 to NUM_CLASSES, each r,g,b array gets indexed by a shape 512, 512 array of booleans indicating a given class l
    # each iteration of the loop updates the values r,g,b matrices for a given class
    # for example, when l is 1, idx matrix gets True values for those indices where the prediction matrix has values 1
    for l in range(0, n_classes):
        idx = mask == l
        r[idx] = colormap[l, 0]
        g[idx] = colormap[l, 1]
        b[idx] = colormap[l, 2]

    # stack the r,g,b arrays to get an rgb array of shape 512,512,3 
    rgb = np.stack([r, g, b], axis=2)
    return rgb


# display_list is the real world picture and its corresponding predicted rgb-colorized mask, figsize to configure the size 
def plot_samples_matplotlib(display_list, figsize=(5, 3)):
    # make subplots to display both the image and its mask
    # _ for Figure is of no use, axs is a list of the subplots (one row, two columns for each image we are predicting)
    _, axs = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)

    # separate rgb images and segmentation masks by checking if the last axis is ,3 which is the rgb axis
    for i in range(len(display_list)):
        if display_list[i].shape[-1] == 3:
            axs[i].axis('off')
            axs[i].imshow(keras.utils.array_to_img(display_list[i]))
        else:
            axs[i].axis('off')
            axs[i].imshow(display_list[i])
    plt.show()

# collection function for all the support functions
# gets called with a list of pictures to predict, a colormap to colorize the prediction and a trained model to use in inference
# remember to adjust rotate boolean when calling read_image
def plot_predictions(images_list, colormap, model):
    for image_file in images_list:
        image_tensor = read_image(image_file, rotate=False)
        prediction_mask = infer(image_tensor=image_tensor, model=model)
        prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, NUM_CLASSES)
        plot_samples_matplotlib(
            [image_tensor, prediction_colormap], figsize=(18, 14)
        )

In [None]:
input_paths_list = sorted(
    [os.path.join('path_to_test_data, fname)
    for fname in os.listdir('path_to_test_data')
    if fname.startswith("rgb")])

plot_predictions(input_paths_list, colormap, model=model)

In [None]:
def correct_classes(target_arr):
    testing = target_arr[:, :, 0] 
    unique_ids = np.unique(testing)
    #labels in hand labelled ground truth
    labels = {"0": {"class":"background"}, 
              "1": {"class": "table"}, 
              "2": {"class": "props"}}
    # matching labels with SD
    CUSTOM_LABELS = {
                "background": 0,
                "props": 2,
                "table": 3,
        }

    corrected_class_ids = np.zeros((testing.shape[0], testing.shape[1]), dtype=np.uint8)
    for i, _id in enumerate(unique_ids):
        obj_label = [*labels[str(_id)].values()][0].lower()
        if obj_label in CUSTOM_LABELS:
            corrected_class_ids[testing == _id] = CUSTOM_LABELS[obj_label]

    print("Corrected label id's: " + str(np.unique(corrected_class_ids, return_counts=True)))
    return corrected_class_ids

In [None]:
# Mean IoU metric to evaluate the data numerically

results_list = []
m = keras.metrics.MeanIoU(num_classes=NUM_CLASSES)

for i in range(len(input_paths_list)):

    metrics_img_path = f'path_to_test_data/rgb_{i}.jpg'
    ground_truth_path = f'path_to_test_data/mask_{i}.png'

    # remember to rotate if needed
    metrics_img = read_image(metrics_img_path, rotate=False)
    metrics_pred = infer(model, metrics_img)
    ground_truth = path_to_target(ground_truth_path)
    ground_truth = correct_classes(ground_truth)

    # calculate the mean of IoU values of all classes
    m.reset_state()
    m.update_state(ground_truth, metrics_pred)
    results_list.append(m.result().numpy())
    #print(results_list[i])

for i in results_list:
    print(i)