#**Problem 2: 3D medical image segmentation**#
The dataset is the The Atrial
Segmentation Challenge dataset, including 14 training images and 20 test images

#### **Install Libraries**

In [None]:
!pip install numpy h5py tensorflow keras itk tqdm medpy

#### **Mount Drive if needed**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#### **Define the dataloader**

In [70]:
import numpy as np
import h5py
import os
import random

class DataLoader:
    def __init__(self, directory, batch_size=4, crop_size=(112, 112, 80), shuffle=True, random_crop=True):
        self.directory = directory
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.shuffle = shuffle
        self.images, self.masks = self.load_data()
        self.cropped_images, self.cropped_masks = self.crop_images_and_masks()
        self.indexes = np.arange(len(self.cropped_images))
        self.random_crop = random_crop

    def __len__(self):
      return len(self.masks)//self.batch_size +1

    def load_data(self):
        images = []
        masks = []

        for filename in os.listdir(self.directory):
            if filename.endswith('.h5'):
                file_path = os.path.join(self.directory, filename)
                with h5py.File(file_path, 'r') as f:
                    img_data = np.array(f['image'])  # Load images
                    mask_data = np.array(f['label'])   # Load masks

                    if img_data.shape[0] != mask_data.shape[0]:
                        raise ValueError(f'Mismatched samples in {filename}')

                    images.append(img_data)
                    masks.append(mask_data)

        return images, masks

    def crop_images_and_masks(self):
        cropped_images = []
        cropped_masks = []

        for i in range(len(self.images)):
            image, mask = self.random_crop(self.images[i], self.masks[i])
            cropped_images.append(image)
            cropped_masks.append(mask)

        return np.array(cropped_images), np.array(cropped_masks)

    def random_crop(self, image, mask):
        z, y, x = image.shape
        cz, cy, cx = self.crop_size
        if self.random_crop == True:
          # Do random crop for train_loader
          start_z = random.randint(0, z - cz)
          start_y = random.randint(0, y - cy)
          start_x = random.randint(0, x - cx)
        else:
          # Do fixed middle crop for test loader
          start_z = (z - cz) // 2
          start_y = (y - cy) // 2
          start_x = (x - cx) // 2
        cropped_image = image[start_z:start_z + cz, start_y:start_y + cy, start_x:start_x + cx]
        cropped_mask = mask[start_z:start_z + cz, start_y:start_y + cy, start_x:start_x + cx]
        return cropped_image, cropped_mask

    def __iter__(self):
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
        for start in range(0, len(self.cropped_images), self.batch_size):
            end = min(start + self.batch_size, len(self.cropped_images))
            yield self.cropped_images[self.indexes[start:end]], self.cropped_masks[self.indexes[start:end]]

# Example usage

# Batch Size is kept to 1 for all train loader to prevent any discrepanices from floating point precision and maintain
# consistency among the testing set for all configurations

#### **Define Model**

In [4]:
from keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate
from keras.models import Model

def unet_3d(input_size=(112, 112, 80, 1)):
    inputs = Input(input_size)
    c1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling3D((2, 2, 2))(c1)

    c2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(p1)
    c2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling3D((2, 2, 2))(c2)

    c3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(p2)
    c3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling3D((2, 2, 2))(c3)

    c4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(p3)
    c4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling3D((2, 2, 2))(c4)

    c5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(p4)
    c5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(c5)

    u6 = UpSampling3D((2, 2, 2))(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(u6)
    c6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(c6)

    u7 = UpSampling3D((2, 2, 2))(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(u7)
    c7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c7)

    u8 = UpSampling3D((2, 2, 2))(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(u8)
    c8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c8)

    u9 = UpSampling3D((2, 2, 2))(c8)
    u9 = concatenate([u9, c1])
    c9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(u9)
    c9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c9)

    outputs = Conv3D(1, (1, 1, 1), activation='sigmoid')(c9)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model


#### **CONFIGURATION 1**

In [None]:
# CONFIGURATION 1

# First Training configuration dataloader
train_loader = DataLoader('/content/drive/MyDrive/train', batch_size=4, shuffle=False, random_crop=False)
test_loader = DataLoader('/content/drive/MyDrive/test', batch_size=1, shuffle=False, random_crop=False)

