In [2]:
import numpy as np
from matplotlib import pyplot as plt
import SimpleITK as sitk
import os

In [3]:
def register():
    output_dir = "../HaN-Seg Registered"

    for i in range(1, 43):
        case_num = f"{i:02d}"
        filepath_ct = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_IMG_CT.nrrd"
        filepath_mri = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_IMG_MR_T1.nrrd"
        ct = sitk.ReadImage(filepath_ct, sitk.sitkFloat32)
        mri = sitk.ReadImage(filepath_mri, sitk.sitkFloat32)
        
        # Registers the MRI image to the CT image
        mri_registered = register_helper(ct, mri)

        # Saves image to the HaN-Seg Registration Folder
        output_path = os.path.join(output_dir, f"MRI_Case_{case_num}.nrrd")
        sitk.WriteImage(mri_registered, output_path)

    return


# Add notation here
def register_helper(fixed_image, moving_image):

    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image,
        moving_image,
        sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY,
    )

    min_value = float(sitk.GetArrayViewFromImage(moving_image).min())
    moving_resampled = sitk.Resample(
        moving_image,
        fixed_image,
        initial_transform,
        sitk.sitkLinear,
        min_value,
        moving_image.GetPixelID(),
    )

    registration_method = sitk.ImageRegistrationMethod()

    # Similarity metric settings.
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    registration_method.SetInterpolator(sitk.sitkLinear)

    # Optimizer settings.
    registration_method.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=100,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10,
    )
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework.
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Don't optimize in-place, we would possibly like to run this cell multiple times.
    registration_method.SetInitialTransform(initial_transform, inPlace=False)

    final_transform = registration_method.Execute(
        sitk.Cast(fixed_image, sitk.sitkFloat32), sitk.Cast(moving_image, sitk.sitkFloat32)
    )

    moving_resampled = sitk.Resample(
        moving_image,
        fixed_image,
        final_transform,
        sitk.sitkLinear,
        min_value,
        moving_image.GetPixelID(),
    )
    return moving_resampled

In [6]:
register()

In [4]:
def pruneVolume():
    output_dir = "../HaN-Seg Pruned"
    for i in range(1, 43):
        case_num = f"{i:02d}"
        filepath_mri = f"../HaN-Seg Registered/MRI_Case_{case_num}.nrrd"
        mri = sitk.ReadImage(filepath_mri, sitk.sitkFloat32)
        mri = sitk.GetArrayFromImage(mri).astype(np.float32)

        filepath_ct = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_IMG_CT.nrrd"
        ct = sitk.ReadImage(filepath_ct, sitk.sitkFloat32)
        ct = sitk.GetArrayFromImage(ct).astype(np.float32)

        filepath = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_OAR_Bone_Mandible.seg.nrrd"
        gt = sitk.ReadImage(filepath, sitk.sitkFloat32)
        gt = sitk.GetArrayFromImage(gt).astype(np.float32)

        gt_var = np.var(gt, axis = (1, 2))
        # mask = gt_var > 1e-10
        mask = gt_var > 0.0001
        mri = mri[mask]
        mri = sitk.GetImageFromArray(mri)
        ct = ct[mask]
        ct = sitk.GetImageFromArray(ct)
        gt = gt[mask]
        gt = sitk.GetImageFromArray(gt)

        output_path = os.path.join(output_dir, "MRI", f"MRI_Case_{case_num}.nrrd")
        sitk.WriteImage(mri, output_path)

        output_path = os.path.join(output_dir, "CT", f"CT_Case_{case_num}.nrrd")
        sitk.WriteImage(ct, output_path)

        output_path = os.path.join(output_dir, "GT", f"GT_Case_{case_num}.nrrd")
        sitk.WriteImage(gt, output_path)

In [5]:
pruneVolume()

In [6]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split

BATCH_SIZE = 16
BUFFER_SIZE = 1000
RANDOM_SEED = 42
IMAGE_RESIZE = 256

In [7]:
def normalize(img_path, mean, std, output):
    for i in range(len(img_path)):
        img = sitk.ReadImage(img_path[i], sitk.sitkFloat32)
        img = sitk.GetArrayFromImage(img).astype(np.float32)

        normalized = (img - mean) / std
        normalized = sitk.GetImageFromArray(normalized)

        output_path = output + f"{i}.nrrd"
        sitk.WriteImage(normalized, output_path)

