# Particle segmentation Notebook

This is an example on how to train on our custom particle dataset.  
In this example we show how we trained on the ABT 10Shape dataset.  
If you want to use it for the Raw dataset or the 5Shape one there are
some uncommented cells witch showcase the use of those other variants.   

# 00 Inital Steps

In [None]:
!pip install keras-unet-collection
!pip install -U -q segmentation-models

In [2]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
import cv2
import json
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from google.colab import drive
from keras import backend as K
from keras.metrics import Precision, Recall, AUC, Accuracy
from keras_unet_collection import losses, models, utils
import segmentation_models as sm
from skimage import measure
from sklearn.utils import shuffle
from tensorflow import keras
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf
from pycocotools import mask
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

Segmentation Models: using `tf.keras` framework.


### some functions for later

In [3]:
%matplotlib inline
def ax_decorate_box(ax):
    [j.set_linewidth(0) for j in ax.spines.values()]
    ax.tick_params(axis="both", which="both", bottom=False, top=False, \
               labelbottom=False, left=False, right=False, labelleft=False)
    return ax

In [4]:
# Plotting function
def plot_metrics(history, metrics_list):
    for metric in metrics_list:
        plt.figure()

        metric_values = history.history[metric]
        val_metric_values = history.history['val_'+metric]

        epochs = range(1, len(metric_values) + 1)
        plt.plot(epochs, metric_values, 'y', label='Training '+metric)
        plt.plot(epochs, val_metric_values, 'r', label='Validation '+metric)

        plt.title('Training and Validation '+metric)
        plt.xlabel('Epochs')
        plt.ylabel(metric)
        plt.legend()
        plt.savefig(metric+'_plot.jpg')
        plt.show()

In [5]:
def make_coco_eval_data(image_directory,mask_directory,change_diff):
    # Create a list to hold all annotations and a dict to hold all images
    annotations = []
    images = []
    # Iterate over all images
    for i, image_name in enumerate(os.listdir(image_directory)):
        # Load the image
        image = cv2.imread(os.path.join(image_directory, image_name))

        # Create an entry for this image
        images.append({
            "id": i,
            "file_name": image_name,
            "width": image.shape[1],
            "height": image.shape[0],
        })

        # Load the corresponding mask
        mask = cv2.imread(os.path.join(mask_directory, image_name.replace('.jpg', change_diff)), cv2.IMREAD_GRAYSCALE)

        # Iterate over all possible classes
        for class_id in range(0, 10):  # Assuming class IDs start from 1
            binary_mask = (mask == class_id).astype(np.uint8)
            area = np.sum(binary_mask)

            # Find contours in the binary mask
            contours = measure.find_contours(binary_mask, 0.5)

            # Convert contours to segmentation format
            segmentation = []
            area = 0
            for contour in contours:
                contour = np.flip(contour, axis=1)
                seg = contour.ravel().tolist()
                # area += cv2.contourArea(contour)
                if len(seg) > 4:
                    segmentation.append(seg)
            if len(segmentation) == 0:
                continue
            # area = cv2.contourArea(contour)

            # Create a new annotation for each class
            annotations.append({
                "id": len(annotations) + 1,
                "image_id": i,
                "category_id": class_id,
                "width": image.shape[1],
                "height": image.shape[0],
                "score": 0.0,
                "bbox": [float(np.min(contour[:, 0])), float(np.min(contour[:, 1])),
                        float(np.max(contour[:, 0]) - np.min(contour[:, 0])),
                        float(np.max(contour[:, 1]) - np.min(contour[:, 1]))],
                "area": area,#300,#float(maskUtils.area(maskUtils.encode(np.asfortranarray(binary_mask)))),
                "segmentation": segmentation,
                "iscrowd": 0,
            })
    return images, annotations

# 01 Preprocessing

In [None]:
drive.mount('/content/gdrive/', force_remount=True)

In [22]:
!unzip 10S_raw_abt.zip

In [6]:
img_SIZE = 256
num_CLASSES = 11

### Basic Normalisation

In [7]:
def input_data_process(input_array):
    '''converting pixel vales to [0, 1]'''
    return input_array/255.

def target_data_process(target_array):
    return keras.utils.to_categorical(target_array, num_classes=num_CLASSES)

Uncomment this one if you want to use it with the 5S dataset.

In [14]:
# ## this one is for our 5 Shape dataset where some of the particles class lables
# ## need to be remapped


