In [None]:
# Execute if running on Colab
from google.colab import drive
drive.mount('/content/drive')

In [2]:
import os
import cv2
import pickle
import pathlib
import numpy as np
import tensorflow as tf
from sklearn import utils
from tensorflow import keras
import matplotlib.pyplot as plt
from numpy.typing import NDArray
import xml.etree.ElementTree as ET
from warnings import filterwarnings
from sklearn.metrics import classification_report, ConfusionMatrixDisplay

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress Tensorflow messages
filterwarnings("ignore")
%matplotlib inline

# ##################################
# Global constants - amend as needed
# ##################################

# Set Root Directory...
# Google Drive
BASE_FOLDER = '/content/drive/My Drive/A.I. Group 6 Project'

DATASET_FOLDER = f'{BASE_FOLDER}/Datasets/CTC_CCA_dataset'
MODEL_FOLDER = f'{BASE_FOLDER}/Models'
ANNOTATION_FOLDER = f'{DATASET_FOLDER}/annotations'
ANNOTATION_TRAIN_FOLDER = f'{ANNOTATION_FOLDER}/trainval'
ANNOTATION_TEST_FOLDER = f'{ANNOTATION_FOLDER}/test'
IMG_FOLDER = f'{DATASET_FOLDER}/raw_images_for_model'                            # 1207 images - we use only 1087
IMG_BF_FOLDER = f'{IMG_FOLDER}/brightfield'
IMG_FLR_FOLDER = f'{IMG_FOLDER}/fluorescence'
CELL_IMG_FOLDER = f'{DATASET_FOLDER}/cell_images'
CELL_IMG_TRAIN_FOLDER =f'{CELL_IMG_FOLDER}/trainval/fluorescence'
CELL_IMG_TEST_FOLDER =f'{CELL_IMG_FOLDER}/test/fluorescence'

IMG_RESIZE = 224
VALIDATION_SPLIT = 0.15
EPOCHS=10
BATCH_SIZE = 32

# set constant to zero if want to process all the images
IMG_TRAIN_CNT = 0               # includes validation
IMG_TEST_CNT = 0
CELL_CLS_TRAIN_CNT = 3000       # per class, includes validation
CELL_CLS_TEST_CNT = 600         # per class

# cell model - transfer model trained with cell images
# segmentation model - weights and results using transfer models
CREATE_CELL_IMAGES = False
RETRAIN_CELL_MODEL = False
RETRAIN_IMG_MODEL = False
model_info = {
    'cell_train_weights'         : f'{MODEL_FOLDER}/cell_train.weights.h5',
    'cell_train_results'         : f'{MODEL_FOLDER}/cell_train.results.pkl',
    'img_train_nontuned_weights' : f'{MODEL_FOLDER}/img_train_nontuned.weights.h5',
    'img_train_nontuned_results' : f'{MODEL_FOLDER}/img_train_nontuned.results.pkl',
    'img_train_tuned_weights'    : f'{MODEL_FOLDER}/img_train_tuned.weights.h5',
    'img_train_tuned_results'    : f'{MODEL_FOLDER}/img_train_tuned.results.pkl',
    'img_test_nontuned_results'  : f'{MODEL_FOLDER}/img_test_nontuned.results.pkl',
    'img_test_tuned_results'     : f'{MODEL_FOLDER}/img_test_tuned.results.pkl'
}

## Transfer Learning and Feature Extraction

A key advantage of this appraoch is that you only run the base model once on your data, rather than once per epoch of training; it's a lot faster & cheaper

<ul>
<li>Instantiate a base model and load pre-trained weights into it.</li>
<li>Run the cell image dataset through it and extract the features from an output layer; preferrably the last convalution layer.</li>
<li>Use that output as input data for the segmentation model.</li>
</ul>

#### Create cleansed cell images from bounfding boxes

In [3]:
def get_annotation_files(path: str, max_files:int=0) -> list[pathlib.Path]:
    ''' Generates the paths to annotation files
    
    Args: 
        path (str): the base folder containing the annotation (xml) files
        max_files (int): the length of the returned annotations - default (0) returns all the annotation files
        
    Returns: 
        A list of annotation paths

    '''
    
    
    annotation_files = list(pathlib.Path(path).glob('*.xml'))
    if len(annotation_files) == 0:
        raise UserWarning("No annotation files found - exiting process")
    print(f"There are {len(annotation_files):,} annotations in the XML folder")

    if max_files > 0:
        annotation_files = annotation_files[:max_files]
        print(f'User process limt set to {max_files} files')

    return annotation_files

