# Drosophila Heart OCM (FlyNet 2.0+) training code

**Author**: Matthew Fishman

**Date**: December 1, 2022

**Description**: This notebook contains the training code to build FlyNet from scratch using the data that can be found on Figshare. A complete description of the model can be found within the Scientific Data manuscript. Data directory paths need to be updated before running.

**Requirements**:
- Python 3.9
- Libraries: cudatoolkit=11.2 cudnn=8.1.0 tensorflow=2.10 scikit-image opencv-python

**License**: MIT License


In [None]:
# Import necessary Tensorflow and Keras libraries
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import ConvLSTM2D, Conv2DTranspose, TimeDistributed, Conv2D, Conv3D
from tensorflow.keras.layers import Input, MaxPooling2D, BatchNormalization, LeakyReLU, Concatenate
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy
import loss_funcs as lf

# Import numpy and data management
import numpy as np
import os
import pickle
from random import shuffle
from random import randint
from random import sample
import skimage.io
from skimage import transform
import cv2

# Optional imports for data visualization (tensorboard)
# Import tensorboard
# %load_ext tensorboard

# Set all constants needed for training the model 
IMAGE_SIZE = "full" # which files to read (full size images are cropped and augmented by default during training)

# Directory where the data is stored (change this to your own directory)
DATA_DIR = "C:/Users/username/Downloads/Drosophila_heart_OCM_dataset"
OUTPUT_DIR = "C:/Users/username/Downloads/Output"

DATA_DIR = "D:/Drosophila_heart_OCM_dataset"
OUTPUT_DIR = "D:/FlyNet/Output"

In [None]:
# Functions for reading in and processing the data
def generate_pk(path):
    """
    Generate a pickle file containing the paths to the individual samples
    
    Parameters:
    path (str): path to the directory containing the data

    Returns:
    pickle_path (str): path to the generated pickle file in OUTPUT_DIR
    """
    # Get a list of all the directories in the data directory
    dirs = os.listdir(path)
    # Initialize a list to store the paths to the individual samples
    data = []
    # Iterate through each directory
    for dir in dirs:
        # Get the path to the directory
        dir_path = os.path.join(path, dir)
        data.append(dir_path)

    # Save the data to a pickle file
    pickle_path = os.path.join(OUTPUT_DIR, "data.pkl").replace("\\","/")
    with open(pickle_path, mode='wb') as f:
        pickle.dump(data, f)

    return pickle_path

def read_pk(path):
    """
    Read in a pickle file and return the data split into training and validation sets

    Parameters:
    path (str): path to the pickle file

    Returns:
    data (np.array): array of paths to individual samples
    train_ids (np.array): array of training indices
    val_ids (np.array): array of validation indices
    """
    with open(path, mode='rb') as f:
        data = pickle.load(f)

    data=np.array(data)
    available_ids = np.array(range(len(data)))
    shuffle(available_ids)

    # ADJUST TRAINING PERCENTAGE HERE
    training_percent = 0.95
    final_train_id = int(len(available_ids)*training_percent)
    train_ids = available_ids[:final_train_id]
    val_ids = available_ids[final_train_id:]

    return data, train_ids, val_ids

def getImg(path):
    """
    Get the path to the image file in the directory

    Parameters:
    path (str): path to the directory

    Returns:
    file_path (str): path to the image file
    """
    global IMAGE_SIZE
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "resize" in the name
    search_string = IMAGE_SIZE + "_img.tiff"

    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No resize file found")

def getMask(path):
    """
    Get the path to the mask file in the directory

    Parameters:
    path (str): path to the directory

    Returns:
    file_path (str): path to the mask file
    """
    global IMAGE_SIZE
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "mask" in the name
    search_string = IMAGE_SIZE + "_mask.tiff"
    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No mask file found")

def getResizeImg(path):
    """
    Get the path to the resized image file in the directory

    Parameters:
    path (str): path to the directory

    Returns:
    file_path (str): path to the resized image file
    """
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "resize" in the name
    search_string = "resize_img.tiff"

    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No resize file found")

def getResizeMask(path):
    """
    Get the path to the resized mask file in the directory

    Parameters:
    path (str): path to the directory

    Returns:
    file_path (str): path to the resized mask file
    """
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "mask" in the name
    search_string = "resize_mask.tiff"
    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No mask file found")

def centerVideo(video):
    """
    Center the video by subtracting the mean of all the frames and dividing by the standard deviation

    Parameters:
    video (np.array): array of video frames

    Returns:
    centered_video (np.array): array of centered video frames
    """
    mean = np.mean(video)
    std = np.std(video)
    centered_video = (video - mean) / std
    return centered_video
    