from tensorflow.keras import backend as K
from tqdm import tqdm
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np

losses = []
# Training loop
def train_model(model, data_loader, num_epochs=50):
    for epoch in range(num_epochs):
        epoch_losses = []
        train_loop = tqdm(enumerate(data_loader), total=len(data_loader), leave= False)
        for batch_idx, (batch_images, batch_masks) in train_loop:
            # Reshape images to add the channel dimension
            batch_images = tf.convert_to_tensor(batch_images[..., np.newaxis])  # Shape: (batch_size, cz, cy, cx, 1)
            batch_masks = tf.convert_to_tensor(batch_masks[..., np.newaxis])
            loss = model.train_on_batch(batch_images, batch_masks)
            epoch_losses.append(loss[0])
            losses.append(loss[0])
            train_loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            train_loop.set_postfix(loss=loss[0])

        avg_loss = np.mean(epoch_losses)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}')


model = unet_3d()
model.compile(optimizer=Adam(learning_rate=0.001), loss= 'binary_crossentropy', metrics=['accuracy'])
# Call the training function with the original batch size
train_model(model, train_loader, num_epochs=50)

#### **CONFIGURATION 2 (BEST ONE)**

In [None]:
# CONFIGURATION 2

# Second Training configuration
train_loader = DataLoader('/content/drive/MyDrive/train', batch_size=2, shuffle=False, random_crop=False)
test_loader = DataLoader('/content/drive/MyDrive/test', batch_size=1, shuffle=False, random_crop=False)

from tensorflow.keras import backend as K
from tqdm import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

# Define Dice loss function with @tf.function
@tf.function
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

losses = []

# Training loop
def train_model(model, data_loader, num_epochs=75):
    for epoch in range(num_epochs):
        epoch_losses = []
        train_loop = tqdm(enumerate(data_loader), total=len(data_loader), leave=False)
        for batch_idx, (batch_images, batch_masks) in train_loop:
            # Reshape images to add the channel dimension
            batch_images = batch_images[..., np.newaxis]  # Shape: (batch_size, cz, cy, cx, 1)

            # Ensure batch_masks has the same number of channels
            batch_masks = batch_masks[..., np.newaxis]  # Shape: (batch_size, cz, cy, cx, 1)
            batch_images = tf.convert_to_tensor(batch_images)  # Shape: (batch_size, cz, cy, cx, 1)
            batch_masks = tf.convert_to_tensor(batch_masks)

            loss = model.train_on_batch(batch_images, batch_masks)
            epoch_losses.append(loss[0])
            losses.append(loss[0])
            train_loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            train_loop.set_postfix(loss=loss[0])

        avg_loss = np.mean(epoch_losses)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}')

# Model compilation
model = unet_3d()

model.compile(optimizer=Adam(learning_rate=0.00001), loss=dice_loss, metrics=['accuracy'])

# Call the training function with the original batch size
train_model(model, train_loader, num_epochs=75)

#### **CONFIGURATION 3**

In [None]:
# CONFIGURATION 3

# Third Training configuration
train_loader = DataLoader('/content/drive/MyDrive/train', batch_size=2, shuffle=False, random_crop=False)
test_loader = DataLoader('/content/drive/MyDrive/test', batch_size=1, shuffle=False, random_crop=False)

import numpy as np
import random
from scipy.ndimage import rotate
from tqdm import tqdm
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
import tensorflow as tf

# Define Dice loss function
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# Function to perform random rotation
def random_rotate(images, masks):
    angle = random.uniform(-20, 20)  # Random angle between -20 and 20 degrees
    rotated_images = np.array([rotate(image, angle, reshape=False) for image in images])
    rotated_masks = np.array([rotate(mask, angle, reshape=False) for mask in masks])
    return rotated_images, rotated_masks

# Function to perform random flipping
def random_flip(images, masks):
    if random.random() > 0.5:  # Flip with 50% probability
        flipped_images = np.flip(images, axis=2)  # Flip along the depth axis
        flipped_masks = np.flip(masks, axis=2)
        return flipped_images, flipped_masks
    return images, masks
