In [None]:
import sys
sys.path.append('/home/olyaeeen/Desktop/Ehsan/jupyter/')
from utils import *
import numpy as np
import glob
import nibabel as nib
import PIL
from IPython.display import Image
import cv2
import imageio
import matplotlib.pyplot as plt
import datetime
import os

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate
from tensorflow.keras.layers import Activation
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

os.environ["CUDA_VISIBLE_DEVICES"]=""
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

print(tf.__version__)

In [None]:
# # normal objects
# DATA_DIR='/home/olyaeeen/Desktop/Ehsan/data/normal/'
# image_list = sorted(glob.glob(DATA_DIR+'*-subvolume-normalized.nii.gz'))[-2:]
# gt_list = sorted(glob.glob(DATA_DIR+'*-Segmentation-smoothed-label.nii.gz'))[-2:]

#all objects
DATA_DIR='/home/olyaeeen/Desktop/Ehsan/data/data_all/'
image_list = sorted(glob.glob(DATA_DIR+'*-subvolume-normalized.nii.gz'))
gt_list = sorted(glob.glob(DATA_DIR+'*-subvolume-split-label-myo-fixed-tformedBack.nii.gz'))

CSV_LOG_DIR = '/home/olyaeeen/Desktop/Ehsan/log/unet_training.log'
LOG_DIR = '/home/olyaeeen/Desktop/Ehsan/log/'
MODELS_DIR = '/home/olyaeeen/Desktop/Ehsan/models/unet/'
if not os.path.exists(MODELS_DIR):
    os.makedirs(MODELS_DIR) 

gt_len = len(gt_list)
# print(len(gt_list))

In [None]:
# INPUT_SHAPE = (144, 144, 144)
INPUT_SHAPE = (64, 64, 64)

EPOCHS = 10
BATCH_SIZE = 1
STEP_PER_EPOCH = len(image_list)// BATCH_SIZE
CHANNELS_LAST=True
SAVE_MODEL_PERIOD = 1

tf.keras.backend.set_image_data_format('channels_last' if CHANNELS_LAST is True else 'channels_first')


In [None]:
ds = VolumeDataGenerator(image_list, gt_list, batch_size=BATCH_SIZE, crop=True, dim_crop=INPUT_SHAPE, to_categorical=True, channels_last=CHANNELS_LAST, gt_includes_bg=True)
img, gt = next(iter(ds))
print(img.shape, gt.shape)


In [None]:
if CHANNELS_LAST is True:
    visualize_volume(img[0, :, :, :, 0], undo_categorical(gt, channels_last=True), 10)
else:
    visualize_volume(img[0, 0, :, :, :], undo_categorical(gt, channels_last=False), 10)


In [None]:
# if CHANNELS_LAST is True:
#     visualize_volume(img[0, :, :, :, 0], gt[0, :, :, :, 0], 10)
# else:
#     visualize_volume(img[0, 0, :, :, :], gt[0, 0, :, :, :], 10)


In [None]:
# if CHANNELS_LAST is True:
#     gifs = make_gif(img[0, :, :, :, 0], gt[0, :, :, :, 0])
# else:
#     gifs = make_gif(img[0, 0, :, :, :], gt[0, 0, :, :, :])

# # display(gifs[0])
# # display(gifs[1])
# # display(gifs[2])

In [None]:
class unet(object):

    def __init__(self, use_upsampling=False, learning_rate=0.001,
                 n_cl_in=1, n_cl_out=1, feature_maps = 16,
                 dropout=0.2, print_summary=False,
                 channels_last = True):

        self.channels_last = channels_last
        if channels_last:
            self.concat_axis = -1
            self.data_format = "channels_last"

        else:
            self.concat_axis = 1
            self.data_format = "channels_first"

        #print("Data format = " + self.data_format)