def get_cell_cls_names(cell_dir: str) -> list[str]:
    ''' Creates a list of cell names based on the folder names
    
    Args: 
        cell_dir (str): the parent folder containing R, G, U cell folders
        
    Returns: 
        A list of cell folder names
    '''
    
    
    cls_names = []
    for sub_dir in [f for f in os.scandir(cell_dir) if f.is_dir()]:
        cls_names.append(sub_dir.name)
    return cls_names

def parse_xml(xml_file: str, cell_cls_names: list[str]) -> tuple[NDArray, tuple[int, int]]:
    ''' Extracts the bounding box data from an annotation (xml) file
    
    Args: 
         xml_file (str): a string path to an annotation file
         cell_cls_names list[str]: a list of the cell folder names to map the cell class to int values
         
    Returns: 
        A tuple of array of cell boxes and dimension tuple(w x h)
    '''
    chk_limit = lambda x,y: y if x > y else x

    cell_boxes = []
    tree = ET.parse(xml_file)
    root = tree.getroot()

    # retrive image dimensions
    img_width = int(root.find('.//size/width').text) # type: ignore
    img_height = int(root.find('.//size/height').text) # type: ignore

    # bounding boxes - retrieve label and co-ordinates
    for elem in root.findall('.//object'):
        cell_cls = elem.find('name').text
        cell_cls_idx = cell_cls_names.index(cell_cls)

        x_min = int(elem.find('./bndbox/xmin').text)
        y_min = int(elem.find('./bndbox/ymin').text)
        x_max = int(elem.find('./bndbox/xmax').text)
        y_max = int(elem.find('./bndbox/ymax').text)

        # ensure max co-ordinates do not exceed image size
        x_max = chk_limit(x_max, img_width)
        y_max = chk_limit(y_max, img_height)

        bbox = [x_min, y_min, x_max, y_max, cell_cls_idx]
        cell_boxes.append(bbox)

    return np.array(cell_boxes), (img_width, img_height)

#
# Important: loaded images - shape is height, width, optionally channel(s)
#
def get_colour_image(xml_file: str) -> tuple[NDArray, str]:
    '''Reads a corresponding fluorescence image from an annotation file
    
    Args: 
        xml_file: the annotation file path
        
    Returns: 
        a numpy array representation of the respective image and the image id
    '''
    img_id = os.path.splitext(os.path.basename(xml_file))[0]
    img_file = f'{IMG_FLR_FOLDER}/{img_id}.tiff'
    img = cv2.cvtColor(cv2.imread(img_file), cv2.COLOR_BGR2RGB)
    return img, img_id

def generate_cell_images(annotation_files: list[pathlib.Path], cell_dir: str, cell_resize: tuple[int, int]) -> None:
    """
    Generates a cleaner cell from an XML annotation file and a corresponding full image.

    Args:
        annotation_files list(Path): Paths to the XML annotation files.
        cell_dir (str): the parent folder containing R, G, U cell folders
        cell_resize tuple(int, int): The size of the image.

    Returns:
        none
    """
    cell_cls_names = get_cell_cls_names(cell_dir)
    rgb_cls_names = ['R','G','U']

    for xml_file in annotation_files:
        cell_boxes, img_size = parse_xml(str(xml_file), cell_cls_names)

        c_img, c_img_name = get_colour_image(str(xml_file))
        c_img = cv2.resize(c_img, img_size)

        # Create a blank multi-channel mask
        cell_cnt = 0
        c_threshold = 5

        for box in cell_boxes:
            x_min, y_min, x_max, y_max, box_cls = box
            cls_dir_name = cell_cls_names[box_cls]

            # Fill the bounding box region in the mask with the corresponding color
            cell_img = c_img[y_min:y_max, x_min:x_max]

            r_val = np.max(c_img[y_min:y_max, x_min:x_max, 0])
            g_val = np.max(c_img[y_min:y_max, x_min:x_max, 1])
            u_val = np.max(c_img[y_min:y_max, x_min:x_max, 2])

            # ignore type
            c_pixel_idx = np.argmax([r_val, g_val, u_val])
            c_pixel_val = int(np.max(cell_img))
            c_pixel_cls = cell_cls_names.index(rgb_cls_names[c_pixel_idx])

            # ensure the corresponding channel(color) is the most prominent in the cell region
            if c_pixel_val > c_threshold and c_pixel_cls == box_cls:
                cell_img = cv2.cvtColor(cv2.resize(cell_img, cell_resize), cv2.COLOR_RGB2BGR)
                fname = f'{c_img_name}_{cell_cnt}_{cls_dir_name}_{x_min}_{y_min}_{x_max}_{y_max}'
                cv2.imwrite(f"{cell_dir}/{cls_dir_name}/{fname}.tiff", cell_img)
                cell_cnt += 1

