# Train a U-Net Model from scratch

## Load data

In [23]:
%matplotlib inline

import cv2
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import Model, layers, models

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.image import Iterator, ImageDataGenerator

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import Iterator, ImageDataGenerator
import tensorflow.keras.backend as K

print(tf.__version__)
print(tf.test.is_built_with_cuda()) 
print(tf.config.list_physical_devices('GPU'))

import skimage.transform

import napari

# tf.config.gpu.set_per_process_memory_fraction(0.75)
# tf.config.gpu.set_per_process_memory_growth(True)

2.2.0
True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [345]:
class ImageMaskGenerator(Sequence):
    """
    Generates images and masks for performing data augmentation in Keras.
    We inherit from Sequence (instead of directly using the keras ImageDataGenerator)
    since we want to perform augmentation on both the input image AND the mask 
    (target). This mechanism needs to be implemented in this class. This class
    also allows to implement new augmentation transforms that are not implemented
    in the core Keras class (illumination, etc.).
    See : https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
    and https://stackoverflow.com/questions/56758592/how-to-customize-imagedatagenerator-in-order-to-modify-the-target-variable-value
    for more details.
    """

    def __init__(self, X_set, Y_set=None, # input images and masks
                 batch_size: int=32, dim: tuple=(512, 512),
                 n_channels_ims: int=1, n_channels_masks: int=1, # informations 
                 shuffle: bool=True, normalize=True, reshape=False, crop=None,# preprocessing params
                 **kwargs): # data augmentation params
        """
        X_set (list, array or str): pointer to the images (Bright-Field). If str
        the string is assumed to be pointing at some directory.
        Y_set (list; array or str): pointer to the masks (target). If str
        the string is assumed to be pointing at some directory.
        batch_size (int): size of the batch
        dim (tuple): dimension of the images
        n_channels_ims (int) : number of channels of the images (1 for TIF)
        shuffle (bool): Shuffle the dataset between each training epoch
        crop (tuple): Target dim of one image after cropping
        normalize (bool): normalize the images and masks in the beginning
        reshape (bool): reshape the images and masks to (dim, dim, n_channels_ims)
        histogram_equalization (bool): perform histogram equalization to improve
        rendering using opencv
        horiz_flip_percent ()
        vert_flip_percent
        """
        # super().__init__(n, batch_size, shuffle, seed)
        self.dim = dim
        self.im_size = dim
        self.batch_size = batch_size
        self.n_channels_ims = n_channels_ims
        self.n_channels_masks = n_channels_masks
        
        # build the X_set in an array. If X_set is a directory containing images
        # then self.X_set doesn't contains the images but the file names, but it
        # is transparent for the user.
        if type(X_set) == list:
            self.from_directory_X = False
            self.X_set = np.array(X_set)
        elif type(X_set) == np.array:
            self.from_directory_X = False
            self.X_set = X_set
        elif type(X_set) == str: # assuming a path
            self.from_directory_X = True
            self.X_dir = X_set # path to the images dir