#         K.backend.set_image_data_format(self.data_format)

        self.fms = feature_maps # 16 or 32 feature maps in the first convolutional layer

        self.use_upsampling = use_upsampling
        self.dropout = dropout
        self.print_summary = print_summary
        self.n_cl_in = n_cl_in
        self.n_cl_out = n_cl_out

        # self.loss = self.dice_coef_loss
        self.loss = self.combined_dice_ce_loss

        self.learning_rate = learning_rate
        self.optimizer = keras.optimizers.Adam(lr=self.learning_rate)

        self.metrics= [self.dice_coef, self.soft_dice_coef, "accuracy",
                 self.sensitivity, self.specificity]

        self.custom_objects = {
            "combined_dice_ce_loss": self.combined_dice_ce_loss,
            "dice_coef_loss": self.dice_coef_loss,
            "dice_coef": self.dice_coef,
            "soft_dice_coef": self.soft_dice_coef,
            "sensitivity": self.sensitivity,
            "specificity": self.specificity}

        self.model = self.unet_3d()

    def dice_coef(self, target, prediction, axis=(1, 2, 3), smooth=0.01):
        """
        Sorenson Dice
        \frac{  2 \times \left | T \right | \cap \left | P \right |}{ \left | T \right | +  \left | P \right |  }
        where T is ground truth mask and P is the prediction mask
        """
        prediction = tf.round(prediction)  # Round to 0 or 1

        intersection = tf.reduce_sum(target * prediction, axis=axis)
        union = tf.reduce_sum(target + prediction, axis=axis)
        numerator = tf.constant(2.) * intersection + smooth
        denominator = union + smooth
        coef = numerator / denominator

        return tf.reduce_mean(coef)

    def soft_dice_coef(self, target, prediction, axis=(1, 2, 3), smooth=0.01):
        """
        Sorenson (Soft) Dice - Don't round predictions
        \frac{  2 \times \left | T \right | \cap \left | P \right |}{ \left | T \right | +  \left | P \right |  }
        where T is ground truth mask and P is the prediction mask
        """
        intersection = tf.reduce_sum(target * prediction, axis=axis)
        union = tf.reduce_sum(target + prediction, axis=axis)
        numerator = tf.constant(2.) * intersection + smooth
        denominator = union + smooth
        coef = numerator / denominator

        return tf.reduce_mean(coef)


    def dice_coef_loss(self, target, prediction, axis=(1, 2, 3), smooth=0.1):
        """
        Sorenson (Soft) Dice loss
        Using -log(Dice) as the loss since it is better behaved.
        Also, the log allows avoidance of the division which
        can help prevent underflow when the numbers are very small.
        """
        intersection = tf.reduce_sum(prediction * target, axis=axis)
        p = tf.reduce_sum(prediction, axis=axis)
        t = tf.reduce_sum(target, axis=axis)
        numerator = tf.reduce_mean(intersection + smooth)
        denominator = tf.reduce_mean(t + p + smooth)
        dice_loss = -tf.math.log(2.*numerator) + tf.math.log(denominator)

        return dice_loss


    def combined_dice_ce_loss(self, target, prediction, axis=(1, 2, 3),
                              smooth=0.1, weight=0.7):
        """
        Combined Dice and Binary Cross Entropy Loss
        """
        return weight*self.dice_coef_loss(target, prediction, axis, smooth) + \
            (1-weight)*keras.losses.binary_crossentropy(target, prediction)


    def unet_3d(self):
        """
        3D U-Net
        """
        def ConvolutionBlock(x, name, fms, params):
            """
            Convolutional block of layers
            Per the original paper this is back to back 3D convs
            with batch norm and then ReLU.
            """

            x = keras.layers.Conv3D(filters=fms, **params, name=name+"_conv0")(x)
            x = keras.layers.BatchNormalization(name=name+"_bn0")(x)
            x = keras.layers.Activation("relu", name=name+"_relu0")(x)

            x = keras.layers.Conv3D(filters=fms, **params, name=name+"_conv1")(x)
            x = keras.layers.BatchNormalization(name=name+"_bn1")(x)
            x = keras.layers.Activation("relu", name=name)(x)

            return x

        if self.channels_last:
            input_shape = [None, None, None, self.n_cl_in]
        else:
            input_shape = [self.n_cl_in, None, None, None]

        inputs = keras.layers.Input(shape=input_shape,
                                name="MRImages")

        params = dict(kernel_size=(3, 3, 3), activation=None,
                      padding="same", data_format=self.data_format,
                      kernel_initializer="he_uniform")

        # Transposed convolution parameters
        params_trans = dict(data_format=self.data_format,
                            kernel_size=(2, 2, 2), strides=(2, 2, 2),
                            padding="same")


        # BEGIN - Encoding path
        encodeA = ConvolutionBlock(inputs, "encodeA", self.fms, params)
        poolA = keras.layers.MaxPooling3D(name="poolA", pool_size=(2, 2, 2))(encodeA)

        encodeB = ConvolutionBlock(poolA, "encodeB", self.fms*2, params)
        poolB = keras.layers.MaxPooling3D(name="poolB", pool_size=(2, 2, 2))(encodeB)

        encodeC = ConvolutionBlock(poolB, "encodeC", self.fms*4, params)
        poolC = keras.layers.MaxPooling3D(name="poolC", pool_size=(2, 2, 2))(encodeC)

        encodeD = ConvolutionBlock(poolC, "encodeD", self.fms*8, params)
        poolD = keras.layers.MaxPooling3D(name="poolD", pool_size=(2, 2, 2))(encodeD)

        encodeE = ConvolutionBlock(poolD, "encodeE", self.fms*16, params)
        # END - Encoding path

        # BEGIN - Decoding path
        if self.use_upsampling:
            up = keras.layers.UpSampling3D(name="upE", size=(2, 2, 2),
                                       interpolation="bilinear")(encodeE)
        else:
            up = keras.layers.Conv3DTranspose(name="transconvE", filters=self.fms*8,
                                          **params_trans)(encodeE)
        concatD = keras.layers.concatenate(
            [up, encodeD], axis=self.concat_axis, name="concatD")

        decodeC = ConvolutionBlock(concatD, "decodeC", self.fms*8, params)

        if self.use_upsampling:
            up = keras.layers.UpSampling3D(name="upC", size=(2, 2, 2),
                                       interpolation="bilinear")(decodeC)
        else:
            up = keras.layers.Conv3DTranspose(name="transconvC", filters=self.fms*4,
                                          **params_trans)(decodeC)
        concatC = keras.layers.concatenate(
            [up, encodeC], axis=self.concat_axis, name="concatC")

        decodeB = ConvolutionBlock(concatC, "decodeB", self.fms*4, params)

        if self.use_upsampling:
            up = keras.layers.UpSampling3D(name="upB", size=(2, 2, 2),
                                       interpolation="bilinear")(decodeB)
        else:
            up = keras.layers.Conv3DTranspose(name="transconvB", filters=self.fms*2,
                                          **params_trans)(decodeB)
        concatB = keras.layers.concatenate(
            [up, encodeB], axis=self.concat_axis, name="concatB")

        decodeA = ConvolutionBlock(concatB, "decodeA", self.fms*2, params)

        if self.use_upsampling:
            up = keras.layers.UpSampling3D(name="upA", size=(2, 2, 2),
                                       interpolation="bilinear")(decodeA)
        else:
            up = keras.layers.Conv3DTranspose(name="transconvA", filters=self.fms,
                                          **params_trans)(decodeA)
        concatA = keras.layers.concatenate(
            [up, encodeA], axis=self.concat_axis, name="concatA")

        # END - Decoding path

        convOut = ConvolutionBlock(concatA, "convOut", self.fms, params)

        prediction = keras.layers.Conv3D(name="PredictionMask",
                                     filters=self.n_cl_out, kernel_size=(1, 1, 1),
                                     data_format=self.data_format,
                                     activation="sigmoid")(convOut)

        model = keras.models.Model(inputs=[inputs], outputs=[prediction])

        if self.print_summary:
            model.summary()

        return model


    def sensitivity(self, target, prediction, axis=(1, 2, 3), smooth=0.0001):
        """
        Sensitivity
        """
        prediction = tf.round(prediction)

        intersection = tf.reduce_sum(prediction * target, axis=axis)
        coef = (intersection + smooth) / (tf.reduce_sum(target,
                                                        axis=axis) + smooth)
        return tf.reduce_mean(coef)


    def specificity(self, target, prediction, axis=(1, 2, 3), smooth=0.0001):
        """
        Specificity
        """
        prediction = tf.round(prediction)

        intersection = tf.reduce_sum(prediction * target, axis=axis)
        coef = (intersection + smooth) / (tf.reduce_sum(prediction,
                                                        axis=axis) + smooth)
        return tf.reduce_mean(coef)