losses = []
# Update your training loop to use the manual augmentation
def train_model_with_manual_augmentation(model, data_loader, num_epochs=50, batch_size=8):
    for epoch in range(num_epochs):
        epoch_losses = []
        train_loop = tqdm(enumerate(data_loader), total=len(data_loader), leave=False)

        # Iterate through the data loader
        for batch_idx, (batch_images, batch_masks) in train_loop:
            # Apply manual augmentations
            batch_images, batch_masks = random_rotate(batch_images, batch_masks)
            batch_images, batch_masks = random_flip(batch_images, batch_masks)

            # Convert to tensors
            batch_images_tensor = tf.convert_to_tensor(batch_images, dtype=tf.float32)
            batch_masks_tensor = tf.convert_to_tensor(batch_masks, dtype=tf.float32)

            # Train on the augmented batch
            loss = model.train_on_batch(batch_images[..., np.newaxis], batch_masks[..., np.newaxis])
            epoch_losses.append(loss[0])
            losses.append(loss[0])
            train_loop.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
            train_loop.set_postfix(loss=loss[0])

        avg_loss = np.mean(epoch_losses)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss}')

# Call the training function with the new batch size
model = unet_3d()
model.compile(optimizer=Adam(learning_rate=0.00001), loss=dice_loss, metrics=['accuracy'])
train_model_with_manual_augmentation(model, train_loader, num_epochs=75, batch_size=2)

#### **Plot training loss**

In [None]:
import matplotlib.pyplot as plt

train_index = []
val_index = []
num_epochs = 75 # Adjust number of epochs as needed

train_batch = 7 # Adjust as needed (number of training samples)/(batch_size) = 14/(batch_size) {if decimal then round to the next number}


for i in range(1,num_epochs*train_batch + 1):
  train_index.append(i)

# Create a figure with subplots
plt.figure(figsize=(12, 5))

# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(train_index, losses, label='Training Loss', color='blue')
plt.title('Loss per batch')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.legend()
plt.grid()

# Show the plots
plt.tight_layout()
plt.show()

#### **Evaluation metrics**

In [30]:
# Run this if you want to load any saved weights
# model = unet_3d()
# model.load_weights('/content/model_config2.weights.h5')

In [None]:
import numpy as np
from sklearn.metrics import jaccard_score
from scipy.spatial.distance import directed_hausdorff
from medpy.metric.binary import assd
from medpy.metric.binary import dc, hd95, jc

def evaluate_model(model, data_loader):
    all_predictions = []
    all_masks = []
    for test_images, test_masks in data_loader:
        # Reshape test images for prediction
        test_images_reshaped = test_images[..., np.newaxis]  # Add channel dimension

        # Make predictions

        predictions = model.predict(test_images_reshaped)

        # Store predictions and masks
        all_predictions.append(predictions)
        all_masks.append(test_masks[..., np.newaxis])

    # Concatenate all predictions and masks
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_predictions = (all_predictions > 0.5).astype(np.uint8)
    all_masks = np.concatenate(all_masks, axis=0)

    # Calculate metrics
    dice_coeff = np.mean([dc(pred,mask) for pred, mask in zip(all_predictions, all_masks)])
    # jaccard = jaccard_score(all_masks.flatten(), (all_predictions > 0.5).astype(int).flatten())
    jaccard = np.mean([jc(pred, mask) for pred, mask in zip(all_predictions, all_masks)])

    # # Calculate ASD and 95HD
    asd_list = []
    for pred, mask in zip(all_predictions, all_masks):
        if np.count_nonzero(pred) == 0 or np.count_nonzero(mask) == 0:
            print("Skipping empty prediction or mask")
            continue
        asd = assd(pred, mask)
        asd_list.append(asd)
    hd_list = []
    for pred, mask in zip(all_predictions, all_masks):
        if np.count_nonzero(pred) == 0 or np.count_nonzero(mask) == 0:
            print("Skipping empty prediction or mask")
            continue
        hd = hd95(pred, mask)
        hd_list.append(hd)
    avg_asd = np.mean(asd_list)
    avg_hd = np.mean(hd_list)  # 95% Hausdorff Distance

    return dice_coeff, jaccard, avg_asd, avg_hd

# Create a DataLoader for test data
# test_loader = DataLoader('./datas/test/', batch_size=4)

