In [None]:
# Matthew Fishman (6/24/2022)

# Tensorflow and Keras
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model, load_model, save_model
from tensorflow.keras.layers import ConvLSTM2D, TimeDistributed, Conv2D, Conv3D, BatchNormalization, Conv3DTranspose
from tensorflow.keras.layers import Activation, MaxPooling2D, MaxPool3D, Dropout, Flatten, Dense, Input, LeakyReLU, Bidirectional

from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model

from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.losses import binary_crossentropy
import loss_funcs as lf


# Import numpy and data management
import numpy as np
import os
import pickle
from random import shuffle
from random import randint
from random import sample
import skimage.io
from skimage import transform
import cv2

# Import tensorboard
%load_ext tensorboard


type = "full"

In [None]:
# Read the pickle file and return the ids for the samples in the training and validation sets
def read_pk(path):
    with open(path, mode='rb') as f:
        data = pickle.load(f)

    data=np.array(data)
    y=np.zeros(len(data))
    available_ids = np.array(range(len(data)))
    shuffle(available_ids)


    if(len(available_ids)>=70):
        multiplier = 0.95
    else:
        multiplier = 0.9
    final_train_id = int(len(available_ids)*multiplier)
    train_ids = available_ids[:final_train_id]
    val_ids = available_ids[final_train_id:]

    return data, train_ids, val_ids


def getImg(path):
    global type
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "resize" in the name
    search_string = type + "_img.tiff"

    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No resize file found")


def getMask(path):
    global type
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "mask" in the name
    search_string = type + "_mask.tiff"
    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No mask file found")

def getResizeImg(path):
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "resize" in the name
    search_string = "resize_img.tiff"

    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No resize file found")

def getResizeMask(path):
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "mask" in the name
    search_string = "resize_mask.tiff"
    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No mask file found")

def centerVideo(video):
    # Center the video by subtracting the mean of all the frames and dividing by the standard deviation
    mean = np.mean(video)
    std = np.std(video)
    centered_video = (video - mean) / std
    return centered_video

def shift(arr, num, axis):
    # Shift the array by num along the axis and fill with zeros
    result = np.zeros_like(arr)
    if axis == 0:
        result[num:, :] = arr[:-num, :]
    elif axis == 1:
        result[:, num:] = arr[:, :-num]
    elif axis == 2:
        result[:, :, num:] = arr[:, :, :-num]
    else:
        raise ValueError("axis should be 0, 1 or 2")
    return result
    