# def input_data_process(input_array):
#     '''converting pixel vales to [0, 1]'''
#     return input_array/255

# def target_data_process(target_array):
#     target_array[target_array == 8] = 4
#     target_array[target_array == 9] = 5
#     return keras.utils.to_categorical(target_array, num_classes=num_CLASSES)

### Keras VGG16 Normalisation

In [None]:
# def input_data_process(input_array):
#     return tf.keras.applications.vgg16.preprocess_input(input_array, data_format=None)

# def target_data_process(target_array):
#     return keras.utils.to_categorical(target_array, num_classes=num_CLASSES)

In [None]:
# def input_data_process_vis(input_array):
#     '''converting pixel vales to [0, 1]'''
#     return input_array/255

# def target_data_process_vis(target_array):
#     return keras.utils.to_categorical(target_array, num_classes=num_CLASSES)



### load images

In [None]:
dataset_mode = 'abt' # change abt to raw if you want to train on only the raw data

In [None]:
# paths to your dataset
path_train_img_raw = '/content/img_train/'
path_train_mask_raw = '/content/msk_train/'
path_train_img_a = '/content/iabt_img_train/'
path_train_mask_a = '/content/iabt_msk_train/'

path_valid_img = '/content/img_val/'
path_valid_mask = '/content/msk_val/'

path_test_img = '/content/img_test/'
path_test_mask = '/content/msk_test/'

train_input_names_raw = np.array(sorted(glob(path_train_img_raw +'*.jpg')))
train_label_names_raw = np.array(sorted(glob(path_train_mask_raw +'*.png')))

train_input_names_a = np.array(sorted(glob(path_train_img_a +'*.jpg')))
train_label_names_a = np.array(sorted(glob(path_train_mask_a +'*.png')))

if dataset_mode == 'raw':
  train_input_names_a = np.array(sorted(glob(path_train_img_raw +'*.jpg')))
  train_label_names_a = np.array(sorted(glob(path_train_mask_raw +'*.png')))

train_input_names = np.concatenate((train_input_names_raw,
                                    train_input_names_a),
                                   axis=0)
train_label_names = np.concatenate((train_label_names_raw,
                                    train_label_names_a),
                                   axis=0)

valid_input_names = np.array(sorted(glob(path_valid_img +'*.jpg')))
valid_label_names = np.array(sorted(glob(path_valid_mask +'*.png')))
test_input_names = np.array(sorted(glob(path_test_img+'*.jpg')))
test_label_names = np.array(sorted(glob(path_test_mask+'*.png')))
L_train = len(train_input_names)

print("Training:validation = {}:{}:{}".format(len(train_input_names),
                                              len(valid_input_names),
                                              len(test_label_names)))

In [20]:
a_shuffled, b_shuffled = shuffle(train_input_names, train_label_names)
### SHUFFLE
train_input = input_data_process(utils.image_to_array(a_shuffled,
                                                      size=img_SIZE,
                                                      channel=3))
train_label = target_data_process(utils.image_to_array(b_shuffled,
                                                       size=img_SIZE,
                                                       channel=1))

In [None]:
valid_input = input_data_process(utils.image_to_array(valid_input_names,
                                                      size=img_SIZE,
                                                      channel=3))
valid_label = target_data_process(utils.image_to_array(valid_label_names,
                                                       size=img_SIZE,
                                                       channel=1))

### visualise some of the training data

In [None]:


def plot_random_images(input_images, label_images, num_images):
    indices = np.random.choice(len(input_images), num_images)
    for idx in indices:
        label_class = np.argmax(train_label[idx], axis=-1)
        plt.figure(figsize=(10,5))
        plt.subplot(1, 2, 1)
        plt.imshow(train_input[idx])
        plt.title('Input Image')
        plt.subplot(1, 2, 2)
        plt.imshow(label_class, cmap='gray')
        plt.title('Label')
        plt.show()
        unique_classes = np.unique(label_class)
        print(f"Unique classes represented in the label image: {unique_classes}")
plot_random_images(train_input_names, train_label_names, num_images=10)


### calc weights for imbalance


In [None]:
import numpy as np

num_classes = train_label.shape[-1]
class_counts = np.sum(train_label, axis=(0,1,2))
print('Pixel per class:')
print(class_counts)

# Calculate class weights
class_weights = 1. / class_counts  # Inverse of the number of pixels
class_weights = class_weights / np.max(class_weights)
print('Class weights for Dice Loss')
print(class_weights)