# Evaluate the model using the DataLoader
dice_coeff, jaccard, avg_asd, avg_hd = evaluate_model(model, test_loader)

# evaluate_model(model, test_loader)
print(f'Dice Coefficient: {dice_coeff}, Jaccard Index: {jaccard}, ASD: {avg_asd}, 95HD: {avg_hd}')

#### **Plotting 2D segmentation results**

In [32]:
def obtain_testing_data(model, data_loader):
    all_predictions = []
    all_masks = []
    all_images = []

    for test_images, test_masks in data_loader:
        # Reshape test images for prediction
        test_images_reshaped = test_images[..., np.newaxis]  # Add channel dimension

        # Make predictions
        predictions = model.predict(test_images_reshaped)

        # Store predictions and masks
        all_predictions.append(predictions)
        all_masks.append(test_masks[..., np.newaxis])

        all_images.append(test_images_reshaped)

    # Concatenate all predictions and masks
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_masks = np.concatenate(all_masks, axis=0)
    all_images = np.concatenate(all_images, axis=0)
    all_predictions = (all_predictions > 0.5).astype(np.uint8)

    return all_predictions, np.array(all_masks), np.array(all_images)

In [None]:
predictions, ground_truth, original_images = obtain_testing_data(model, test_loader)

In [None]:
# All of them should (20, 112, 112, 80, 1)
print(predictions.shape)
print(ground_truth.shape)
print(original_images.shape)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# shapes:
# original_images: (num_samples, depth, height, width, channels)
# predictions: (num_samples, depth, height, width, channels) with binary masks
# ground_truth: (num_samples, depth, height, width, channels) with binary masks

# Select the sample (you can change the index to visualize other samples)
sample_index = 1
original_image = original_images[sample_index]
predicted_mask = predictions[sample_index]
true_mask = ground_truth[sample_index]

# Choose slices to display (e.g., specific slices)

slice_indices = [0, 20, 35, 39, 41, 52, 61, 69] # Adjust as needed
# for i in range(10):
#   slice_indices.append(i+20)
# Create a figure for plotting
num_slices = len(slice_indices)
fig, axs = plt.subplots(num_slices, 4, figsize=(16, 8))

for i, slice_index in enumerate(slice_indices):
    # Squeeze to remove the last dimension (channels)
    original_slice = np.squeeze(original_image[:, :, slice_index, :])
    predicted_slice = np.squeeze(predicted_mask[:, :, slice_index, :])
    true_slice = np.squeeze(true_mask[:, :, slice_index, :])


    # Display the original image slice
    axs[i, 0].imshow(original_slice, cmap='gray')
    axs[i, 0].set_title(f'Original Image Slice {slice_index}')
    axs[i, 0].axis('off')

    # Display the predicted mask slice
    axs[i, 1].imshow(predicted_slice, cmap='gray')
    axs[i, 1].set_title(f'Predicted Mask {slice_index}')
    axs[i, 1].axis('off')

    # Display the ground truth mask slice
    axs[i, 2].imshow(true_slice, cmap='gray')
    axs[i, 2].set_title(f'Ground Truth Mask {slice_index}')
    axs[i, 2].axis('off')

    # Overlay prediction on the original image for comparison
    overlay = np.maximum(predicted_slice, true_slice)
    axs[i, 3].imshow(overlay, cmap='gray')
    axs[i, 3].set_title(f'Overlay Slice {slice_index}')
    axs[i, 3].axis('off')

plt.tight_layout()
plt.show()

#### **Save images, labels, and predictions**

In [38]:
import nibabel as nib

def save_nifti(image, filename):
    img = nib.Nifti1Image(image, np.eye(4))
    nib.save(img, filename)


for i in range(20):
  # Change file name and path accordingly
  save_nifti(np.squeeze(original_images[i]), f"/content/drive/MyDrive/results/image{i}.nii.gz") # Save Image
  save_nifti(np.squeeze(ground_truth[i]), f"/content/drive/MyDrive/results/label{i}.nii.gz") # Save Prediction
  save_nifti(np.squeeze(predictions[i]), f"/content/drive/MyDrive/results/prediction{i}.nii.gz") # Save Prediction