In [None]:
def generateData(data, available_ids, batch_size):
    """
    Generate training data for the model

    Parameters:
    data (np.array): array of paths to individual samples
    available_ids (np.array): array of available indices
    batch_size (int): number of samples to generate

    Yield:
    outputX (np.array): array of training samples
    outputY (np.array): array of training labels
    """
    #generate train data
    augment = True
    while True:
        # Choose two random IDs from the available IDs
        # INCREASE NUMBER OF SAMPLES HERE IF GPU MEMORY ALLOWS
        s = sample(list(available_ids), 2)
        outputX = []
        outputY = []
        for i in s:
            # Read the image at that ID and convert it to a numpy array
            dir_path = data[i]
            mask_path = getMask(dir_path)
            resize_path = getImg(dir_path)
            img = skimage.io.imread(resize_path)
            img = np.array(img)
            img = np.squeeze(img)
           
            # Read the mask file and convert it to a numpy array
            img_mask=skimage.io.imread(mask_path)
            img_mask=np.array(img_mask)
            if(len(img_mask.shape) > 3):
                img_mask = (img_mask[:,:,:,0]>0.5)*1.0
            else:
                img_mask = (img_mask>0.5)*1.0
            
            # Add a singleton dimension so that all images have a color channel
            train = np.array(img)
            train=train[...,np.newaxis]
            y=np.array(img_mask)
            y=y[...,np.newaxis]
            
            # For training samples shorten to 128 frames per step
            last_start = train.shape[0] - batch_size
            start_loc = randint(0, last_start)
            end_loc = start_loc + batch_size
            train = train[start_loc:end_loc]
            y = y[start_loc:end_loc]

            # The shape of y is [frame, height, width, 1]
            # y is a mask with value 0 or 1, find the max and min x and y coordinates of the mask
            max_x = np.max(np.where(y == 1)[2])
            min_x = np.min(np.where(y == 1)[2])
            max_y = np.max(np.where(y == 1)[1])
            min_y = np.min(np.where(y == 1)[1])

            min_x = 3 if min_x <= 15 else min_x - 12
            min_y = 3 if min_y <= 15 else min_y - 12
            max_x = 124 if max_x >= 112 else max_x + 12
            max_y = 596 if max_y >= 584 else max_y + 12

            low = min_y - 30 if min_y - 30 > 1 else 1
            high = max_y + 30 if max_y + 30 < y.shape[1] - 1 else y.shape[1] - 1

            # Now select the boundaries for the crop
            # Make sure the whole mask is in the frame
            crop_x_min = randint(2, min_x)
            crop_x_max = randint(max_x, 125)
            crop_y_min = randint(low, min_y)
            crop_y_max = randint(max_y, high)

            # Crop the image and mask
            train_roi = train[:, crop_y_min:crop_y_max, crop_x_min:crop_x_max]
            y_roi = y[:, crop_y_min:crop_y_max, crop_x_min:crop_x_max]

            # Now interpolate the image and mask to the original size
            train_list = []
            y_list = []
            for i in range(len(train_roi)):
                train_i = cv2.resize(train_roi[i], (128, 128),interpolation=cv2.INTER_CUBIC)
                train_list.append(train_i)
                y_i = cv2.resize(y_roi[i], (128, 128),interpolation=cv2.INTER_CUBIC)
                y_list.append(y_i)

            train = np.array(train_list)
            y = np.array(y_list)
            
            if augment:
                # Apply random rotation/flip augmentation
                aug = randint(0, 2) # Equal chance for each
                if aug==0:
                    aug_x = train
                    aug_y = y
                elif aug==1:
                    aug_x = np.flip(train, 1)
                    aug_y = np.flip(y, 1)
                elif aug==2:
                    aug_x = np.flip(train, 2)
                    aug_y = np.flip(y, 2)
                elif aug==3:
                    aug_x = np.flip(train, 0)
                    aug_y = np.flip(y, 0)
                    

                # Cast to uint8 before yield
                train = aug_x.astype('float32')
                y = aug_y.astype('float32')

                # Normalize the image
                train = centerVideo(train)

            else:
                train = train.astype('float32')
                y = y.astype('float32')

                # Normalize the image
                train = centerVideo(train)

            outputX.append(train)
            outputY.append(y)

        outputX = np.array(outputX)
        outputY = np.array(outputY)

        yield (outputX, outputY)


