In [1]:
import numpy as np
from matplotlib import pyplot as plt
import SimpleITK as sitk
import os

In [5]:
def register():
    output_dir = "../HaN-Seg Registered"

    for i in range(1, 43):
        case_num = f"{i:02d}"
        filepath_ct = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_IMG_CT.nrrd"
        filepath_mri = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_IMG_MR_T1.nrrd"
        ct = sitk.ReadImage(filepath_ct, sitk.sitkFloat32)
        mri = sitk.ReadImage(filepath_mri, sitk.sitkFloat32)
        
        # Registers the MRI image to the CT image
        mri_registered = register_helper(ct, mri)

        # Saves image to the HaN-Seg Registration Folder
        output_path = os.path.join(output_dir, f"MRI_Case_{case_num}.nrrd")
        sitk.WriteImage(mri_registered, output_path)

    return


# Add notation here
def register_helper(fixed_image, moving_image):

    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image,
        moving_image,
        sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY,
    )

    min_value = float(sitk.GetArrayViewFromImage(moving_image).min())
    moving_resampled = sitk.Resample(
        moving_image,
        fixed_image,
        initial_transform,
        sitk.sitkLinear,
        min_value,
        moving_image.GetPixelID(),
    )

    registration_method = sitk.ImageRegistrationMethod()

    # Similarity metric settings.
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    registration_method.SetInterpolator(sitk.sitkLinear)

    # Optimizer settings.
    registration_method.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=100,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10,
    )
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework.
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Don't optimize in-place, we would possibly like to run this cell multiple times.
    registration_method.SetInitialTransform(initial_transform, inPlace=False)

    final_transform = registration_method.Execute(
        sitk.Cast(fixed_image, sitk.sitkFloat32), sitk.Cast(moving_image, sitk.sitkFloat32)
    )

    moving_resampled = sitk.Resample(
        moving_image,
        fixed_image,
        final_transform,
        sitk.sitkLinear,
        min_value,
        moving_image.GetPixelID(),
    )
    return moving_resampled

In [6]:
register()

In [30]:
def pruneVolume():
    output_dir = "../HaN-Seg Pruned"
    for i in range(1, 43):
        case_num = f"{i:02d}"
        filepath_mri = f"../HaN-Seg Registered/MRI_Case_{case_num}.nrrd"
        mri = sitk.ReadImage(filepath_mri, sitk.sitkFloat32)
        mri = sitk.GetArrayFromImage(mri).astype(np.float32)

        filepath_ct = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_IMG_CT.nrrd"
        ct = sitk.ReadImage(filepath_ct, sitk.sitkFloat32)
        ct = sitk.GetArrayFromImage(ct).astype(np.float32)

        filepath = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_OAR_Bone_Mandible.seg.nrrd"
        gt = sitk.ReadImage(filepath, sitk.sitkFloat32)
        gt = sitk.GetArrayFromImage(gt).astype(np.float32)

        gt_var = np.var(gt, axis = (1, 2))
        # mask = gt_var > 1e-10
        mask = gt_var > 0.0001
        mri = mri[mask]
        mri = sitk.GetImageFromArray(mri)
        ct = ct[mask]
        ct = sitk.GetImageFromArray(ct)
        gt = gt[mask]
        gt = sitk.GetImageFromArray(gt)

        output_path = os.path.join(output_dir, "MRI", f"MRI_Case_{case_num}.nrrd")
        sitk.WriteImage(mri, output_path)

        output_path = os.path.join(output_dir, "CT", f"CT_Case_{case_num}.nrrd")
        sitk.WriteImage(ct, output_path)

        output_path = os.path.join(output_dir, "GT", f"GT_Case_{case_num}.nrrd")
        sitk.WriteImage(gt, output_path)

In [31]:
pruneVolume()

In [None]:
def histEqual():
    output_dir = "../HaN-Seg Hist Eq"
    for i in range(1, 43):
        case_num = f"{i:02d}"
        filepath_mri = f"../HaN-Seg Pruned/MRI/MRI_Case_{case_num}.nrrd"
        mri = sitk.ReadImage(filepath_mri, sitk.sitkFloat32)
        mri_arr = sitk.GetArrayFromImage(mri)
        np.linspace(min(mri_arr), max(mri_arr), )

        matcher = sitk.HistogramMatchingImageFilter()
        matcher.SetNumberOfHistogramLevels(256)
        matcher.SetNumberOfMatchPoints(256)
        
                
        mri = matcher.Execute(mri, mri)

        output_path = os.path.join(output_dir, f"MRI_Case_{case_num}.nrrd")
        sitk.WriteImage(mri, output_path)