In [4]:
if CREATE_CELL_IMAGES:
    print('Creating new cell images from fluorescent full images...')
    files =  get_annotation_files(ANNOTATION_TRAIN_FOLDER)
    generate_cell_images(files, CELL_IMG_TRAIN_FOLDER, (IMG_RESIZE, IMG_RESIZE))

    files =  get_annotation_files(ANNOTATION_TEST_FOLDER)
    generate_cell_images(files, CELL_IMG_TEST_FOLDER, (IMG_RESIZE, IMG_RESIZE))

#### Process new cell images

---

In [5]:
def get_cell_images(cell_dir: str, img_resize: tuple[int, int], channels:int=3, max_imgs:int=0) -> tuple[NDArray, NDArray, list[str]]:
    ''' Creates the input data (x) and corresponding labels (y) from a folder
    
    Args: 
        cell_dir (str): the parent folder containing R, G, U cell folders
        img_resize (int, int): the cell image resized dimension
        channels int: the number of channels in the generated images
        
    Returns: 
        a tuple of input cell images, corresponding labels and the class names
    '''
    cell_cls_names = []
    x = []
    y = []
    for sub_dir in [f for f in os.scandir(cell_dir) if f.is_dir()]:
        print(f'Processing {sub_dir.name} folder...')
        cell_cls_names.append(sub_dir.name)
        img_files = list(pathlib.Path(sub_dir.path).glob('*.tiff'))
        if max_imgs > 0:
            img_files = img_files[:max_imgs]

        for img_file in img_files:
            img = cv2.resize(cv2.imread(str(img_file), cv2.IMREAD_GRAYSCALE), img_resize)
            img = np.stack((img,)*channels, axis=-1)
            x.append(img)
            y.append(len(cell_cls_names)-1)

    print(f'Cell Classes: {cell_cls_names}')

    x = np.array(x)
    y = np.stack((y,), axis=-1)
    y = np.array(y, dtype=np.uint8)

    # randomise images/labels
    x, y = utils.shuffle(x, y) # type: ignore

    return x, y, cell_cls_names

#### Build Transfer Model using pre-trained weights

In [None]:
def load_weights(model, file, desc):
    ''' Loads pretrained weights
    
    Args: 
        model: a keras template model
        file: the corresponding file with the model's weights
        desc: a simple text description
        
    Returns: 
        a model loaded with the pretrained weights
    
    '''
    
    try:
        model.load_weights(file)
        print(f'{desc} weights loaded successfully.')
    except Exception as e:
        print(f'Error loading {desc} weights:', e)
        raise
    return model

def load_results(file, desc):
    ''' Loads pretrained results
    
    Args: 
        file: the corresponding file with the model's metrics
        desc: a simple text description
        
    Returns: 
        the pretrained model's metrics
        
    '''
    try:
        with open(file , 'rb') as f:
            results = pickle.load(f)
        print(f'{desc} results loaded successfully.')
    except Exception as e:
        print(f'error loading {desc} results:', e)
        raise
    return results

def save_results(results, file, desc):
    ''' Saves the model's weights and metrics for future use instead of retraining it
    
    Args: 
        results: the models's results
        file: the path to save the results
        desc: a simple text description
    '''
    try:
        with open(file, 'wb') as f:
            pickle.dump(results, f)
            print(f'{desc} results saved successfully.')
    except Exception as e:
        print(f'Error loading {desc} results:', e)
        raise

def build_cell_model(input_dims: tuple[int, int], n_class):
    ''' creates a cell models
    
    Args: 
        input_dims: the shape of the input
        n_class: the number of the model's output classes
        
    Returns: 
        A cell model
    '''
    
    input = keras.layers.Input(shape=input_dims, name='input')
    x = keras.applications.vgg16.preprocess_input(input)

    base_model = keras.applications.VGG16 (
        weights='imagenet',
        include_top=False,
        input_shape=input_dims
        )
    base_model.trainable = False
    x = base_model(input)

    # replace fully connected layers
    x = keras.layers.Conv2D(filters=1024, kernel_size=3, padding="same", activation='relu', kernel_initializer="he_normal")(x)
    x = keras.layers.GlobalAveragePooling2D()(x)

    output = keras.layers.Dense(n_class, activation='softmax', name='output')(x)

    model = keras.Model(inputs=input, outputs=output,  name='cell_model')
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    return model