In [8]:
def computeNormal(img_path):
    img_sum = 0.0
    img_sq_sum = 0.0
    total = 0

    for path in img_path:
        img = sitk.ReadImage(path, sitk.sitkFloat32)
        img = sitk.GetArrayFromImage(img).astype(np.float32)
        img_sum += np.sum(img)
        img_sq_sum += np.sum(img ** 2)
        total += img.size

    train_mean = img_sum/total
    train_std = np.sqrt(img_sq_sum/total - (train_mean**2))

    return train_mean, train_std       


In [None]:
def load_images(mri_path, ct_path, mask_path):
    mri_path = mri_path.numpy().decode("utf-8")
    ct_path = ct_path.numpy().decode("utf-8")
    mask_path = mask_path.numpy().decode("utf-8")

    mri = sitk.ReadImage(mri_path, sitk.sitkFloat32)
    ct = sitk.ReadImage(ct_path, sitk.sitkFloat32)
    gt = sitk.ReadImage(mask_path, sitk.sitkFloat32)

    slices_mri = sitk.GetArrayFromImage(mri).astype(np.float32)
    slices_ct = sitk.GetArrayFromImage(ct).astype(np.float32)
    slices_gt = sitk.GetArrayFromImage(gt).astype(np.float32)

    slices_gt = (slices_gt > 0.5).astype(np.float32)

    slices_mri_resize = np.expand_dims(slices_mri, axis=-1)
    slices_ct_resize = np.expand_dims(slices_ct, axis=-1)
    slices_gt_resize = np.expand_dims(slices_gt, axis=-1)

    slices_mri_resize = tf.image.resize(slices_mri_resize, [IMAGE_RESIZE, IMAGE_RESIZE], method=tf.image.ResizeMethod.BILINEAR)
    slices_ct_resize = tf.image.resize(slices_ct_resize, [IMAGE_RESIZE, IMAGE_RESIZE], method=tf.image.ResizeMethod.BILINEAR)
    slices_gt_resize = tf.image.resize(slices_gt_resize, [IMAGE_RESIZE, IMAGE_RESIZE], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    slices = np.concatenate([slices_mri_resize, slices_ct_resize], axis=-1)

    return slices, slices_gt_resize

def lammy_func(mri_path, ct_path, mask_path):
    img, mask = tf.py_function(
        func=load_images,
        inp=[mri_path, ct_path, mask_path],
        Tout=[tf.float32, tf.float32]
    )
    img.set_shape([None, IMAGE_RESIZE, IMAGE_RESIZE, 2])
    mask.set_shape([None, IMAGE_RESIZE, IMAGE_RESIZE, 1])
    return img, mask

In [18]:
mri_path = []
ct_path = []
gt_path = []

# mri_train_mean, mri_train_sd, ct_train_mean, ct_train_sd = normalize()

for i in range(1, 43):
    case_num = f"{i:02d}"
    filepath_mri = f"../HaN-Seg Pruned/MRI/MRI_Case_{case_num}.nrrd"
    mri_path.append(filepath_mri)
    filepath_ct = f"../HaN-Seg Pruned/CT/CT_Case_{case_num}.nrrd"
    ct_path.append(filepath_ct)
    filepath_gt = f"../HaN-Seg Pruned/GT/GT_Case_{case_num}.nrrd"
    gt_path.append(filepath_gt)

mri_train, mri_test, ct_train, ct_test, gt_train, gt_test = train_test_split(mri_path, ct_path, gt_path, test_size=0.2, random_state=RANDOM_SEED)

mri_mean, mri_std = computeNormal(mri_train)
ct_mean, ct_std = computeNormal(ct_train)

normalize(mri_train, mri_mean, mri_std, "../HaN-Seg Pruned/Train/MRI/")
normalize(ct_train, ct_mean, ct_std, "../HaN-Seg Pruned/Train/CT/")

mri_norm_train = []
ct_norm_train = []

for i in range(len(mri_train)):
    mri_norm_train.append(f"../HaN-Seg Pruned/Train/MRI/{i}.nrrd")
    ct_norm_train.append(f"../HaN-Seg Pruned/Train/CT/{i}.nrrd")

train_dataset = tf.data.Dataset.from_tensor_slices((mri_norm_train, ct_norm_train, gt_train))

train_dataset = (
    train_dataset
    .map(lammy_func, num_parallel_calls=tf.data.AUTOTUNE)
    .unbatch()
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
    )

normalize(mri_test, mri_mean, mri_std, "../HaN-Seg Pruned/Test/MRI/")
normalize(ct_test, ct_mean, ct_std, "../HaN-Seg Pruned/Test/CT/")

mri_norm_test = []
ct_norm_test = []

for i in range(len(mri_test)):
    mri_norm_test.append(f"../HaN-Seg Pruned/Test/MRI/{i}.nrrd")
    ct_norm_test.append(f"../HaN-Seg Pruned/Test/CT/{i}.nrrd")

test_dataset = tf.data.Dataset.from_tensor_slices((mri_norm_test, ct_norm_test, gt_test))

test_dataset = (
    test_dataset
    .map(lammy_func, num_parallel_calls=tf.data.AUTOTUNE)
    .unbatch()
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
    )

In [19]:
## A single Encoding step in the Contracting path of a U-Net CNN
## @Inputs:
##       inputs: image of size (nxn) with k feature channels
##       num_channels: number of channels to have in output image (i.e. depth of output tensor)
## @Outputs: 
##       x: image of size (n/2 x n/2) with num_channels feature channels
def encode_block(inputs, num_channels):
    # Extract num_channels feature channels from image
    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    skip = tf.keras.layers.Activation('relu')(x)

    # Downsample each channels feature map by a factor of 2
    x = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=2)(skip)

    return skip, x