In [97]:
histEqual()

In [32]:
# Check what you're actually working with
for i in range(1, 43):
    case_num = f"{i:02d}"
    filepath = f"../HaN-Seg/set_1/case_{case_num}/case_{case_num}_OAR_Bone_Mandible.seg.nrrd"
    gt = sitk.ReadImage(filepath, sitk.sitkFloat32)
    gt = sitk.GetArrayFromImage(gt).astype(np.float32)
    
    total_slices = gt.shape[0]
    slices_with_mandible = np.sum(np.sum(gt, axis=(1,2)) > 0)
    total_mandible_pixels = np.sum(gt)
    
    print(f"Case {case_num}:")
    print(f"  Total slices: {total_slices}")
    print(f"  Slices with mandible: {slices_with_mandible}")
    print(f"  Mandible pixel ratio: {total_mandible_pixels / gt.size * 100:.4f}%")
    print(f"  Variance range: {np.var(gt, axis=(1,2)).min():.6f} - {np.var(gt, axis=(1,2)).max():.6f}")
    print()

Case 01:
  Total slices: 202
  Slices with mandible: 43
  Mandible pixel ratio: 0.0620%
  Variance range: 0.000000 - 0.007484

Case 02:
  Total slices: 204
  Slices with mandible: 43
  Mandible pixel ratio: 0.0338%
  Variance range: 0.000000 - 0.003865

Case 03:
  Total slices: 194
  Slices with mandible: 42
  Mandible pixel ratio: 0.0308%
  Variance range: 0.000000 - 0.003596

Case 04:
  Total slices: 184
  Slices with mandible: 44
  Mandible pixel ratio: 0.0292%
  Variance range: 0.000000 - 0.003253

Case 05:
  Total slices: 182
  Slices with mandible: 34
  Mandible pixel ratio: 0.0444%
  Variance range: 0.000000 - 0.004947

Case 06:
  Total slices: 141
  Slices with mandible: 29
  Mandible pixel ratio: 0.0322%
  Variance range: 0.000000 - 0.003588

Case 07:
  Total slices: 132
  Slices with mandible: 24
  Mandible pixel ratio: 0.0473%
  Variance range: 0.000000 - 0.006957

Case 08:
  Total slices: 135
  Slices with mandible: 28
  Mandible pixel ratio: 0.0581%
  Variance range: 0.000

In [52]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split

BATCH_SIZE = 16
BUFFER_SIZE = 1000
RANDOM_SEED = 42
IMAGE_RESIZE = 256

In [35]:
def normalize(img_path, mean, std, output):
    for i in range(len(img_path)):
        img = sitk.ReadImage(img_path[i], sitk.sitkFloat32)
        img = sitk.GetArrayFromImage(img).astype(np.float32)

        normalized = (img - mean) / std
        normalized = sitk.GetImageFromArray(normalized)

        output_path = output + f"{i}.nrrd"
        sitk.WriteImage(normalized, output_path)

In [53]:
def computeNormal(img_path):
    img_sum = 0.0
    img_sq_sum = 0.0
    total = 0

    for path in img_path:
        img = sitk.ReadImage(path, sitk.sitkFloat32)
        img = sitk.GetArrayFromImage(img).astype(np.float32)
        img_sum += np.sum(img)
        img_sq_sum += np.sum(img ** 2)
        total += img.size

    train_mean = img_sum/total
    train_std = np.sqrt(img_sq_sum/total - (train_mean**2))

    return train_mean, train_std       