if RETRAIN_CELL_MODEL or not os.path.exists(model_info['cell_train_weights']):
    # Only load data sets if training
    x_cell_train, y_cell_train, cell_cls_names = get_cell_images(CELL_IMG_TRAIN_FOLDER, (IMG_RESIZE, IMG_RESIZE), max_imgs=CELL_CLS_TRAIN_CNT)
    print(f'Training data: X={x_cell_train.shape} Y={y_cell_train.shape}')

    x_cell_test, y_cell_test, cell_cls_names = get_cell_images(CELL_IMG_TEST_FOLDER, (IMG_RESIZE, IMG_RESIZE), max_imgs=CELL_CLS_TEST_CNT)
    print(f'Test data: X={x_cell_test.shape} Y={y_cell_test.shape}')

    # build cell model
    input_dims = (IMG_RESIZE, IMG_RESIZE, 3)
    cell_model = build_cell_model(input_dims, len(cell_cls_names))

    # train model - save results and weights
    train_results = cell_model.fit(
        x=x_cell_train,
        y=y_cell_train,
        validation_split=VALIDATION_SPLIT,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS
        )

    cell_model.save_weights(model_info['cell_train_weights'])
    train_metrics = train_results.history
    save_results(train_metrics, model_info['cell_train_results'],  'cell training')

    cell_model.summary(show_trainable=True)
else:
    x_cell_test, y_cell_test, cell_cls_names = get_cell_images(CELL_IMG_TEST_FOLDER, (IMG_RESIZE, IMG_RESIZE), max_imgs=CELL_CLS_TEST_CNT)
    print(f'Test data: X={x_cell_test.shape} Y={y_cell_test.shape}')

    cell_model = build_cell_model((IMG_RESIZE, IMG_RESIZE, 3), len(cell_cls_names))
    if os.path.exists(model_info['cell_train_weights']):
        cell_model = load_weights(cell_model, model_info['cell_train_weights'], 'cell training')

    train_metrics = load_results(model_info['cell_train_results'], 'training')

#### Performance Metrics for cell model

In [None]:
# Display training metrics
print(f"Best Train accuracy: {max(train_metrics['accuracy']):.3f}")
print(f"Best Train loss: {min(train_metrics['loss']):.3f}")
print(f"Best validation accuarcy: {max(train_metrics['val_accuracy']):.3f}")
print(f"Best validation loss: {min(train_metrics['val_loss']):.3f}")

fig, ax = plt.subplots(1, 2, figsize=(9, 6), sharex=True, sharey=False)
ax[0].plot(train_metrics['loss'], label='Train Loss')
ax[0].plot(train_metrics['val_loss'], label='Validation Loss')
ax[0].set(xlabel='Epoch', ylabel='Loss')
ax[1].plot(train_metrics['accuracy'], label='Train Accuracy')
ax[1].plot(train_metrics['val_accuracy'], label='Validation Accuracy')
ax[1].set(xlabel='Epoch', ylabel='Accuracy')

plt.legend(['Train', 'Validation'])
plt.tight_layout()
plt.show()

In [None]:
# get predictions
y_cell_pred = cell_model.predict(x_cell_test)
y_cell_pred = np.argmax(y_cell_pred, axis=1)

# show accuracy scores for class
labels = cell_cls_names
labels_idx = [i for i,e in enumerate(labels)]
print(classification_report(y_cell_test, y_cell_pred, digits=4, labels=labels_idx, target_names=labels))

# Print confusion matrix
conf_matrix = ConfusionMatrixDisplay.from_predictions(y_cell_test, y_cell_pred, colorbar=False, display_labels=labels)
plt.show()

### Image Segmentation and Classification

#### Load training data and pre-process full images