## A single Decoding step in the Expanding path of a U-Net CNN
## @Inputs:
##       inputs: image of size (nxn) with k feature channels
##       skip_connection: tensor of corresponding encoding block
##       num_channels: number of channels to have in output image (i.e. depth of output tensor)
## @Outputs: 
##       x: image of size (2nx2n) with num_channels feature channels
def decode_block(inputs, skip_connection, num_channels):
    # Upsample image by doubling feature space while changing feature channels to num_channels
    x = tf.keras.layers.Conv2DTranspose(num_channels, (2,2), strides=2, padding='same')(inputs)

    # Concatenate the skip_channel and the upsampled image (doubles the feature channels)
    # Might need to resize skip_connection, but should be fine b/c same padding in encoding
    x = tf.keras.layers.Concatenate()([x, skip_connection])
    
    # Merge feature channels from the skip_connection and upsampled input image
    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    return x

# Metrics
# Source: https://medium.com/mastering-data-science/understanding-evaluation-metrics-in-medical-image-segmentation-d289a373a3f
def dice_coeff(y_true, y_pred):
    epsilon = 1e-6
    threshold = 0.5
    y_pred = tf.cast(y_pred > threshold, tf.float32)
    # y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))
    return (2*tp+epsilon)/(2*tp + fp + fn+epsilon)

def rand_index(y_true, y_pred):
    y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))
    return tp / (tp + fp + fn)

def jaccard_index(y_true, y_pred):
    y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))
    tn = tf.reduce_sum(tf.cast((1 - y_true) * (1 - y_pred), tf.float32))
    return (tp + tn) / (tp + tn + fn + fp)

def specificity(y_true, y_pred):
    y_pred = tf.cast(y_pred, tf.float32)
    true_negatives = tf.reduce_sum(tf.cast((1 - y_true) * (1 - y_pred), tf.float32))
    possible_negatives = tf.reduce_sum(tf.cast(1 - y_true, tf.float32))
    return true_negatives / (possible_negatives + tf.keras.backend.epsilon())

In [20]:
def custom_loss1(y_true, y_pred):
    dice_weight = 50.0
    bce_weight = 1.0
    bce = tf.keras.losses.BinaryFocalCrossentropy(alpha=0.75)
    loss_dice = 1.0 - dice_coeff(y_true, y_pred)
    loss_bce = bce(y_true, y_pred)
    loss = dice_weight * loss_dice + bce_weight * loss_bce
    return loss

def custom_loss2(y_true, y_pred):
    dice_weight = 50.0
    bce_weight = 1.0
    bce = tf.keras.losses.BinaryFocalCrossentropy(alpha=0.25)
    loss_dice = 1.0 - dice_coeff(y_true, y_pred)
    loss_bce = bce(y_true, y_pred)
    loss = dice_weight * loss_dice + bce_weight * loss_bce
    return loss

In [21]:
## Defining the model

input = tf.keras.layers.Input(shape=(IMAGE_RESIZE, IMAGE_RESIZE, 2))