In [None]:
def generateData(data, available_ids, batch_size):
    #generate train data
    augment = True
    while True:
        # Choose two random IDs from the available IDs
        s = sample(list(available_ids), 3)
        outputX = []
        outputY = []
        for i in s:
            # print(i)
            # Read the image at that ID and convert it to a numpy array
            dir_path = data[i]
            # print(dir_path)
            mask_path = getMask(dir_path)
            resize_path = getImg(dir_path)
            img = skimage.io.imread(resize_path)
            img = np.array(img)
            img = np.squeeze(img)
           
            # Read the mask file and convert it to a numpy array
            img_mask=skimage.io.imread(mask_path)
            img_mask=np.array(img_mask)
            if(len(img_mask.shape) > 3):
                img_mask = (img_mask[:,:,:,0]>0.5)*1.0
            else:
                img_mask = (img_mask>0.5)*1.0
            
            # Add a singleton dimension so that all images have a color channel
            train = np.array(img)
            train=train[...,np.newaxis]
            y=np.array(img_mask)
            y=y[...,np.newaxis]
            
            # For training samples shorten to 128 frames per step
            last_start = train.shape[0] - batch_size
            start_loc = randint(0, last_start)
            end_loc = start_loc + batch_size
            train = train[start_loc:end_loc]
            y = y[start_loc:end_loc]

            # The shape of y is [frame, height, width, 1]
            # y is a mask with value 0 or 1, find the max and min x and y coordinates of the mask
            max_x = np.max(np.where(y == 1)[2])
            min_x = np.min(np.where(y == 1)[2])
            max_y = np.max(np.where(y == 1)[1])
            min_y = np.min(np.where(y == 1)[1])

            min_x = 3 if min_x <= 15 else min_x - 12
            min_y = 3 if min_y <= 15 else min_y - 12
            max_x = 124 if max_x >= 112 else max_x + 12
            max_y = 596 if max_y >= 584 else max_y + 12
            # print(min_x, max_x, min_y, max_y)

            low = min_y - 30 if min_y - 30 > 1 else 1
            high = max_y + 30 if max_y + 30 < y.shape[1] - 1 else y.shape[1] - 1

            # Now select the boundaries for the crop
            # Make sure the whole mask is in the frame
            crop_x_min = randint(2, min_x)
            crop_x_max = randint(max_x, 125)
            crop_y_min = randint(low, min_y)
            crop_y_max = randint(max_y, high)
            # print(crop_x_min, crop_x_max, crop_y_min, crop_y_max)

            # Crop the image and mask
            train_roi = train[:, crop_y_min:crop_y_max, crop_x_min:crop_x_max]
            y_roi = y[:, crop_y_min:crop_y_max, crop_x_min:crop_x_max]

            # Now interpolate the image and mask to the original size
            train_list = []
            y_list = []
            for i in range(len(train_roi)):
                train_i = cv2.resize(train_roi[i], (128, 128),interpolation=cv2.INTER_CUBIC)
                train_list.append(train_i)
                y_i = cv2.resize(y_roi[i], (128, 128),interpolation=cv2.INTER_CUBIC)
                y_list.append(y_i)

            train = np.array(train_list)
            y = np.array(y_list)
            
            if augment:
                # Apply random rotation/flip augmentation
                aug = randint(0, 2) # Equal chance for each
                if aug==0:
                    aug_x = train
                    aug_y = y
                elif aug==1:
                    aug_x = np.flip(train, 1)
                    aug_y = np.flip(y, 1)
                elif aug==2:
                    aug_x = np.flip(train, 2)
                    aug_y = np.flip(y, 2)
                elif aug==3:
                    aug_x = np.flip(train, 0)
                    aug_y = np.flip(y, 0)
                    

                # Cast to uint8 before yield
                train = aug_x.astype('float32')
                y = aug_y.astype('float32')

                # Divide the image by 255 to normalize it
                # train /= 255.0
                train = centerVideo(train)

            else:
                train = train.astype('float32')
                y = y.astype('float32')

                # Divide the image by 255 to normalize it
                # train /= 255.0
                train = centerVideo(train)

            # print(train.shape)
            outputX.append(train)
            outputY.append(y)

        outputX = np.array(outputX)
        outputY = np.array(outputY)

        # print(outputX.shape)
        # print(outputY.shape)
        # print(outputX[0][0])
        # print(outputX.shape)
        # return outputX, outputY
        yield (outputX, outputY)