In [8]:
def display_images(display_list: list[NDArray], title: list[str]) -> None:
    ''' visualizes a list of images using matplotlib
    
    Args: 
        display_list list[NDArray]: a list of images
        title list[str]: a list of string used as titles for the plots
    '''
    
    plt.figure(figsize=(15, 15))

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Important: loaded images - shape is height, width, optionally channel(s)
def get_bw_image(xml_file: str) -> tuple[NDArray, str] :
    ''' Creates a grayscale image
    
    Args: 
        xml_file str: the annotation file path
    
    Returns: 
        a numpy array representation of the respective image and the image id
        
    '''
    
    img_id = os.path.splitext(os.path.basename(xml_file))[0]
    img_file = f'{IMG_BF_FOLDER}/{img_id}.tiff'
    img = cv2.imread(img_file, cv2.IMREAD_GRAYSCALE)
    return img, img_id

def generate_rgb_mask(annotation_files: list[pathlib.Path], cell_cls_names: list[str], img_resize: tuple[int, int]) -> NDArray:
    '''Creates an RGB based mask from annotation bounding boxes
    
    Args: 
        annotation_files list[Path]: a list of annotation file paths
        cell_cls_names list[str]: a list of the cell folder names to map the cell class to int values
        img_resize (int, int): the cell image resized dimension 
        
    Returns: 
        An array of masks
    '''
    masks = []
    c_threshold = 1
    rgb_cls_names = ['R','G','U']

    for xml_file in annotation_files:
        # print(xml_file)
        # get the annotation data
        cell_boxes, img_size = parse_xml(str(xml_file), cell_cls_names)

        # correct wrong metadata
        w = 540 if img_size[0] > 540 else img_size[0]
        h = 540 if img_size[1] > 540 else img_size[1]
        img_size = (w,h)

        c_img,_ = get_colour_image(str(xml_file))
        c_img = cv2.resize(c_img, img_size)

        # assign labels to each pixel
        # note: image format is height x width
        mask = np.zeros((c_img.shape[0], c_img.shape[1], 3), dtype=np.uint8)

        for box in cell_boxes:
            x_min, y_min, x_max, y_max, box_cls = box
            if x_max > img_size[0]:
              x_max = img_size[0]
            if y_max > img_size[1]:
              y_max = img_size[1]

            # for each pixel, get most intense colour.
            # only assign class to pixel if class matches
            # image format - height x width so lead with y
            for h in range(y_min-1, y_max-1):
                for w in range(x_min-1, x_max-1):
                    # get the max pixel index(channel)
                    pixel_r_max_val = np.max(c_img[h, w, 0])
                    pixel_g_max_val = np.max(c_img[h, w, 1])
                    pixel_u_max_val = np.max(c_img[h, w, 2])

                    c_pixel_idx = np.argmax([pixel_r_max_val, pixel_g_max_val, pixel_u_max_val])
                    c_pixel_val = np.max([pixel_r_max_val, pixel_g_max_val, pixel_u_max_val])
                    c_pixel_cls = cell_cls_names.index(rgb_cls_names[c_pixel_idx])

                    # check that the pixel max index matches the box class
                    # assumes background is zero.
                    if c_pixel_cls == box_cls and c_pixel_val > c_threshold:
                        mask[h, w, c_pixel_idx] = 255

        mask = cv2.resize(mask, img_resize)
        masks.append(mask)

    return  np.array(masks)