#             if self.n_channels_ims == 1:
#                 self.X_set = np.array(sorted(os.listdir(X_set))) # sorted guarantees the order
#             else: # n_channels_ims > 1 : several channels per image
            self.X_set = []
            for k in range(0, len(os.listdir(X_set)), self.n_channels_ims):
                self.X_set.append(np.array(os.listdir(X_set)[k:k+self.n_channels_ims]))
            self.X_set = np.array(self.X_set)
        else:
            raise TypeError("X_set should be list, array or path")
        
        # build the Y_set in an array
        if type(Y_set) == list:
            self.from_directory_Y = False
            self.Y_set = np.array(Y_set)
        elif type(Y_set) == np.array:
            self.from_directory_Y = False
            self.Y_set = Y_set
        elif type(Y_set) == str: # assuming a path
            self.from_directory_Y = True
            self.Y_dir = Y_set # path to the masks dir
            self.Y_set = []
            for k in range(0, len(os.listdir(Y_set)), self.n_channels_masks):
                self.Y_set.append(np.array(os.listdir(Y_set)[k:k+self.n_channels_masks]))
            self.Y_set = np.array(self.Y_set)
        else:
            raise TypeError("Y_set should be list, array or path")

        # Check if there are the same number of images in X (images) and Y (masks)
        assert self.X_set.shape[0] != 0 and self.Y_set.shape[0] != 0, print(f"Directory '{X_set}' is empty!")
        assert self.X_set.shape[0] == self.Y_set.shape[0], print(f"{self.X_set.shape[0]} images != {self.Y_set.shape[0]} masks")

        self.shuffle = shuffle

        # Preprocessing parameters
        self.normalize = normalize
        self.reshape = reshape
        self.crop = crop

        # The Keras generator that will be used to perform data augmentation 
        self.generator = ImageDataGenerator(**kwargs)

        # Initialize the indices (shuffle if asked)
        self.on_epoch_end()

    def __len__(self) -> int:
        """
        Number of batches per epoch : we evenly split the train set into samples
        of size batch_size.
        """
        return int(np.floor(self.X_set.shape[0] / self.batch_size))
        
    def __getitem__(self, index: int):
        """
        Generate one batch of data.
        """
        if index >= self.__len__():
            raise IndexError
        
        # Generate indices corresponding to the images in the batch
        indices = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Generate the batch
        X, Y = self.__data_generation(indices)
        return X, Y

    def on_epoch_end(self):
        """
        Updates indexes after each epoch. self.indexes is used to retrieve the
        samples and organize them into batches.
        If shuffle : randomizes the order of the samples in order to give 
        different training batches at each epoch.
        """
        self.indexes = np.arange(self.X_set.shape[0])
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs: [int]):
        """
        Generates data containing batch_size samples. This is where we load the
        images if they are in a directory, and apply transformations to them.
        """ 
        # Load data (from directory or from X_set depending on the given data)
        if self.from_directory_X:
            batch_X = []
            for im in list_IDs:
                channels = []
                for k in range(self.n_channels_ims):
                    channels.append(np.expand_dims(imageio.imread(f"{self.X_dir}/{self.X_set[im, k]}"), axis=-1)) # add channel axis
                batch_X.append(np.concatenate(channels, axis=-1))
            batch_X = np.array(batch_X)
        else:
            batch_X = self.X_set[list_IDs]

        if self.from_directory_Y:
            batch_Y = []
            for im in list_IDs:
                channels = []
                for k in range(self.n_channels_masks):
                    channels.append(np.expand_dims(imageio.imread(f"{self.Y_dir}/{self.Y_set[im, k]}"), axis=-1)) # add channel axis
                batch_Y.append(np.concatenate(channels, axis=-1))
            batch_Y = np.array(batch_Y) 
        else:
            batch_Y = self.Y_set[list_IDs]

        # Preprocessing
        if self.crop is not None:
            batch_X = self.perf_crop(batch_X)
            batch_Y = self.perf_crop(batch_Y)

        if self.reshape:
            batch_X = self.perf_reshape(batch_X, is_images=True)
            batch_Y = self.perf_reshape(batch_Y, is_images=False)

        if self.normalize:
            batch_X = self.perf_normalize(batch_X)
            batch_Y = self.perf_normalize(batch_Y)