In [37]:
def load_images(mri_path, ct_path, mask_path):
    mri_path = mri_path.numpy().decode("utf-8")
    ct_path = ct_path.numpy().decode("utf-8")
    mask_path = mask_path.numpy().decode("utf-8")

    mri = sitk.ReadImage(mri_path, sitk.sitkFloat32)
    ct = sitk.ReadImage(ct_path, sitk.sitkFloat32)
    gt = sitk.ReadImage(mask_path, sitk.sitkFloat32)

    slices_mri = sitk.GetArrayFromImage(mri).astype(np.float32)
    slices_ct = sitk.GetArrayFromImage(ct).astype(np.float32)
    slices_gt = sitk.GetArrayFromImage(gt).astype(np.float32)

    slices_mri_resize = np.expand_dims(slices_mri, axis=-1)
    slices_ct_resize = np.expand_dims(slices_ct, axis=-1)
    slices_gt_resize = np.expand_dims(slices_gt, axis=-1)

    slices_mri_resize = tf.image.resize(slices_mri_resize, [IMAGE_RESIZE, IMAGE_RESIZE], method=tf.image.ResizeMethod.BILINEAR)
    slices_ct_resize = tf.image.resize(slices_ct_resize, [IMAGE_RESIZE, IMAGE_RESIZE], method=tf.image.ResizeMethod.BILINEAR)
    slices_gt_resize = tf.image.resize(slices_gt_resize, [IMAGE_RESIZE, IMAGE_RESIZE], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    slices = np.concatenate([slices_mri_resize, slices_ct_resize], axis=-1)

    return slices, slices_gt_resize

def lammy_func(mri_path, ct_path, mask_path):
    img, mask = tf.py_function(
        func=load_images,
        inp=[mri_path, ct_path, mask_path],
        Tout=[tf.float32, tf.float32]
    )
    img.set_shape([None, IMAGE_RESIZE, IMAGE_RESIZE, 2])
    mask.set_shape([None, IMAGE_RESIZE, IMAGE_RESIZE, 1])
    return img, mask

In [65]:
mri_path = []
ct_path = []
gt_path = []

# mri_train_mean, mri_train_sd, ct_train_mean, ct_train_sd = normalize()

for i in range(1, 43):
    case_num = f"{i:02d}"
    filepath_mri = f"../HaN-Seg Pruned/MRI/MRI_Case_{case_num}.nrrd"
    mri_path.append(filepath_mri)
    filepath_ct = f"../HaN-Seg Pruned/CT/CT_Case_{case_num}.nrrd"
    ct_path.append(filepath_ct)
    filepath_gt = f"../HaN-Seg Pruned/GT/GT_Case_{case_num}.nrrd"
    gt_path.append(filepath_gt)

mri_train, mri_test, ct_train, ct_test, gt_train, gt_test = train_test_split(mri_path, ct_path, gt_path, test_size=0.2, random_state=RANDOM_SEED)

#################################################
# Use only first 5-10 cases for quick experiments
mri_train = mri_train[:10]
ct_train = ct_train[:10]
gt_train = gt_train[:10]
#################################################

mri_mean, mri_std = computeNormal(mri_train)
ct_mean, ct_std = computeNormal(ct_train)

train_dataset = tf.data.Dataset.from_tensor_slices((mri_train, ct_train, gt_train))

normalize(mri_train, mri_mean, mri_std, "../HaN-Seg Pruned/Train/MRI/")
normalize(ct_train, ct_mean, ct_std, "../HaN-Seg Pruned/Train/CT/")

train_dataset = (
    train_dataset
    .map(lammy_func, num_parallel_calls=tf.data.AUTOTUNE)
    .unbatch()
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
    )

normalize(mri_test, mri_mean, mri_std, "../HaN-Seg Pruned/Test/MRI/")
normalize(ct_test, ct_mean, ct_std, "../HaN-Seg Pruned/Test/CT/")

test_dataset = tf.data.Dataset.from_tensor_slices((mri_test, ct_test, gt_test))

test_dataset = (
    test_dataset
    .map(lammy_func, num_parallel_calls=tf.data.AUTOTUNE)
    .unbatch()
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
    )

In [73]:
## A single Encoding step in the Contracting path of a U-Net CNN
## @Inputs:
##       inputs: image of size (nxn) with k feature channels
##       num_channels: number of channels to have in output image (i.e. depth of output tensor)
## @Outputs: 
##       x: image of size (n/2 x n/2) with num_channels feature channels
def encode_block(inputs, num_channels):
    # Extract num_channels feature channels from image
    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    skip = tf.keras.layers.Activation('relu')(x)

    # Downsample each channels feature map by a factor of 2
    x = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=2)(skip)

    return skip, x