In [None]:
def readFiles(data, ids):
    # make an empty array to hold 
    train_list = []
    mask_list = []
    for i in ids:
        dir_path = data[i]
        mask_path = getResizeMask(dir_path)
        resize_path = getResizeImg(dir_path)
        img = skimage.io.imread(resize_path)
        img = np.array(img)
        img = np.squeeze(img)
        
        # Read the mask file and convert it to a numpy array
        # mask_file=input_dir+'/'+input_name+'_mask.tiff'
        img_mask=skimage.io.imread(mask_path)
        img_mask=np.array(img_mask)
        if(len(img_mask.shape) > 3):
                img_mask = (img_mask[:,:,:,0]>0.5)*1.0
        else:
            img_mask = (img_mask>0.5)*1.0
        
        # Add a singleton dimension so that all images have a color channel
        train = np.array(img)
        train=train[:4000,:,:,np.newaxis]
        y=np.array(img_mask)
        y=y[:4000,:,:,np.newaxis]

        startidx = 0
        endidx = 32
        for i in range(4000//64):
            # train_list.append(train[startidx:endidx, ...].astype('float32') / 255.0)
            train_list.append(centerVideo(train[startidx:endidx, ...].astype('float32')))
            mask_list.append(y[startidx:endidx, ...].astype('float32'))
            startidx += 64
            endidx += 64

    # Convert the lists into numpy arrays combining the first dimension
    train = np.array(train_list)
    y = np.array(mask_list)
    
    return train, y

# Define the different metrics for measuring the performance of the model
def dice_coeff(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    y_true_f = tf.cast(y_true_f, tf.float32)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score


def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss


def bce_dice_loss(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

def step_decay(epoch):
    initial_lrate = 0.00001
    drop = 0.5
    epochs_drop = 5
    lrate = initial_lrate * np.power(drop, np.floor((1+epoch)/epochs_drop))
    return lrate

In [None]:
def conv_block(input, num_filters):
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(input)
    x = BatchNormalization()(x)

    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    return x

def deconv_block(input, num_filters):
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(input)
    x = BatchNormalization()(x)
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    return x

def end_block(input, num_filters):
    x = ConvLSTM2D(num_filters, 5, padding="same", return_sequences=True)(input)
    x = BatchNormalization()(x)

    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = TimeDistributed(MaxPooling2D((2, 2)))(x)
    return x, p

def first_layer(input, num_filters):
    x = end_block(input, num_filters)
    p = TimeDistributed(MaxPooling2D((2, 2)))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    x = TimeDistributed(Conv2DTranspose(num_filters, (2, 2), strides=(2,2), padding="same"))(input)
    x = Concatenate()([x, skip_features])
    x = deconv_block(x, num_filters)
    return x

def last_layer(input, skip_features, num_filters):
    x = TimeDistributed(Conv2DTranspose(num_filters, (2, 2), strides=(2,2), padding="same"))(input)
    x = Concatenate()([x, skip_features])
    x = end_block(x, num_filters)
    return x

def create_model(input_shape=(None, 128, 128, 1)):
    inputs = Input(shape=input_shape)

    s1, p1 = first_layer(inputs, 40)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)

    d1 = conv_block(p3, 256)

    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = last_layer(d3, s1, 32)

    classify = Conv3D(1, (1, 1, 1), padding="same", activation='sigmoid')(d4)

    model = Model(inputs=inputs, outputs=classify)
    return model

In [None]:
# Define the main function that will train the model
def main():
    directory = "C:/Python/Matthew/Training"
    # steps_per_id = int(400/batch_size) + 1

    data, train_ids,val_ids = read_pk("C:/Python/Matthew/Training/full_90_804.pk")
    K.set_image_data_format('channels_last')
    batch_size = 32
    steps_per_epoch = (len(train_ids)*40)//batch_size
    print(len(val_ids))
    val, y_val = readFiles(data, val_ids)

    loss_function = lf.Semantic_loss_functions().log_cosh_dice_loss

    # model = create_model()
    model = load_model("C:/Python/Matthew/Training/LSTM_sci_809_540.h5", custom_objects = {'dice_coeff': dice_coeff, 'bce_dice_loss': bce_dice_loss, "focal_tversky": lf.Semantic_loss_functions().focal_tversky, "log_cosh_dice_loss": lf.Semantic_loss_functions().log_cosh_dice_loss})
    # lr = 2.5e-3
    lr = 1e-4
    model.compile(
        loss=loss_function,
        optimizer=Adam(lr),
        metrics=[
            tf.keras.metrics.MeanIoU(num_classes=2),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.Precision(),
            dice_coeff
        ]
    )
    # print(model.summary(line_length=150))
    model_checkpoint = ModelCheckpoint(directory + "/LSTM_sci_809_6{epoch:02d}.h5", monitor='val_loss', save_best_only=False)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="C:/Python/Matthew/Training/LSTMLog6", histogram_freq=1)
    model.fit(generateData(data, train_ids, batch_size), steps_per_epoch=steps_per_epoch, epochs=40, verbose=1, validation_data=(val, y_val), callbacks=[model_checkpoint, tensorboard_callback])

    return data, val_ids

data, val_ids = main()