In [None]:
unet_model = unet(use_upsampling=False,
                  learning_rate=0.01,
                  n_cl_in=1,
                  n_cl_out=9,  # single channel (greyscale)
                  feature_maps = 16,
                  dropout=0.2,
                  print_summary=True,
                  channels_last = CHANNELS_LAST)  # channels first or last

unet_model.model.compile(optimizer=unet_model.optimizer,
              loss=unet_model.loss,
              metrics=unet_model.metrics)

In [None]:

log_dir = LOG_DIR + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch = '2, 20')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

# Creates a file writer for the log directory.
file_writer = tf.summary.create_file_writer(logdir)

# Using the file writer, log the reshaped image.
with file_writer.as_default():
  tf.summary.image("Training data", img, step=0)


# model checkpoint
saved_models = sorted(glob.glob(MODELS_DIR+'*.h5'))
initial_epoch=0
if len(saved_models) is not 0:
  last_epoch_name = saved_models[-1]
  model = tf.keras.models.load_model(last_epoch_name, custom_objects=unet_model.custom_objects)
  initial_epoch = int(str(last_epoch_name)[-9:-3])
filepath = MODELS_DIR+"{epoch:06d}.h5"
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=False, save_freq=gt_len*SAVE_MODEL_PERIOD)

# csv logger
csv_logger = tf.keras.callbacks.CSVLogger(CSV_LOG_DIR, append=True)


# Keep reducing learning rate if we get to plateau
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.2,
                                          patience=5, min_lr=0.0001)




In [None]:
def visualize_gt(gt, num_slices = 7):
  # images and gts
  coronal_gt = np.flip(gt, axis=2)
  for lbl in range(9):
    display(concat_h([coronal_gt[slc, :, :, lbl] for slc in range(0, coronal_gt.shape[0], coronal_gt.shape[0]//num_slices)], mode='L'))

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    # clear_output(wait=True)
    out = unet_model.model.predict(img)

    print('out', out.shape)
    visualize_volume(img[0, :, :, :, 0], undo_categorical(out, channels_last=True), 10)
    visualize_gt(out[0, :, :, :, :])
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
    
    with file_writer.as_default():
      tf.summary.image("Training data", img, step=0)


In [None]:
history = unet_model.model.fit(ds,
        steps_per_epoch=STEP_PER_EPOCH,
        epochs=1000, 
        callbacks=[tensorboard_callback, 
                   model_checkpoint, 
                   csv_logger, 
                   DisplayCallback(), 
                   reduce_lr])