# Do 5 calls of encode_block to end up with a 32x32x512 tensor
s1, e1 = encode_block(input, 32)
s2, e2 = encode_block(e1, 64)
s3, e3 = encode_block(e2, 128)
s4, e4 = encode_block(e3, 256)
# s5, e5 = encode_block(e4, 512)

# Bottleneck
# b1 = tf.keras.layers.Conv2D(1024, 3, padding='same')(e5)
# b1 = tf.keras.layers.Activation('relu')(b1)
# b1 = tf.keras.layers.Conv2D(1024, 3, padding='same')(b1)
# b1 = tf.keras.layers.Activation('relu')(b1)

b1 = tf.keras.layers.Conv2D(512, 3, padding='same')(e4)
b1 = tf.keras.layers.Activation('relu')(b1)
b1 = tf.keras.layers.Conv2D(512, 3, padding='same')(b1)
b1 = tf.keras.layers.Activation('relu')(b1)

# Do 5 calls of decode_block
# d1 = decode_block(b1, s5, 512)
d2 = decode_block(b1, s4, 256)
d3 = decode_block(d2, s3, 128)
d4 = decode_block(d3, s2, 64)
d5 = decode_block(d4, s1, 32)

# Play around with activation
output = tf.keras.layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d5)


In [None]:
EPOCH = 150
LEARNING_RATE = 1e-4

#######################
# TODO BEFORE NEXT RUN:
# RENAME THE CSV
# SAVE MODELS
#######################

lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_dice_coeff',
    factor = 0.5,
    patience = 5,
    mode='max',
    verbose = 1
)

early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_dice_coeff',
    patience = 10,
    verbose = 1,
    mode='max',
    restore_best_weights=True
)

model1 = tf.keras.models.Model(inputs=input, outputs=output, name='U-Net')

model1.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=custom_loss1,
    metrics=['accuracy', dice_coeff, specificity, rand_index, jaccard_index])

csv_logger = tf.keras.callbacks.CSVLogger('model1_150epoch.csv', append=False)

history1 = model1.fit(
    train_dataset, 
    validation_data=test_dataset,
    epochs = EPOCH,
    callbacks=[csv_logger, lr_reducer, early_stop]
    )

model2 = tf.keras.models.Model(inputs=input, outputs=output, name='U-Net')

model2.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=custom_loss2,
    metrics=['accuracy', dice_coeff, specificity, rand_index, jaccard_index])

csv_logger = tf.keras.callbacks.CSVLogger('model2_150epoch.csv', append=False)

history2 = model2.fit(
    train_dataset, 
    validation_data=test_dataset,
    epochs = EPOCH,
    callbacks=[csv_logger, lr_reducer, early_stop]
    )

model1.save("model1_good_150epoch.keras")
model2.save("model2_good_150epoch.keras")

Epoch 1/20
     75/Unknown [1m145s[0m 2s/step - accuracy: 0.9871 - dice_coeff: 0.1926 - jaccard_index: 0.6633 - loss: 40.4197 - rand_index: 0.0031 - specificity: 0.6636



[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m156s[0m 2s/step - accuracy: 0.9918 - dice_coeff: 0.2489 - jaccard_index: 0.6783 - loss: 37.3363 - rand_index: 0.0032 - specificity: 0.6787 - val_accuracy: 0.9843 - val_dice_coeff: 0.1111 - val_jaccard_index: 0.5677 - val_loss: 44.0321 - val_rand_index: 0.0028 - val_specificity: 0.5680 - learning_rate: 1.0000e-04
Epoch 2/20
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 2s/step - accuracy: 0.9976 - dice_coeff: 0.3318 - jaccard_index: 0.7514 - loss: 33.0875 - rand_index: 0.0038 - specificity: 0.7520 - val_accuracy: 0.9860 - val_dice_coeff: 0.1760 - val_jaccard_index: 0.6367 - val_loss: 40.4123 - val_rand_index: 0.0041 - val_specificity: 0.6369 - learning_rate: 1.0000e-04
Epoch 3/20
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 2s/step - accuracy: 0.9980 - dice_coeff: 0.3527 - jaccard_index: 0.7815 - loss: 32.0164 - rand_index: 0.0043 - specificity: 0.7822 - val_accuracy: 0.9937 - val_dice_coe

In [None]:
model1.save("model1_good_150epoch.keras")
model2.save("model2_good_150epoch.keras")