def read_validation_files(data, ids):
    """
    Read in the validation files preprocess them the same way and return the validation set

    Parameters:
    data (np.array): array of paths to individual samples
    ids (np.array): array of available indices

    Returns:
    image (np.array): array of validation samples
    y (np.array): array of validation labels
    """
    # make an empty array to hold 
    image_list = []
    mask_list = []
    for i in ids:
        dir_path = data[i]
        mask_path = getResizeMask(dir_path)
        resize_path = getResizeImg(dir_path)
        img = skimage.io.imread(resize_path)
        img = np.array(img)
        img = np.squeeze(img)
        
        # Read the mask file and convert it to a numpy array
        # mask_file=input_dir+'/'+input_name+'_mask.tiff'
        img_mask=skimage.io.imread(mask_path)
        img_mask=np.array(img_mask)
        if(len(img_mask.shape) > 3):
                img_mask = (img_mask[:,:,:,0]>0.5)*1.0
        else:
            img_mask = (img_mask>0.5)*1.0
        
        # Add a singleton dimension so that all images have a color channel
        image = np.array(img)
        image=image[:4000,:,:,np.newaxis]
        y=np.array(img_mask)
        y=y[:4000,:,:,np.newaxis]

        startidx = 0
        endidx = 32
        for i in range(len(image)//32):
            image_list.append(centerVideo(image[startidx:endidx, ...].astype('float32')))
            mask_list.append(y[startidx:endidx, ...].astype('float32'))
            startidx += 32
            endidx += 32

    # Convert the lists into numpy arrays combining the first dimension
    image = np.array(image_list)
    y = np.array(mask_list)
    
    return image, y

In [None]:
# Define the different metrics for measuring the performance of the model
def dice_coeff(y_true, y_pred):
    """
    Calculate the dice coefficient between the true and predicted masks

    Parameters:
    y_true (np.array): array of true masks
    y_pred (np.array): array of predicted masks

    Returns:
    score (float): dice coefficient
    """
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    y_true_f = tf.cast(y_true_f, tf.float32)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score


def dice_loss(y_true, y_pred):
    """
    Calculate the dice loss between the true and predicted masks

    Parameters:
    y_true (np.array): array of true masks
    y_pred (np.array): array of predicted masks

    Returns:
    loss (float): dice loss
    """
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss


def bce_dice_loss(y_true, y_pred):
    """
    Calculate the binary cross entropy dice loss between the true and predicted masks

    Parameters:
    y_true (np.array): array of true masks
    y_pred (np.array): array of predicted masks

    Returns:
    loss (float): binary cross entropy dice loss
    """
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

In [None]:
# Define all the different blocks and arrange them into the model
# The end_block contains the LSTM layers and can be moved around in the model
# Currently there is only LSTM on the first and last layers

def conv_block(input, num_filters):
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(input)
    x = BatchNormalization()(x)
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

def deconv_block(input, num_filters):
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(input)
    x = BatchNormalization()(x)
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

def end_block(input, num_filters):
    x = ConvLSTM2D(num_filters, 5, padding="same", return_sequences=True)(input)
    x = BatchNormalization()(x)
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = TimeDistributed(MaxPooling2D((2, 2)))(x)
    return x, p

def first_layer(input, num_filters):
    x = end_block(input, num_filters)
    p = TimeDistributed(MaxPooling2D((2, 2)))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    x = TimeDistributed(Conv2DTranspose(num_filters, (2, 2), strides=(2,2), padding="same"))(input)
    x = Concatenate()([x, skip_features])
    x = deconv_block(x, num_filters)
    return x

def last_layer(input, skip_features, num_filters):
    x = TimeDistributed(Conv2DTranspose(num_filters, (2, 2), strides=(2,2), padding="same"))(input)
    x = Concatenate()([x, skip_features])
    x = end_block(x, num_filters)
    return x

def create_model(input_shape=(None, 128, 128, 1)):
    inputs = Input(shape=input_shape)

    s1, p1 = first_layer(inputs, 40)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)

    d1 = conv_block(p3, 256)

    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = last_layer(d3, s1, 32)

    classify = Conv3D(1, (1, 1, 1), padding="same", activation='sigmoid')(d4)

    model = Model(inputs=inputs, outputs=classify)
    return model

In [None]:
# Define the main function that will train the model
def main():
    # Start by generating the pickle file (fill in path below if you already have a pickle file)
    pickle_path = generate_pk(DATA_DIR)

    # Read in the pickle file and prepare validation data
    data, train_ids,val_ids = read_pk(pickle_path)
    K.set_image_data_format('channels_last')
    batch_size = 32
    steps_per_epoch = (len(train_ids)*40)//batch_size
    val, y_val = read_validation_files(data, val_ids)
    print(f"Length of training set: {len(train_ids)}, Length of validation set: {len(val_ids)}")

    # Optional: load in an existing model to fine tune rather than training from scratch
    model = create_model()
    # model = load_model("FlyNet.h5", custom_objects = {'dice_coeff': dice_coeff, 'bce_dice_loss': bce_dice_loss, "focal_tversky": lf.Semantic_loss_functions().focal_tversky, "log_cosh_dice_loss": lf.Semantic_loss_functions().log_cosh_dice_loss})

    # Set all remaining parameters and compile the model
    lr = 1e-4
    loss_function = lf.Semantic_loss_functions().log_cosh_dice_loss
    model.compile(
        loss=loss_function,
        optimizer=Adam(lr),
        metrics=[
            tf.keras.metrics.MeanIoU(num_classes=2),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.Precision(),
            dice_coeff
        ]
    )

    # Optional: print out the model summary
    # print(model.summary(line_length=150))

    # Run the model
    model_checkpoint = ModelCheckpoint(OUTPUT_DIR + "/FlyNet_{epoch:02d}.h5", monitor='val_loss', save_best_only=True)
    # Optional: uncomment the line below to use tensorboard and add to list of callbacks in model.fit
    # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=OUTPUT_DIR, histogram_freq=1)
    model.fit(generateData(data, train_ids, batch_size), steps_per_epoch=steps_per_epoch, epochs=40, verbose=1, validation_data=(val, y_val), callbacks=[model_checkpoint])

    return data, val_ids

data, val_ids = main()