def generate_dataset(annotation_files: list[pathlib.Path], cell_cls_names: list[str], img_resize: tuple[int, int]) -> tuple[NDArray, NDArray, NDArray]:
    '''Creates an RGB based mask from annotation bounding boxes
    
    Args: 
        annotation_files list[Path]: a list of annotation file paths
        cell_cls_names list[str]: a list of the cell folder names to map the cell class to int values
        img_resize (int, int): the cell image resized dimension 
        
    Returns: 
        An tuple of grayscale images, corresponding masks and a colored version
    '''
    
    bw_imgs = []
    c_imgs=[]
    masks = []
    c_threshold = 1
    rgb_cls_names = ['R','G','U']

    for xml_file in annotation_files:
        # print(xml_file)
        # get the annotation data
        cell_boxes, img_size = parse_xml(str(xml_file), cell_cls_names)

        # correct wrong metadata
        w = 540 if img_size[0] > 540 else img_size[0]
        h = 540 if img_size[1] > 540 else img_size[1]
        img_size = (w,h)

        # get associated brightfield/fluorescent images
        # resize to annotated size
        bw_img,_ = get_bw_image(str(xml_file))
        bw_img = cv2.resize(bw_img, img_size)

        c_img,_ = get_colour_image(str(xml_file))
        c_img = cv2.resize(c_img, img_size)

        # assign labels to each pixel
        # note: image format is height x width
        mask = np.zeros((bw_img.shape[0], bw_img.shape[1]), dtype=np.uint8)

        for box in cell_boxes:
            x_min, y_min, x_max, y_max, box_cls = box
            if x_max > img_size[0]:
              x_max = img_size[0]
            if y_max > img_size[1]:
              y_max = img_size[1]

            # for each pixel, get most intense colour.
            # only assign class to pixel if class matches
            # image format - height x width so lead with y
            for h in range(y_min-1, y_max-1):
                for w in range(x_min-1, x_max-1):
                    # get the max pixel index(channel)
                    pixel_r_max_val = np.max(c_img[h, w, 0])
                    pixel_g_max_val = np.max(c_img[h, w, 1])
                    pixel_u_max_val = np.max(c_img[h, w, 2])

                    c_pixel_idx = np.argmax([pixel_r_max_val, pixel_g_max_val, pixel_u_max_val])
                    c_pixel_val = np.max([pixel_r_max_val, pixel_g_max_val, pixel_u_max_val])
                    c_pixel_cls = cell_cls_names.index(rgb_cls_names[c_pixel_idx])

                    # check that the pixel max index matches the box class
                    # assumes background is zero.
                    if c_pixel_cls == box_cls and c_pixel_val > c_threshold:
                        mask[h, w] = c_pixel_cls


        # resize and convert to 3 channel for vgg16
        bw_img = cv2.resize(bw_img, img_resize)
        bw_img = np.stack((bw_img,)*3, axis=-1)

        # resize and convert to 1 channel for vgg model 16
        mask = cv2.resize(mask, img_resize)
        mask = np.stack((mask,), axis=-1)

        # just resize colour image
        c_img = cv2.resize(c_img, img_resize)

        # save processed data
        bw_imgs.append(bw_img)
        masks.append(mask)
        c_imgs.append(c_img)

    return np.array(bw_imgs), np.array(masks), np.array(c_imgs)

In [None]:
# add background avoid negative bias
img_cls_names = ['Background'] + cell_cls_names
print(f'Image classes: {img_cls_names}')

# load training/validation dataset (if needed)
files = get_annotation_files(ANNOTATION_TRAIN_FOLDER, IMG_TRAIN_CNT)
x_img_train, y_img_train, x_img_train_colour = generate_dataset(files[:5], img_cls_names, (IMG_RESIZE, IMG_RESIZE))
print(x_img_train.shape, y_img_train.shape, x_img_train_colour.shape)

title = ['Input Image(Brightfield)', 'Input Image(Fluorescence)', 'True Mask', 'RGB Mask']
display_images([x_img_train[0], x_img_train_colour[0], y_img_train[0], generate_rgb_mask(files[:1], img_cls_names, (IMG_RESIZE, IMG_RESIZE))[0]], title)

# load test dataset
files =  get_annotation_files(ANNOTATION_TEST_FOLDER, IMG_TEST_CNT)
x_img_test, y_img_test, x_img_test_colour = generate_dataset(files[1], img_cls_names, (IMG_RESIZE, IMG_RESIZE))
print(x_img_test.shape, y_img_test.shape, x_img_test_colour.shape)

img_data = {
    'x_train': x_img_train,
    'y_train': y_img_train,
    'x_test': x_img_test,
    'y_test': y_img_test,
    'x_train_colour': x_img_train_colour,
    'x_test_colour': x_img_test_colour
}

#### Train Segmentation Model