## A single Decoding step in the Expanding path of a U-Net CNN
## @Inputs:
##       inputs: image of size (nxn) with k feature channels
##       skip_connection: tensor of corresponding encoding block
##       num_channels: number of channels to have in output image (i.e. depth of output tensor)
## @Outputs: 
##       x: image of size (2nx2n) with num_channels feature channels
def decode_block(inputs, skip_connection, num_channels):
    # Upsample image by doubling feature space while changing feature channels to num_channels
    x = tf.keras.layers.Conv2DTranspose(num_channels, (2,2), strides=2, padding='same')(inputs)

    # Concatenate the skip_channel and the upsampled image (doubles the feature channels)
    # Might need to resize skip_connection, but should be fine b/c same padding in encoding
    x = tf.keras.layers.Concatenate()([x, skip_connection])
    
    # Merge feature channels from the skip_connection and upsampled input image
    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(num_channels, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    return x

# Metrics
# Source: https://medium.com/mastering-data-science/understanding-evaluation-metrics-in-medical-image-segmentation-d289a373a3f
def dice_coeff(y_true, y_pred):
    epsilon = 1e-6
    # threshold = 0.9
    # y_pred = tf.cast(y_pred > threshold, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))
    return (2*tp+epsilon)/(2*tp + fp + fn+epsilon)

def rand_index(y_true, y_pred):
    y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))
    return tp / (tp + fp + fn)

def jaccard_index(y_true, y_pred):
    y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))
    tn = tf.reduce_sum(tf.cast((1 - y_true) * (1 - y_pred), tf.float32))
    return (tp + tn) / (tp + tn + fn + fp)

def specificity(y_true, y_pred):
    y_pred = tf.cast(y_pred, tf.float32)
    true_negatives = tf.reduce_sum(tf.cast((1 - y_true) * (1 - y_pred), tf.float32))
    possible_negatives = tf.reduce_sum(tf.cast(1 - y_true, tf.float32))
    return true_negatives / (possible_negatives + tf.keras.backend.epsilon())

In [74]:
def custom_loss(y_true, y_pred):
    dice_weight = 10.0
    bce_weight = 1.0
    bce = tf.keras.losses.BinaryFocalCrossentropy(alpha=0.9)
    loss_dice = 1.0 - dice_coeff(y_true, y_pred)
    loss_bce = bce(y_true, y_pred)
    loss = dice_weight * loss_dice + bce_weight * loss_bce
    return loss

In [75]:
## Defining the model

input = tf.keras.layers.Input(shape=(IMAGE_RESIZE, IMAGE_RESIZE, 2))

# Do 5 calls of encode_block to end up with a 32x32x512 tensor
s1, e1 = encode_block(input, 32)
s2, e2 = encode_block(e1, 64)
s3, e3 = encode_block(e2, 128)
s4, e4 = encode_block(e3, 256)
# s5, e5 = encode_block(e4, 512)

# Bottleneck
# b1 = tf.keras.layers.Conv2D(1024, 3, padding='same')(e5)
# b1 = tf.keras.layers.Activation('relu')(b1)
# b1 = tf.keras.layers.Conv2D(1024, 3, padding='same')(b1)
# b1 = tf.keras.layers.Activation('relu')(b1)

b1 = tf.keras.layers.Conv2D(512, 3, padding='same')(e4)
b1 = tf.keras.layers.Activation('relu')(b1)
b1 = tf.keras.layers.Conv2D(512, 3, padding='same')(b1)
b1 = tf.keras.layers.Activation('relu')(b1)

# Do 5 calls of decode_block
# d1 = decode_block(b1, s5, 512)
d2 = decode_block(b1, s4, 256)
d3 = decode_block(d2, s3, 128)
d4 = decode_block(d3, s2, 64)
d5 = decode_block(d4, s1, 32)

# Play around with activation
output = tf.keras.layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d5)

model = tf.keras.models.Model(inputs=input, outputs=output, name='U-Net')

model.compile(
    optimizer = 'adam',
    loss=custom_loss,
    metrics=['accuracy', dice_coeff, specificity, rand_index, jaccard_index]
)

In [76]:
EPOCH = 10
model.fit(
    train_dataset, 
    validation_data=test_dataset,
    epochs = EPOCH
    )