# 02 Model selection

In [None]:
## unet
model = models.unet_2d((img_SIZE, img_SIZE, 3), filter_num=[64, 128, 256, 512, 1024], n_labels=num_CLASSES, stack_num_down=2, stack_num_up=2,
            activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True,
            backbone='VGG16', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet')

In [None]:
## attent UNet
model = models.att_unet_2d((img_SIZE, img_SIZE, 3), filter_num=[64, 128, 256, 512, 1024], n_labels=num_CLASSES,
                           stack_num_down=2, stack_num_up=2, activation='ReLU',
                           atten_activation='ReLU', attention='add', output_activation='Softmax',
                           batch_norm=False, pool=True, unpool=True,
                           backbone='VGG16', weights='imagenet',
                           freeze_backbone=True, freeze_batch_norm=True,
                           name='attunet')

In [None]:
## UNet+++
model = models.unet_3plus_2d((img_SIZE, img_SIZE, 3), n_labels=num_CLASSES, filter_num_down=[64, 128, 256, 512, 1024], filter_num_skip='auto', filter_num_aggregate='auto',
                  stack_num_down=2, stack_num_up=2, activation='ReLU', output_activation='Softmax',
                  batch_norm=False, pool=True, unpool=True, deep_supervision=False,
                  backbone='VGG16', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus')



### compile the model

In [None]:
dice_loss = sm.losses.DiceLoss(class_weights=class_weights)
total_loss = dice_loss

In [None]:
metrics = [
    sm.metrics.IOUScore(threshold=0.5)
    ]

In [None]:
model.compile(loss=total_loss,
              optimizer=keras.optimizers.Adam(learning_rate=0.00025),
              metrics=metrics)

# 03 Train


In [None]:
N_epoch = 64
N_batch = 8

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='/content/tmp1/',
    save_weights_only=True,
    monitor='val_iou_score',
    mode='max',
    save_best_only=True)

# Train your model
history1 = model.fit(
    x=train_input,
    y=train_label,
    batch_size=N_batch,
    epochs=N_epoch,
    shuffle=True,
    validation_data=(valid_input, valid_label),
    callbacks=[model_checkpoint_callback])

# load the best performing model
model.load_weights('/content/tmp1/')
y_pred = model.predict(valid_input)
val_iou = np.mean(losses.iou_seg(valid_label, y_pred))
iou_per_class = []
for i in range(num_CLASSES):
    iou = np.mean(losses.iou_seg(valid_label[..., i], y_pred[..., i]))
    iou_per_class.append(iou)
print('IoU loss per class:', iou_per_class)


#########################################
########## only raw
train_input = input_data_process(utils.image_to_array(
                                                    train_input_names_raw,
                                                    size=img_SIZE, channel=3))
train_label = target_data_process(utils.image_to_array(
                                                    train_label_names_raw,
                                                    size=img_SIZE, channel=1))

model.load_weights('/content/tmp1/')
K.set_value(model.optimizer.learning_rate, 0.000025)
N_epoch = 32
N_batch = 8

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='/content/tmp/',
    save_weights_only=True,
    monitor='val_iou_score',
    mode='max',
    save_best_only=True)

# Train your model
history2 = model.fit(
    x=train_input,
    y=train_label,
    batch_size=N_batch,
    epochs=N_epoch,
    shuffle=True,
    validation_data=(valid_input, valid_label),
    callbacks=[model_checkpoint_callback])

# load the best performing model
model.load_weights('/content/tmp/')
y_pred = model.predict(valid_input)
val_iou = np.mean(losses.iou_seg(valid_label, y_pred))
iou_per_class = []
for i in range(num_CLASSES):
    iou = np.mean(losses.iou_seg(valid_label[..., i], y_pred[..., i]))
    iou_per_class.append(iou)
print('IoU loss per class:', iou_per_class)



#########################################
model.load_weights('/content/tmp/')
K.set_value(model.optimizer.learning_rate, 0.0000025)
N_epoch = 16
N_batch = 8

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='/content/tmp/',
    save_weights_only=True,
    monitor='val_iou_score',
    mode='max',
    save_best_only=True)

# Train your model
history3 = model.fit(
    x=train_input,
    y=train_label,
    batch_size=N_batch,
    epochs=N_epoch,
    shuffle=True,
    validation_data=(valid_input, valid_label),
    callbacks=[model_checkpoint_callback])