In [9]:
def build_img_model(input_dims: tuple[int, int], n_class: int, vanilla: bool):
    ''' Builds an image segmentation model
    
    Args: 
        input_dims tuple(int, int): the shape of the model's input
        n_class int: the number of the classes of the model's output layer
        vanilla bool: a flag to specify when to use a vanilla version of the VGG16 pretrained model
        
    Returns: 
        a keras based image segmentation model
    '''
    
    inputs = keras.layers.Input(shape=input_dims, name='input')
    x = keras.applications.vgg16.preprocess_input(inputs)

    # pre-processing steps
    x = keras.layers.RandomFlip('horizontal', name='horizontal_flip')(x)
    x = keras.layers.RandomRotation(0.1, name='rotation')(x)

    # retrieve layers from cell/transfer model
    cell_model = build_cell_model(input_dims, len(cell_cls_names))
    if not vanilla:
        cell_model = load_weights(cell_model, model_info['cell_train_weights'], 'Cell Model')
    base_model = cell_model.get_layer('vgg16')

    # save final conv layer in each block for back propagation
    residuals=[]
    filters = []
    conv_layer =None
    for layer in base_model.layers[1:]:
        layer.trainable = False
        x = layer(x)
        if isinstance(layer, keras.layers.MaxPooling2D):
            residuals.append(conv_layer)
            filters.append(conv_layer.shape[3])
        if isinstance(layer, keras.layers.Conv2D):
            conv_layer = x

    # final conv layer has no max pooling so need for back propagation
    for layer in cell_model.layers:
        if isinstance(layer, keras.layers.Conv2D):
            layer.trainable = True
            x = layer(x)

    # add upsample layers
    for filter in reversed(filters):
        # Transpose convolution
        x = keras.layers.Conv2DTranspose(filter, kernel_size=2, strides=2, activation='relu', padding="same", kernel_initializer="he_normal")(x)
        x = keras.layers.concatenate([x, residuals.pop()])
        # Two convolutions
        x = keras.layers.Conv2D(filter, kernel_size=3, activation='relu', padding="same", kernel_initializer="he_normal")(x)
        x = keras.layers.Conv2D(filter, kernel_size=3, activation='relu', padding="same", kernel_initializer="he_normal")(x)
        if filter > 128:
            x = keras.layers.Conv2D(filter, kernel_size=3, activation='relu', padding="same", kernel_initializer="he_normal")(x)

    # Output layer
    outputs = keras.layers.Conv2D(filters=n_class, kernel_size=1, activation="softmax", name='per-pixel_clsf')(x)

    # Create intersection over union metric
    seg_iou = tf.keras.metrics.MeanIoU(num_classes=n_class, sparse_y_true=True, sparse_y_pred=False, name='iou')

    # Compile the model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='segmentation_model')
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=[seg_iou])

    return model

def train_img_model(model, data, val_split, epochs, batch_size, w_file, r_file, desc):
    ''' A custom function to train the image segmentation model and save the results
    
    Args: 
        model: the segmentation model
        val_split: the fraction of the input used for validation
        epochs: the number of times the models trains with all the input data
        batch_size: the number of input data to pass to the model at a time
        w_file: the file path to the saved model results
        r_file: the file path to save the model results
    '''

    model_chkptr = tf.keras.callbacks.ModelCheckpoint(
            w_file,
            monitor='loss',
            verbose=0,
            save_best_only=True,      # save trained weights (best values)
            save_weights_only=True,   #
            mode='min',
            save_freq='epoch')

    model_earlystp = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        min_delta=0.001,
        patience=10,
        verbose=0,
        mode="min",
        restore_best_weights=True)

    # Fit the model (best weights saved)
    results = model.fit(
        x=data['x_train'],
        y=data['y_train'],
        epochs=epochs,
        validation_split=val_split,
        batch_size=batch_size,
        callbacks=[model_chkptr,model_earlystp],
        verbose=2
        )

    # Save results
    results = results.history
    save_results(results, r_file, desc)

    return results

def do_training(data, w_file, r_file, cls_names: list[str], vanilla:bool=False):
    ''' Performs the models training
    
    Args: 
        data NDArray: a numpy representation of the input images and their masks
        w_file: the file path to the saved model results
        r_file: the file path to save the model results
        cls_names list(str): a list of the cell class labels
        vanilla bool:  a flag to use tune or untuned model

    '''

    # process description
    desc = 'non-tuned' if vanilla else 'tuned'
    desc = f'Segmentation training ({desc} transfer model)'

    # define model
    img_model = build_img_model((IMG_RESIZE, IMG_RESIZE, 3), len(cls_names), vanilla) # type: ignore

    # If we have previous weights, training is not necessary
    if RETRAIN_IMG_MODEL or not (os.path.exists(w_file) and os.path.exists(r_file)):
        metrics = train_img_model(img_model, data, VALIDATION_SPLIT, EPOCHS, BATCH_SIZE, w_file, r_file, desc)
    else:
        # we still need to reload weights/metrics
        load_weights(img_model, w_file, desc)
        metrics = load_results(r_file, desc)

    return img_model, metrics

In [None]:
# perform segmentation using non-tuned transfer models
img_model1, train_m1 = do_training(
    img_data,
    w_file = model_info['img_train_nontuned_weights'],
    r_file = model_info['img_train_nontuned_results'],
    cls_names = img_cls_names,
    vanilla = True
    )
print(f'Segmentation Mode (non-tuned transfer model) - training metrics...\n{train_m1}\n')
img_model1.summary()