Epoch 1/10
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 2s/step - accuracy: 0.9268 - dice_coeff: 0.0092 - jaccard_index: 0.6206 - loss: 9.9916 - rand_index: 0.0046 - specificity: 0.6202 - val_accuracy: 0.9946 - val_dice_coeff: 2.3035e-10 - val_jaccard_index: 0.9448 - val_loss: 12.7854 - val_rand_index: 0.0000e+00 - val_specificity: 0.9470
Epoch 2/10
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 2s/step - accuracy: 0.9870 - dice_coeff: 0.0167 - jaccard_index: 0.7580 - loss: 9.8473 - rand_index: 0.0084 - specificity: 0.7577 - val_accuracy: 0.9971 - val_dice_coeff: 4.0234e-10 - val_jaccard_index: 0.9469 - val_loss: 11.0747 - val_rand_index: 0.0000e+00 - val_specificity: 0.9491
Epoch 3/10
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 2s/step - accuracy: 0.9913 - dice_coeff: 0.0254 - jaccard_index: 0.8244 - loss: 9.7464 - rand_index: 0.0129 - specificity: 0.8242 - val_accuracy: 0.9960 - val_dice_coeff: 1.2032e-10 - val_jaccard_index:

<keras.src.callbacks.history.History at 0x213ccfd9a00>

# STUFF BELOW FROM AI FOR TESTING DO NOT KEEP

In [77]:
print("\nChecking model predictions...")
for img, mask in train_dataset.take(3):
    pred = model.predict(img, verbose=0)
    
    for i in range(min(8, img.shape[0])):
        mask_sum = tf.reduce_sum(mask[i]).numpy()
        pred_mean = np.mean(pred[i])
        pred_max = np.max(pred[i])
        pred_sum_thresholded = np.sum(pred[i] > 0.5)
        
        print(f"Slice {i}: GT={mask_sum:.0f} pixels | "
              f"Pred mean={pred_mean:.6f} | "
              f"Pred max={pred_max:.6f} | "
              f"Pred>0.5={pred_sum_thresholded:.0f}")
    print()


Checking model predictions...
Slice 0: GT=105 pixels | Pred mean=0.001328 | Pred max=0.732827 | Pred>0.5=2
Slice 1: GT=295 pixels | Pred mean=0.001283 | Pred max=0.152752 | Pred>0.5=0
Slice 2: GT=410 pixels | Pred mean=0.001320 | Pred max=0.775415 | Pred>0.5=3
Slice 3: GT=453 pixels | Pred mean=0.001275 | Pred max=0.268983 | Pred>0.5=0
Slice 4: GT=459 pixels | Pred mean=0.001249 | Pred max=0.073918 | Pred>0.5=0
Slice 5: GT=450 pixels | Pred mean=0.001240 | Pred max=0.144696 | Pred>0.5=0
Slice 6: GT=426 pixels | Pred mean=0.001232 | Pred max=0.020290 | Pred>0.5=0
Slice 7: GT=368 pixels | Pred mean=0.001230 | Pred max=0.005894 | Pred>0.5=0

Slice 0: GT=47 pixels | Pred mean=0.001232 | Pred max=0.012042 | Pred>0.5=0
Slice 1: GT=35 pixels | Pred mean=0.001232 | Pred max=0.121998 | Pred>0.5=0
Slice 2: GT=17 pixels | Pred mean=0.001227 | Pred max=0.064439 | Pred>0.5=0
Slice 3: GT=12 pixels | Pred mean=0.001225 | Pred max=0.189278 | Pred>0.5=0
Slice 4: GT=9 pixels | Pred mean=0.001219 | Pred

In [63]:
# Test different thresholds
for threshold in [0.5, 0.6, 0.7, 0.8, 0.9]:
    pred_thresholded = (pred > threshold).astype(float)
    dice = dice_coeff(mask, pred_thresholded)
    print(f"Threshold {threshold}: Dice = {dice}")
    

Threshold 0.5: Dice = 0.45445767045021057
Threshold 0.6: Dice = 0.4636862277984619
Threshold 0.7: Dice = 0.4714776575565338
Threshold 0.8: Dice = 0.48029953241348267
Threshold 0.9: Dice = 0.49106302857398987