# load the best performing model
model.load_weights('/content/tmp/')
y_pred = model.predict(valid_input)
val_iou = np.mean(losses.iou_seg(valid_label, y_pred))
iou_per_class = []
for i in range(num_CLASSES):
    iou = np.mean(losses.iou_seg(valid_label[..., i], y_pred[..., i]))
    iou_per_class.append(iou)
print('IoU loss per class:', iou_per_class)

In [None]:
plot_metrics(history1, metrics_list=['loss', 'iou_score'])

In [None]:
plot_metrics(history2, metrics_list=['loss', 'iou_score'])

In [None]:
plot_metrics(history3, metrics_list=['loss', 'iou_score'])

# 04 Evaluate

In [None]:
test_metrics = model.evaluate(x=test_input, y=test_label, batch_size=N_batch)

print("\nEvaluation results:")
for name, value in zip(model.metrics_names, test_metrics):
    print(f"{name}: {value}")

In [None]:

y_pred = model.predict(test_input)
val_iou = np.mean(losses.iou_seg(test_label, y_pred))
iou_per_class = []

for i in range(num_CLASSES):
    iou = np.mean(losses.iou_seg(test_label[..., i], y_pred[..., i]))
    iou_per_class.append(iou)
print('IoU loss per class test:', iou_per_class)

In [None]:

# Predict the segmentation for this sample
prediction = y_pred
# test_input_vis = input_data_process_vis(utils.image_to_array(test_input_names, size=img_SIZE, channel=3))
for i in range(len(test_input)):
  # Convert the prediction to a single-channel segmentation mask
  pred_mask = np.argmax(prediction[i], axis=-1)
  # Convert the ground truth to a single-channel segmentation mask
  gt_mask = np.argmax(test_label[i], axis=-1)
  plt.imsave('results/'+str(i) +'_.jpg',test_input[i])
  cv2.imwrite('results/'+str(i) +'_dt.png', pred_mask)
  cv2.imwrite('results/'+str(i) +'_gt.png', gt_mask)

Use this one if you are using VGG normalisation.

In [None]:
### use this one if you are using VGG normalization


# # Predict the segmentation for this sample
# prediction = model.predict(test_input)
# test_input_vis = input_data_process_vis(utils.image_to_array(test_input_names, size=img_SIZE, channel=3))
# for i in range(len(test_input_vis)):
#   # Convert the prediction to a single-channel segmentation mask
#   pred_mask = np.argmax(prediction[i], axis=-1)
#   # Convert the ground truth to a single-channel segmentation mask
#   gt_mask = np.argmax(test_label[i], axis=-1)
#   plt.imsave('results/'+str(i) +'_.jpg',test_input_vis[i])
#   cv2.imwrite('results/'+str(i) +'_dt.png', pred_mask) #pred_mask
#   cv2.imwrite('results/'+str(i) +'_gt.png', gt_mask) #pred_mask

### preparing for MS-COCO evaluatio
since the coco library needs a specific format in order to evaluate
this is a "poor mans implementation" for doing so.

In [None]:

image_directory_GT = 'results'
mask_directory_GT = 'results'
output_file_GT = 'GT_anno.json'
change_diff_GT = 'gt.png'


image_directory_DT = 'results'
mask_directory_DT = 'results'
output_file_DT = 'results_GT_anno.json'
change_diff_DT = 'dt.png'

# Save everything to the annotation file
images, annotations = make_coco_eval_data(image_directory_GT,mask_directory_GT,change_diff_GT)
with open(output_file_GT, 'w') as f:
    json.dump({
        "images": images,
        "annotations": annotations,
        "categories": [{"id": i, "name": str(i)} for i in range(0, 10)],
    }, f)
images, annotations2 = make_coco_eval_data(image_directory_DT,mask_directory_DT,change_diff_DT)
with open(output_file_DT, 'w') as f:
    json.dump({"annotations":annotations2
    }, f)
with open('/content/results_GT_anno.json', 'r') as f:
    data = f.read()
    tmp = data[16:-1]

with open('new_results_GT_anno.json', 'w') as f:
    f.write(tmp)

In [None]:
cocoGt=COCO('/content/GT_anno.json')
cocoDt=cocoGt.loadRes('/content/new_results_GT_anno.json')

In [None]:
cocoEval = COCOeval(cocoGt,cocoDt,'segm')
# cocoEval.params.imgIds  = 1
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()