# Perform inference on a test set

In [1]:
%matplotlib inline

import cv2
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import re 

import tensorflow as tf

# Disable GPU (if it is used elsewhere)
# try:
#     # Disable all GPUS
#     tf.config.set_visible_devices([], 'GPU')
#     visible_devices = tf.config.get_visible_devices()
#     for device in visible_devices:
#         assert device.device_type != 'GPU'
# except:
#     # Invalid device or cannot modify virtual devices once initialized.
#     pass

import tensorflow.keras as keras
from tensorflow.keras import Model, layers, models

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.image import Iterator, ImageDataGenerator

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import Iterator, ImageDataGenerator
import tensorflow.keras.backend as K

import skimage.transform

import napari

print(tf.__version__)

if tf.test.gpu_device_name(): 
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
else:
    print("Please install GPU version of TF")

The version installed is 5.9.7. Please report any issues with this specific QT version at https://github.com/Napari/napari/issues.
  warn(message=warn_message)


2.2.0
Default GPU Device: /device:GPU:0


## Load data

Here we load the test set of Bright Field images.

In [2]:
class ImageGenerator(Sequence):
    """
    Generates images and masks for performing data augmentation in Keras.
    We inherit from Sequence (instead of directly using the keras ImageDataGenerator)
    since we want to perform augmentation on both the input image AND the mask 
    (target). This mechanism needs to be implemented in this class. This class
    also allows to implement new augmentation transforms that are not implemented
    in the core Keras class (illumination, etc.).
    See : https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
    and https://stackoverflow.com/questions/56758592/how-to-customize-imagedatagenerator-in-order-to-modify-the-target-variable-value
    for more details.
    """

    def __init__(self, X_set, # input images and masks
                 n_channels_ims=1, n_channels_masks=1,
                 batch_size: int=4, dim: tuple=(512, 512),
                 n_channels: int=1, # informations 
                 normalize=True, reshape=False, crop=None, # preprocessing params
                 restrict_to="", separate_directories=False): # data augmentation params
        """
        X_set (list, array or str): pointer to the images (Bright-Field). If str
        the string is assumed to be pointing at some directory.
        Y_set (list; array or str): pointer to the masks (target). If str
        the string is assumed to be pointing at some directory.
        batch_size (int): size of the batch
        dim (tuple): dimension of the images
        n_channels (int) : number of channels of the images (1 for TIF)
        shuffle (bool): Shuffle the dataset between each training epoch
        normalize (bool): normalize the images and masks in the beginning
        reshape (bool): reshape the images and masks to (dim, dim, n_channels)
        histogram_equalization (bool): perform histogram equalization to improve
        rendering using opencv
        horiz_flip_percent ()
        vert_flip_percent
        """
        # super().__init__(n, batch_size, shuffle, seed)
        self.dim = dim
        self.im_size = dim
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_channels_ims = n_channels_ims
        self.n_channels_masks = n_channels_masks
        self.separate_directories = separate_directories        
        
        self.restrict_to = restrict_to

        # build the X_set in an array. If X_set is a directory containing images
        # then self.X_set doesn't contains the images but the file names, but it
        # is transparent for the user.       
        self.from_directory_X = True
        self.X_dir = X_set # path to the images dir
        self.X_set = []
        if self.restrict_to == "" and not self.separate_directories:
            directory = self.alphanumeric_sort(os.listdir(X_set))
            for k in range(0, len(directory), self.n_channels_ims):
                self.X_set.append(np.array(directory[k:k+self.n_channels_ims]))
            self.X_set = np.array(self.X_set)
        elif self.restrict_to == "" and self.separate_directories:   # different channels in separate directories
            assert len(self.X_dir) == self.n_channels_ims  # number of directories should mathc the number of channels
            for k in range(0, len(os.listdir(self.X_dir[0]))):
                channels = []
                for chan in range(self.n_channels_ims):
                    channels.append(self.alphanumeric_sort(os.listdir(self.X_dir[chan]))[k])
                self.X_set.append(np.array(channels))
            self.X_set = np.array(self.X_set)
            print(self.X_set.shape)
        else:
            directory = self.alphanumeric_sort(os.listdir(X_set))
            for k in range(0, len(directory), self.n_channels_ims):
                if directory[k].startswith(self.restrict_to):
                    self.X_set.append(np.array(directory[k:k+self.n_channels_ims]))
            self.X_set = np.array(self.X_set)
        
        # Preprocessing parameters
        self.normalize = normalize
        self.reshape = reshape
        self.crop = crop
        
        # Initialize the indices (shuffle if asked)
        self.on_epoch_end()
        
    def alphanumeric_sort(self, l): 
        """ Sort the given iterable in the way that humans expect.""" 
        convert = lambda text: int(text) if text.isdigit() else text 
        alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
        return sorted(l, key = alphanum_key)

    def __len__(self) -> int:
        """
        Number of batches per epoch : we evenly split the train set into samples
        of size batch_size.
        """
        return int(np.floor(self.X_set.shape[0] / self.batch_size))

    def __getitem__(self, index: int):
        """
        Generate one batch of data.
        """
        if index >= self.__len__():
            raise IndexError
            
        # Generate indices corresponding to the images in the batch
        indices = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Generate the batch
        X = self.__data_generation(indices)
        return X
    
    def get_image_idx(self, im_name):
        """
        Used to sort the images by an idx, when they are properly sorted (e.g. when images
        are numbered 1 to 1000 instead of 0001 to 1000). We assume that the numerical index
        is in the form "XXX_tnumericalindex.tiff" where XXC can be anything.
        """
        if "-" in im_name.split(".")[0].split("_")[-1][1:]:
            return int(im_name.split(".")[0].split("_")[-1][1:].split("-")[-1][1:])
        else:
            return int(im_name.split(".")[0].split("_")[-1][1:])
        

    def on_epoch_end(self):
        """
        Updates indexes after each epoch. self.indexes is used to retrieve the
        samples and organize them into batches.
        """
        self.indexes = np.arange(self.X_set.shape[0])

    def __data_generation(self, list_IDs: [int]):
        """
        Generates data containing batch_size samples. This is where we load the
        images if they are in a directory, and apply transformations to them.
        """ 
        # Load data (from directory or from X_set depending on the given data)
        if self.from_directory_X and not self.separate_directories:
            batch_X = []
            for im in list_IDs:
                channels = []
                for k in range(self.n_channels_ims):
                    channels.append(np.expand_dims(imageio.imread(f"{self.X_dir}/{self.X_set[im, k]}"), axis=-1)) # add channel axis
                batch_X.append(np.concatenate(channels, axis=-1))
            batch_X = np.array(batch_X)    
        elif self.from_directory_X and self.separate_directories:
            batch_X = []
            for im in list_IDs:
                channels = []
                for k in range(self.n_channels_ims):
                    channels.append(np.expand_dims(imageio.imread(f"{self.X_dir[k]}/{self.X_set[im, k]}"), axis=-1)) # add channel axis
                batch_X.append(np.concatenate(channels, axis=-1))
            batch_X = np.array(batch_X)   
        else:
            batch_X = self.X_set[list_IDs]

        # Preprocessing
        if self.crop is not None:
            batch_X = self.perf_crop(batch_X)
            
        if self.reshape:
            batch_X = self.perf_reshape(batch_X)

        if self.normalize:
            batch_X = self.perf_normalize(batch_X)

        return batch_X

    # Preprocessing functions
    def perf_crop(self, images):
        crop_X = int((images.shape[1] - self.crop[0]) // 2)
        crop_Y = int((images.shape[2] - self.crop[1]) // 2)
        new_batch = np.empty((self.batch_size, *self.crop))
        for i, img in enumerate(images):
            if crop_X != 0 and crop_Y != 0:
                new_batch[i] = img[crop_X:-crop_X, crop_Y:-crop_Y]
            elif crop_X != 0:
                new_batch[i] = img[crop_X:-crop_X, :]
            elif crop_Y != 0:
                new_batch[i] = img[:, crop_Y:-crop_Y]
            else:
                new_batch[i] = img
        return new_batch
    
    def perf_reshape(self, images):
        """
        images (np.array): batch of images of shape (batch_size, n_rows, n_cols, n_chans)
        is_images (bool): is it a batch of images (True) or masks (False)
        """
        new_batch = np.empty((self.batch_size, *self.im_size, self.n_channels_ims))
        for i, img in enumerate(images): # the resize function normalizes the images anyways...
            new_batch[i] = skimage.transform.resize(img, (*self.im_size, self.n_channels_ims), anti_aliasing=True)
        return new_batch

    def perf_normalize(self, images):
        """
        Performs per image, per channel normalization by substracting the min and dividing by (max - min)
        """
        new_batch = np.empty(images.shape)
        for i, img in enumerate(images):
            assert (np.min(img, axis=(0, 1)) != np.max(img, axis=(0, 1))).all(), print("Cannot normalize an image containing only 0 or 1 valued pixels. There is likely an empty image in the training set.\nIf cropping was used,"
                                                                                       "maybe the mask doesn't contain any white pixel in the specific region.")
            new_batch[i] = (img - np.min(img, axis=(0, 1))) / (np.max(img, axis=(0, 1)) - np.min(img, axis=(0, 1)))
        return new_batch

In [20]:
# CHANGE DATASET PATH HERE
test_path = ["D:\Hugo\Data\H449.1/f0_RFP", "D:\Hugo\Data\H449.1/f0_BF"]

restrict_to = ""
bat_size, nc_ims, nc_masks = 1, 2, 1 # SPECIFY HERE THE NUMBER OF CHANNELS
crop, reshape, target_dim, normalize = None, True, (512, 512), True

test_set = ImageGenerator(test_path, batch_size=bat_size, dim=target_dim, n_channels_ims=nc_ims, 
                          n_channels_masks=nc_masks, normalize=normalize, 
                          reshape=reshape, crop=crop, restrict_to=restrict_to, separate_directories=True)
        
def visualize_data(bf, nc_ims=1):
#     with napari.gui_qt():
    if nc_ims == 1:
        viewer = napari.view_image(bf[:, :, :, :].squeeze(-1))
    else:
        viewer = napari.view_image(bf[:, :, :, 0])  # bf
        for k in range(1, nc_ims):
            viewer.add_image(bf[:, :, :, k], blending="additive")

plot = True
if plot:
    print(f"# Batches : {len(test_set)}")
    bf = np.array(test_set[0])    
    print(bf.shape)
    visualize_data(bf, nc_ims=nc_ims)

(689, 2)
# Batches : 689
(1, 512, 512, 2)


## Load model and perform inference

In [15]:
# CHANGE MODEL PATH HERE
# os.chdir("D:/Hugo/Python_Scripts/Tools/unet/models")

def binary_focal_loss_fixed(y_true, y_pred):
    """
    :param y_true: A tensor of the same shape as `y_pred`
    :param y_pred:  A tensor resulting from a sigmoid
    :return: Output tensor.
    """
    y_true = tf.cast(y_true, tf.float32)
    # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
    epsilon = K.epsilon()
    # Add the epsilon to prediction value
    # y_pred = y_pred + epsilon
    # Clip the prediciton value
    y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
    # Calculate p_t
    p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
    # Calculate alpha_t
    alpha_factor = K.ones_like(y_true) * alpha
    alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
    # Calculate cross entropy
    cross_entropy = -K.log(p_t)
    weight = alpha_t * K.pow((1 - p_t), gamma)
    # Calculate focal loss
    loss = weight * cross_entropy
    # Sum the losses in mini_batch
    loss = K.mean(K.sum(loss, axis=1))
    return loss

def binary_focal_loss(gamma=2., alpha=.25):
    """
    Binary form of focal loss.
    FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
    where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
    model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """

    def binary_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred:  A tensor resulting from a sigmoid
        :return: Output tensor.
        """
        y_true = tf.cast(y_true, tf.float32)
        # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        # y_pred = y_pred + epsilon
        # Clip the prediciton value
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        # Calculate p_t
        p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
        # Calculate alpha_t
        alpha_factor = K.ones_like(y_true) * alpha
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
        # Calculate cross entropy
        cross_entropy = -K.log(p_t)
        weight = alpha_t * K.pow((1 - p_t), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.mean(K.sum(loss, axis=1))
        return loss
    return binary_focal_loss_fixed

metrics_name = "IoU_metric"
def IoU_metric(y_true, y_pred):
    threshold = tf.constant(0.5, dtype=tf.float32)
    y_pred = K.cast(tf.math.greater(y_pred, threshold), dtype="float32")
    intersection = K.sum(y_true * y_pred, axis=(1, 2, 3))
    union = K.sum(y_true + y_pred, axis=(1, 2, 3)) - intersection
    return K.mean((intersection + K.epsilon()) / (union + K.epsilon()), axis=0)

def jaccard_distance(smooth=50):

    def jaccard_distance_fixed(y_true, y_pred):
        """
        Calculates mean of Jaccard distance as a loss function
        """
        intersection = tf.reduce_sum(y_true * y_pred, axis=(1,2))
        sum_ = tf.reduce_sum(y_true + y_pred, axis=(1,2))
        jac = (intersection + smooth) / (sum_ - intersection + smooth)
        jd =  (1 - jac) * smooth
        return tf.reduce_mean(jd)
    
    return jaccard_distance_fixed

os.chdir("D:/Hugo/BiSeg/Models")
model_name = "BS300_84ims"
unet = keras.models.load_model(model_name, custom_objects={"jaccard_distance_fixed": jaccard_distance, "jaccard_distance": jaccard_distance, "IoU_metric": IoU_metric})

In [21]:
# NOTHING TO CHANGE HERE
# For 3-channels models
# predictions = []
# for im in test_set:
#     ccc = np.concatenate([im, im, im], axis=-1)
#     predictions.append(unet.predict(ccc))
# predictions=np.array(predictions)
# print(predictions.shape)
predictions = unet.predict(test_set)
print(predictions.shape)

(689, 512, 512, 1)


## View results with Napari

In [22]:
# CHANGE RFP PATH HERE
def visualize_data_and_predictions(bf, pred, nc_ims=1, nc_masks=1):
#     with napari.gui_qt():
    if nc_ims == 1:
        viewer = napari.view_image(bf[:, :, :, :].squeeze(-1))
    else:
        cmaps = ["red", "gray", "blue", "bop purple"]
        viewer = napari.view_image(bf[:, :, :, 0], colormap=cmaps[0])  # bf
        for k in range(1, nc_ims):
            viewer.add_image(bf[:, :, :, k], blending="additive", colormap=cmaps[k])
    if nc_masks == 1:
        viewer.add_image(pred[:, :, :, :].squeeze(-1), blending="additive")
    else:
        cmaps = ["red", "gray", "bop orange", "blue", "bop purple"]
        for k in range(0, nc_masks):
            viewer.add_image(pred[:, :, :, k], blending="additive", colormap=cmaps[k])

whole_test_set = np.concatenate([test_set[i] for i in range(len(test_set))], axis=0)
print(whole_test_set.shape)

visualize_data_and_predictions(whole_test_set, predictions, nc_ims=nc_ims, nc_masks=nc_masks)

(689, 512, 512, 2)


## Save predictions

In [23]:
from skimage.transform import resize

# CHANGE SAVE PATH PATH HERE
save_path = "D:\Hugo\Anaphase\Inter_Div_Correlation/H449.1"
name = "BS300.84_H449.1_f0"
extension = "tif"

# Rescale image
target_dim, nc_masks = (512, 512), 1

predictions_to_save = []
for im in predictions[:400]:
    if target_dim != (im.shape[0], im.shape[1]):
        if nc_masks == 1:
            predictions_to_save.append(resize(im.squeeze(-1), target_dim, order=3)) # order = 3 :bicubic interpolation
        else:
            predictions_to_save.append(resize(im, target_dim, order=3))
    else:
        if nc_masks == 1:
            predictions_to_save.append(im.squeeze(-1))
        else:
            predictions_to_save.append(im)
predictions_to_save = np.array(predictions_to_save)

if nc_masks == 1: 
    imageio.volwrite(f"{save_path}/{name}.{extension}", predictions_to_save)
else:
    channel_names = ["mask", "voronoi"]
    for k in range(nc_masks):
        imageio.volwrite(f"{save_path}/{name}_{channel_names[k]}.{extension}", predictions_to_save[:, :, :, k])