#         if self.n_channels_ims == 3:
#             batch_X = np.concatenate([batch_X, batch_X, batch_X], axis=-1)

        # Perform the SAME transformation on the image and on the mask
        for i, (img, mask) in enumerate(zip(batch_X, batch_Y)):
            transform_params = self.generator.get_random_transform(img.shape)
            batch_X[i] = self.generator.apply_transform(img, transform_params)
            batch_Y[i] = self.generator.apply_transform(mask, transform_params)

        return batch_X, batch_Y        

    # Preprocessing functions
    def perf_crop(self, images):
        crop_X = int((images.shape[1] - self.crop[0]) // 2)
        crop_Y = int((images.shape[2] - self.crop[1]) // 2)
        assert (crop_X >= 0 and crop_Y >= 0), print(f"Target size after cropping {self.crop} should be lower than the initial shape {(images.shape[1], images.shape[2])}.")
        new_batch = np.empty((self.batch_size, *self.crop, images.shape[3]))
        for i, img in enumerate(images):
            if crop_X != 0 and crop_Y != 0:
                new_batch[i] = img[crop_X:-crop_X, crop_Y:-crop_Y]
            elif crop_X != 0:
                new_batch[i] = img[crop_X:-crop_X, :]
            elif crop_Y != 0:
                new_batch[i] = img[:, crop_Y:-crop_Y]
            else:
                new_batch[i] = img
        return new_batch

    def perf_reshape(self, images, is_images=True):
        """
        images (np.array): batch of images of shape (batch_size, n_rows, n_cols, n_chans)
        is_images (bool): is it a batch of images (True) or masks (False)
        """
        if is_images:  # batch of images
            new_batch = np.empty((self.batch_size, *self.im_size, self.n_channels_ims))
            for i, img in enumerate(images): # the resize function normalizes the images anyways...
                new_batch[i] = skimage.transform.resize(img, (*self.im_size, self.n_channels_ims), anti_aliasing=True)
        else:  # batch of masks
            new_batch = np.empty((self.batch_size, *self.im_size, self.n_channels_masks))
            for i, img in enumerate(images):
                new_batch[i] = skimage.transform.resize(img, (*self.im_size, self.n_channels_masks), anti_aliasing=True)
        return new_batch

    def perf_normalize(self, images):
        """
        Performs per image, per channel normalization by substracting the min and dividing by (max - min)
        """
        new_batch = np.empty(images.shape)
        for i, img in enumerate(images):
            assert (np.min(img, axis=(0, 1)) != np.max(img, axis=(0, 1))).all(), print("Cannot normalize an image containing only 0 or 1 valued pixels. There is likely an empty image in the training set.\nIf cropping was used,"
                                                                                       "maybe the mask doesn't contain any white pixel in the specific region.")
            new_batch[i] = (img - np.min(img, axis=(0, 1))) / (np.max(img, axis=(0, 1)) - np.min(img, axis=(0, 1)))
        return new_batch

In [349]:
data_path = "D:/Hugo/BiSeg/Train_Set"
bf_dir, mask_dir = f"{data_path}/images/", f"{data_path}/masks/"

# cf. la doc Keras pour voir tout ce qu'il est possible de faire
# https://keras.io/api/preprocessing/image/
# voir aussi la librairie imgaug ou albumentation pour implementer de nouvelles transfo
augmentation_params = dict(zoom_range=[0.9, 1.5],
                           rotation_range=360,
                           height_shift_range=0.2,
                           width_shift_range=0.2,
                           fill_mode="constant", cval=0)

# augmentation_params = {}

bat_size, nc_ims, nc_masks, shuffle = 4, 2, 1, True  # SPECIFY HERE THE NUMBER OF CHANNELS
crop, reshape, target_dim, normalize = None, True, (512, 512), True

generator = ImageMaskGenerator(bf_dir, mask_dir, 
                               batch_size=bat_size, dim=target_dim, n_channels_ims=nc_ims, n_channels_masks=nc_masks, 
                               shuffle=shuffle, normalize=normalize, reshape=reshape, crop=crop,
                                **augmentation_params)

val_generator = ImageMaskGenerator(f"{data_path}/val_images/", f"{data_path}/val_masks/", 
                               batch_size=1, dim=target_dim, n_channels_ims=nc_ims, n_channels_masks=nc_masks,
                               shuffle=shuffle, normalize=normalize, reshape=reshape, crop=crop,
                                **augmentation_params)

def visualize_data(bf, masks, nc_ims):
    with napari.gui_qt():
        if nc_ims == 1:
            viewer = napari.view_image(bf[:, :, :, :].squeeze(-1))
            viewer.add_image(masks[:, :, :, :].squeeze(-1), blending="additive")
        else:
            viewer = napari.view_image(bf[:, :, :, 1])  # bf
            viewer.add_image(bf[:, :, :, 0], blending="additive")
            viewer.add_image(masks[:, :, :, :].squeeze(-1), blending="additive")

In [350]:
plot = True
if plot:
    print(f"# Batches : {len(generator)}")
    bf, masks = generator[11]
    bf, masks = np.array(bf), np.array(masks)
    
    visualize_data(bf, masks, nc_ims=nc_ims)

# Batches : 12


## Define model

In [351]:
def get_unet(nbr, x, y, n_channels_imgs=1, n_channels_masks=1):
    """
    nbr (int): kernel side
    x (int): image height
    y (int): image width
    """
    print(f"# input channels : {n_channels_imgs}.")
    print(f"# output channels : {n_channels_masks}.")
    
    initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
    entree=layers.Input(shape=(x, y, n_channels_imgs), dtype='float16')

    result=layers.Conv2D(nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(entree)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result1=layers.BatchNormalization()(result)

    result=layers.MaxPool2D()(result1)

    result=layers.Conv2D(2*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(2*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result2=layers.BatchNormalization()(result)

    result=layers.MaxPool2D()(result2)

    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result3=layers.BatchNormalization()(result)

    result=layers.MaxPool2D()(result3)

    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result4=layers.BatchNormalization()(result)

    result=layers.MaxPool2D()(result4)

    result=layers.Conv2D(8*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)

    result=layers.UpSampling2D()(result)
    result=tf.concat([result, result4], axis=3)

    result=layers.Conv2D(8*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)

    result=layers.UpSampling2D()(result)
    result=tf.concat([result, result3], axis=3)

    result=layers.Conv2D(4*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(2*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)

    result=layers.UpSampling2D()(result)
    result=tf.concat([result, result2], axis=3)

    result=layers.Conv2D(2*nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)

    result=layers.UpSampling2D()(result)
    result=tf.concat([result, result1], axis=3)

    result=layers.Conv2D(nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)
#     result=layers.Dropout(0.2)
    result=layers.Conv2D(nbr, 3, activation='relu', padding='same', kernel_initializer=initializer)(result)
    result=layers.BatchNormalization()(result)

    sortie=layers.Conv2D(n_channels_masks, 1, activation='sigmoid', padding='same', kernel_initializer=initializer)(result)

    model=models.Model(inputs=entree, outputs=sortie)
    return model

## Loss function : Weighted binary crossentropy

This step is important because we are facing a class imbalance problem : the 0 class (i.e. background) are way more numerous than the 1 class (i.e. yeast pixels).

In [352]:
# not necessary for whole cells
# TODO: try it on mating or buds

class WeightedBinaryCrossEntropy():

    def __init__(self, class_weight={0: 0.5, 1: 0.5}):
        self.class_weight = class_weight
        self.__name__ = "binary_cross_entropy"

    def __call__(self, Y_true, Y_pred):
        """
        Compute the weights binary cross entropy for a given mask Y_true and a given
        prediction Y_pred.
        """
        sample_weight = {0: 0.2, 1: 0.8}
        y_true = K.clip(Y_true, K.epsilon(), 1-K.epsilon())
        y_pred = K.clip(Y_pred, K.epsilon(), 1-K.epsilon())
        logloss = -(y_true * K.log(y_pred) * self.class_weight[1] 
                    + (1 - y_true) * K.log(1 - y_pred) * self.class_weight[0] )
        return K.mean(logloss, axis=-1)

weights = {0: 1, 1: 100}
binary_cross_entropy = WeightedBinaryCrossEntropy(class_weight=weights)

def jaccard_distance(smooth=20):

    def jaccard_distance_fixed(y_true, y_pred):
        """
        Calculates mean of Jaccard distance as a loss function
        """
        intersection = tf.reduce_sum(y_true * y_pred, axis=(1,2))
        sum_ = tf.reduce_sum(y_true + y_pred, axis=(1,2))
        jac = (intersection + smooth) / (sum_ - intersection + smooth)
        jd =  (1 - jac) * smooth
        return tf.reduce_mean(jd)
    
    return jaccard_distance_fixed

def binary_focal_loss(gamma=2., alpha=.25):
    """
    Binary form of focal loss.
    FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
    where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
    model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """

    def binary_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred:  A tensor resulting from a sigmoid
        :return: Output tensor.
        """
        y_true = tf.cast(y_true, tf.float32)
        # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        # y_pred = y_pred + epsilon
        # Clip the prediciton value
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        # Calculate p_t
        p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
        # Calculate alpha_t
        alpha_factor = K.ones_like(y_true) * alpha
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
        # Calculate cross entropy
        cross_entropy = -K.log(p_t)
        weight = alpha_t * K.pow((1 - p_t), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.mean(K.sum(loss, axis=1))
        return loss
    return binary_focal_loss_fixed

## Train the model

In [None]:
callbacks = keras.callbacks.ReduceLROnPlateau(monitor="loss", factor=0.5, verbose=1, patience=10, min_lr=1e-6)

smooth = [50]

fig, ax = plt.subplots(1, 2, figsize=(20, 8))

IoU = tf.keras.metrics.MeanIoU(num_classes=2, name="mean_IoU")

print(f"Smoothing : {smooth}")

if smooth == 0:
    loss = keras.losses.BinaryCrossentropy()
else:
    loss = jaccard_distance(smooth=smooth)
n_filters, init_lr = 64, 0.005
unet = get_unet(n_filters, 512, 512, n_channels_imgs=nc_ims, n_channels_masks=nc_masks)
unet.compile(
    optimizer=keras.optimizers.Adam(learning_rate=init_lr),
    # loss=keras.losses.BinaryCrossentropy(),
    loss=loss,
    metrics=IoU,
)
n_epochs = 300
history = unet.fit(generator, validation_data=val_generator, 
                   epochs=n_epochs, verbose=1, callbacks=callbacks)

ax[0].plot(history.history["loss"][:], "orange", label="loss")
ax[0].plot(history.history["val_loss"][:], "b", label="validation loss")
ax[0].legend()
ax[0].set_title("Training curves")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss (Jaccard, smooth==50)")

ax[1].plot(history.history["mean_IoU"][:], "orange", label="IoU")
ax[1].plot(history.history["val_mean_IoU"][:], "b", label="validation IoU")
ax[1].legend()
ax[1].set_title("Training curves")
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("IoU")

model_name = f"BS{n_epochs}"

os.chdir("D:/Hugo/BiSeg/Models")
unet.save(model_name)

plt.savefig(f"{model_name}/{model_name}_learning_curve.png")

list_training_imgs = "\n".join(os.listdir(generator.X_dir))
list_val_imgs = "\n".join(os.listdir(val_generator.X_dir))

with open(f"{model_name}/history.txt", "w") as hist_file:
    hist_file.write(f"Model {model_name} trained for {n_epochs} epochs."
                    f"\nNumber of training images : {len(generator)} * {generator.batch_size}, from directory {generator.X_dir}, masks from {generator.Y_dir}."
                    f"\n\nNumber of validation images : {len(val_generator)} * {val_generator.batch_size}, from directory {val_generator.X_dir}, masks from {val_generator.Y_dir}."
                    f"\nLoss : {loss}, smoothing: {smooth}."
                    f"\nNumber of filters : {n_filters}."
                    f"\nInitial learning rate: {init_lr}."
                    f"\n\nList of the training images:\n{list_training_imgs}."
                    f"\n\nList of the validation images:\n{list_val_imgs}."
                   )

Smoothing : [50]
# input channels : 2.
# output channels : 1.


To change all layers to have dtype float16 by default, call `tf.keras.backend.set_floatx('float16')`. To change just this layer, pass dtype='float16' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float16 by default, call `tf.keras.backend.set_floatx('float16')`. To change just this layer, pass dtype='float16' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300

## Perform inference

Now we will use the model to make predictions on the test dataset to check if it can generalize well.

In [325]:
class ImageGenerator(Sequence):
    """
    Generates images and masks for performing data augmentation in Keras.
    We inherit from Sequence (instead of directly using the keras ImageDataGenerator)
    since we want to perform augmentation on both the input image AND the mask 
    (target). This mechanism needs to be implemented in this class. This class
    also allows to implement new augmentation transforms that are not implemented
    in the core Keras class (illumination, etc.).
    See : https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
    and https://stackoverflow.com/questions/56758592/how-to-customize-imagedatagenerator-in-order-to-modify-the-target-variable-value
    for more details.
    """

    def __init__(self, X_set, # input images and masks
                 n_channels_ims=1, n_channels_masks=1,
                 batch_size: int=4, dim: tuple=(512, 512),
                 n_channels: int=1, # informations 
                 normalize=True, reshape=False, crop=None, # preprocessing params
                 restrict_to=""): # data augmentation params
        """
        X_set (list, array or str): pointer to the images (Bright-Field). If str
        the string is assumed to be pointing at some directory.
        Y_set (list; array or str): pointer to the masks (target). If str
        the string is assumed to be pointing at some directory.
        batch_size (int): size of the batch
        dim (tuple): dimension of the images
        n_channels (int) : number of channels of the images (1 for TIF)
        shuffle (bool): Shuffle the dataset between each training epoch
        normalize (bool): normalize the images and masks in the beginning
        reshape (bool): reshape the images and masks to (dim, dim, n_channels)
        histogram_equalization (bool): perform histogram equalization to improve
        rendering using opencv
        horiz_flip_percent ()
        vert_flip_percent
        """
        # super().__init__(n, batch_size, shuffle, seed)
        self.dim = dim
        self.im_size = dim
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_channels_ims = n_channels_ims
        self.n_channels_masks = n_channels_masks
        
        
        self.restrict_to = restrict_to

        # build the X_set in an array. If X_set is a directory containing images
        # then self.X_set doesn't contains the images but the file names, but it
        # is transparent for the user.
        if type(X_set) == list:
            self.from_directory_X = False
            self.X_set = np.array(X_set)
        elif type(X_set) == np.array:
            self.from_directory_X = False
            self.X_set = X_set           
        elif type(X_set) == str: # assuming a path
            self.from_directory_X = True
            self.X_dir = X_set # path to the images dir
            self.X_set = []
            if self.restrict_to == "":
                for k in range(0, len(os.listdir(X_set)), self.n_channels_ims):
                    self.X_set.append(np.array(os.listdir(X_set)[k:k+self.n_channels_ims]))
                self.X_set = np.array(self.X_set)
            else:
                for k in range(0, len(os.listdir(X_set)), self.n_channels_ims):
                    if os.listdir(X_set)[k].startswith(self.restrict_to):
                        self.X_set.append(np.array(os.listdir(X_set)[k:k+self.n_channels_ims]))
                self.X_set = np.array(self.X_set)
        else:
            raise TypeError("X_set should be list, array or path")
        
        # Preprocessing parameters
        self.normalize = normalize
        self.reshape = reshape
        self.crop = crop
        
        # Initialize the indices (shuffle if asked)
        self.on_epoch_end()

    def __len__(self) -> int:
        """
        Number of batches per epoch : we evenly split the train set into samples
        of size batch_size.
        """
        return int(np.floor(self.X_set.shape[0] / self.batch_size))

    def __getitem__(self, index: int):
        """
        Generate one batch of data.
        """
        if index >= self.__len__():
            raise IndexError
            
        # Generate indices corresponding to the images in the batch
        indices = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Generate the batch
        X = self.__data_generation(indices)
        return X
    
    def get_image_idx(self, im_name):
        """
        Used to sort the images by an idx, when they are properly sorted (e.g. when images
        are numbered 1 to 1000 instead of 0001 to 1000). We assume that the numerical index
        is in the form "XXX_tnumericalindex.tiff" where XXC can be anything.
        """
        if "-" in im_name.split(".")[0].split("_")[-1][1:]:
            return int(im_name.split(".")[0].split("_")[-1][1:].split("-")[-1][1:])
        else:
            return int(im_name.split(".")[0].split("_")[-1][1:])
        

    def on_epoch_end(self):
        """
        Updates indexes after each epoch. self.indexes is used to retrieve the
        samples and organize them into batches.
        If shuffle : randomizes the order of the samples in order to give 
        different training batches at each epoch.
        """
        self.indexes = np.arange(self.X_set.shape[0])

    def __data_generation(self, list_IDs: [int]):
        """
        Generates data containing batch_size samples. This is where we load the
        images if they are in a directory, and apply transformations to them.
        """ 
        # Load data (from directory or from X_set depending on the given data)
        if self.from_directory_X:
            batch_X = []
            for im in list_IDs:
                channels = []
                for k in range(self.n_channels_ims):
                    channels.append(np.expand_dims(imageio.imread(f"{self.X_dir}/{self.X_set[im, k]}"), axis=-1)) # add channel axis
                batch_X.append(np.concatenate(channels, axis=-1))
            batch_X = np.array(batch_X)            
        else:
            batch_X = self.X_set[list_IDs]

        # Preprocessing
        if self.crop is not None:
            batch_X = self.perf_crop(batch_X)
            
        if self.reshape:
            batch_X = self.perf_reshape(batch_X)

        if self.normalize:
            batch_X = self.perf_normalize(batch_X)

        return batch_X

    # Preprocessing functions
    def perf_crop(self, images):
        crop_X = int((images.shape[1] - self.crop[0]) // 2)
        crop_Y = int((images.shape[2] - self.crop[1]) // 2)
        new_batch = np.empty((self.batch_size, *self.crop))
        for i, img in enumerate(images):
            if crop_X != 0 and crop_Y != 0:
                new_batch[i] = img[crop_X:-crop_X, crop_Y:-crop_Y]
            elif crop_X != 0:
                new_batch[i] = img[crop_X:-crop_X, :]
            elif crop_Y != 0:
                new_batch[i] = img[:, crop_Y:-crop_Y]
            else:
                new_batch[i] = img
        return new_batch
    
    def perf_reshape(self, images):
        """
        images (np.array): batch of images of shape (batch_size, n_rows, n_cols, n_chans)
        is_images (bool): is it a batch of images (True) or masks (False)
        """
        new_batch = np.empty((self.batch_size, *self.im_size, self.n_channels_ims))
        for i, img in enumerate(images): # the resize function normalizes the images anyways...
            new_batch[i] = skimage.transform.resize(img, (*self.im_size, self.n_channels_ims), anti_aliasing=True)
        return new_batch

    def perf_normalize(self, images):
        """
        Performs per image, per channel normalization by substracting the min and dividing by (max - min)
        """
        new_batch = np.empty(images.shape)
        for i, img in enumerate(images):
            assert (np.min(img, axis=(0, 1)) != np.max(img, axis=(0, 1))).all(), print("Cannot normalize an image containing only 0 or 1 valued pixels. There is likely an empty image in the training set.\nIf cropping was used,"
                                                                                       "maybe the mask doesn't contain any white pixel in the specific region.")
            new_batch[i] = (img - np.min(img, axis=(0, 1))) / (np.max(img, axis=(0, 1)) - np.min(img, axis=(0, 1)))
        return new_batch

In [342]:
# CHANGE DATASET PATH HERE
test_path = "D:\Hugo\BiSeg\Test_Set/H449_1-120"
restrict_to = ""
bs, n_chan_ims, n_chan_ms = 1, 2, 1
test_set = ImageGenerator(test_path, batch_size=bs, dim=(512, 512),
                          n_channels_ims=n_chan_ims, n_channels_masks=n_chan_ms, crop=None, normalize=True, reshape=True, restrict_to=restrict_to)

visualize_data(test_set[0], np.zeros((1, test_set[0].shape[0], test_set[0].shape[1], 1)), n_chan_ims)

In [343]:
predictions = unet.predict(test_set)

whole_test_set = np.concatenate([test_set[i] for i in range(len(test_set))], axis=0)
print(whole_test_set.shape)

def visualize_data_and_predictions(bf, preds, nc_ims=1):
    with napari.gui_qt():
        if nc_ims == 1:
            viewer = napari.view_image(bf[:, :, :, :].squeeze(-1))
            viewer.add_image(preds[:, :, :, :].squeeze(-1), blending="additive", name="BF")
        else:
            viewer = napari.view_image(bf[:, :, :, 1], name="BF")  # bf
            viewer.add_image(bf[:, :, :, 0], blending="additive", name="Fluo", colormap="red")
            viewer.add_image(preds[:, :, :, :].squeeze(-1), blending="additive", name="Predictions", colormap="blue")

plot = True
if plot:
    visualize_data_and_predictions(whole_test_set, predictions, nc_ims=2)

(120, 512, 512, 2)


In [344]:
save_predictions = True
if save_predictions:
    # REPLACE TEST SAVE PREDICTIONS PATH
    save_path = "D:/Hugo/BiSeg/Predictions/BS200-H449_pos2_1-120.tif"
    predicted_nochan = predictions.squeeze(-1)

    imageio.volwrite(save_path, predicted_nochan)

In [32]:
# CHANGE SAVE PATH HERE
os.chdir("D:/Hugo/Whole_Cell/Models/")
unet.save("S1")

INFO:tensorflow:Assets written to: S1\assets


## Save the model

In [None]:
# REPLACE SAVE MODEL PATH
save_model_path = "/content/gdrive/MyDrive/CYBERSCOPE/Migration/Models"
unet.save(save_model_path)