# perform segmentation using tuned transfer models
img_model2, train_m2 = do_training(
    img_data,
    w_file = model_info['img_train_tuned_weights'],
    r_file = model_info['img_train_tuned_results'],
    cls_names = img_cls_names
    )
print(f'Segmentation Mode (non-tuned transfer model) - training metrics...\n{train_m2}\n')

#### Performance Metrics for segmentation model

<p>
Pixel accuracy metric calculates the percentage of pixels that were correctly classified according to the segmentation mask. Unfortunately, with very sparse class representations the pixel accuracy will be biased towards negative cases. We have addressed thsi issue by assigning a class to the background as well as opting for mean intersection over union (IOU). For our ground truth and predicted segmenation masks, we count the number of pixels that overlap and divide by the total area of both masks; value will be 0-1, the higher the value, the better.
</p>
<p>&emsp;IOU = <u>number of overlapping pixels (intersection)</u></br>&emsp;&emsp;&emsp;&emsp;True area + Predicted area − intersection</p>

Dice coefficient will the same answer but maybe not the same value; we will use IOU.

In [None]:

def evaluate_model(model, data: NDArray, r_file: str, r_desc: str, vanilla: bool=False):
    ''' A custom function to handle the model's evaluation
    '''

    if RETRAIN_IMG_MODEL or not os.path.exists(r_file):
        metrics = model.evaluate(x=data['x_test'], y=data['y_test'])
        save_results(metrics, r_file, r_desc)
    else:
        metrics = load_results(r_file, r_desc)

    return metrics

# plot loss and IOU metrics
def display_metrics(train_m, test_m, title):
    ''' Shows the metrics of the trained/loaded model
    '''

    print(f'{title}...')
    print(f"Best Train loss: {min(train_m['loss']):.3f}")
    print(f"Best Train seg iou: {max(train_m['iou']):.3f}")
    print(f"Best validation loss: {min(train_m['val_loss']):.3f}")
    print(f"Best validation seg iou: {max(train_m['val_iou']):.3f}")
    print(f"Test loss: {test_m[0]:.3f}")
    print(f"Test seg iou: {test_m[1]:.3f}")

    _, ax = plt.subplots(1, 2, figsize=(8, 6), sharex=True, sharey=False)
    # fig.suptitle(title)

    ax[0].plot(train_m['loss'], label='Train Loss')
    ax[0].plot(train_m['val_loss'], label='Validation Loss')
    ax[0].scatter(len(train_m['loss']), test_m[0], label='Test Loss', color='green')
    ax[0].set(xlabel='Epoch', ylabel='Loss')

    ax[1].plot(train_m['iou'], label='Train IOU')
    ax[1].plot(train_m['val_iou'], label='Validation IOU')
    ax[1].scatter(len(train_m['iou']), test_m[1], label='Test IoU', color='green')
    ax[1].set(xlabel='Epoch', ylabel='Seg IOU')

    plt.legend(['Train', 'Validation'])
    plt.tight_layout()
    plt.show()

# perform evaluation and display results for tuned/untuned transfer models
desc = 'Segmenation Metrics (non-tuned transfer model)'
test_m1 = evaluate_model(img_model1, img_data, model_info['img_test_nontuned_results'], desc, vanilla=True)
display_metrics(train_m1, test_m1, desc)

# perform evaluation and display results for tuned/untuned
desc = 'Segmenation Metrics (tuned transfer model)'
test_m2 = evaluate_model(img_model2, img_data, model_info['img_test_tuned_results'], desc)
display_metrics(train_m2, test_m2, desc)

##### Testing the Segmentation Models

In [None]:
y_img_pred1 = img_model1.predict(x_img_test)
y_img_pred2 = img_model2.predict(x_img_test)

y_img_pred_classes1 = [np.argmax(pred, -1) for pred in y_img_pred1]
y_img_pred_classes2 = [np.argmax(pred, -1) for pred in y_img_pred2]

In [None]:
_, ax = plt.subplots(1, 3, figsize=(15, 15));
ax = ax.ravel()

ax[0].imshow(y_img_test[0]);
ax[0].set_title('True Mask');
ax[0].axis('off');
ax[1].imshow(y_img_pred_classes1[0]);
ax[1].set_title('Non-tuned Weights');
ax[1].axis('off');
ax[2].imshow(y_img_pred_classes2[0]);
ax[2].set_title('Tuned Weights');
ax[2].axis('off');