<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/hippocampus_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Install required libraries
!pip install montage

# Import libraries
import os
import random
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from skimage.io import imread, imshow
from skimage.transform import resize
from skimage.util.montage import montage2d
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K

# Unzip the dataset
!unzip -q "/content/hippocampus segmentation dataset.zip" -d "/content/hippocampus_dataset"

# Constants
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

# Seeding
seed = 42
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

# Define dataset paths (adjust based on unzipped structure)
data = "/content/hippocampus_dataset/seghipp0/images"
data_left = "/content/hippocampus_dataset/seghipp0/masks/left"
data_right = "/content/hippocampus_dataset/seghipp0/masks/right"

# Load image and mask file paths
train_data = []
for dirName, _, fileList in sorted(os.walk(data)):
    for filename in fileList:
        if ".jpg" in filename.lower():
            train_data.append(os.path.join(dirName, filename))

mask_left = []
for dirName, _, fileList in sorted(os.walk(data_left)):
    for filename in fileList:
        if ".jpg" in filename.lower():
            mask_left.append(os.path.join(dirName, filename))

mask_right = []
for dirName, _, fileList in sorted(os.walk(data_right)):
    for filename in fileList:
        if ".jpg" in filename.lower():
            mask_right.append(os.path.join(dirName, filename))

# Initialize arrays for training data
X_train = np.zeros((len(train_data), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.float32)
Y_train = np.zeros((len(train_data), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.float32)

# Load and preprocess images
for file_index in tqdm(range(len(train_data)), desc="Loading images"):
    img = imread(train_data[file_index])
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), mode='constant', preserve_range=True)
    img = img / 255.0
    X_train[file_index] = img

# Load and preprocess masks
for n in tqdm(range(len(mask_right)), desc="Loading masks"):
    maskl = imread(mask_left[n], as_gray=True)
    maskr = imread(mask_right[n], as_gray=True)
    mask = np.maximum(maskl, maskr)
    mask = resize(mask, (IMG_HEIGHT, IMG_WIDTH, 1), mode='constant', preserve_range=True)
    mask = mask / 255.0
    Y_train[n] = mask

# Visualize a sample
id = 10
print(X_train[id].shape)
plt.figure()
imshow(X_train[id])
plt.title("Sample Image")
plt.show()
plt.figure()
imshow(Y_train[id][:, :, 0], cmap='gray')
plt.title("Sample Mask")
plt.show()

# Shuffle and split data
X_train, Y_train = shuffle(X_train, Y_train, random_state=42)
X_train, X_test, Y_train, Y_test = train_test_split(X_train, Y_train, test_size=0.1, random_state=42)

# Visualize random sample
image_x = random.randint(0, len(X_train) - 1)
fig, ax = plt.subplots(1, 3, figsize=(16, 12))
ax[0].imshow(X_train[image_x], cmap='gray')
ax[0].set_title("Image")
ax[1].imshow(np.squeeze(Y_train[image_x]), cmap='gray')
ax[1].set_title("Mask")
ax[2].imshow(X_train[image_x], cmap='gray', interpolation='none')
ax[2].imshow(np.squeeze(Y_train[image_x]), cmap='jet', interpolation='none', alpha=0.7)
ax[2].set_title("Overlay")
plt.show()

# Montage visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(montage2d(X_train[:, :, :, 0]), cmap='gray')
ax1.set_title('MRI Input Images Samples')
ax2.imshow(montage2d(Y_train[:, :, :, 0]), cmap='gray')
ax2.set_title('Ground Truth Masks Samples')
plt.show()

# Define Dice coefficient and loss
smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

# Build U-Net model
inputs = tf.keras.layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = tf.keras.layers.Lambda(lambda x: x / 255)(inputs)

# Contraction path
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = tf.keras.layers.Dropout(0.1)(c1)
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = tf.keras.layers.Dropout(0.1)(c2)
c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = tf.keras.layers.Dropout(0.2)(c3)
c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = tf.keras.layers.Dropout(0.2)(c4)
c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)

c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = tf.keras.layers.Dropout(0.3)(c5)
c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

# Expansive path
u6 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = tf.keras.layers.concatenate([u6, c4])
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
c6 = tf.keras.layers.Dropout(0.2)(c6)
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

u7 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
u7 = tf.keras.layers.concatenate([u7, c3])
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
c7 = tf.keras.layers.Dropout(0.2)(c7)
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

u8 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
u8 = tf.keras.layers.concatenate([u8, c2])
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
c8 = tf.keras.layers.Dropout(0.1)(c8)
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

u9 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
c9 = tf.keras.layers.Dropout(0.1)(c9)
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

# Compile model
model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef])
model.summary()

# Callbacks
checkpointer = tf.keras.callbacks.ModelCheckpoint('model_for_nuclei.h5', verbose=1, save_best_only=True)
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
    tf.keras.callbacks.TensorBoard(log_dir='logs'),
    checkpointer
]

# Train model
results = model.fit(X_train, Y_train, validation_split=0.1, shuffle=True, batch_size=16, epochs=10, callbacks=callbacks)

# Predictions
preds_train = model.predict(X_train[:int(X_train.shape[0] * 0.9)], verbose=1)
preds_val = model.predict(X_train[int(X_train.shape[0] * 0.9):], verbose=1)
preds_test = model.predict(X_test, verbose=1)

# Threshold predictions
preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_val_t = (preds_val > 0.5).astype(np.uint8)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

# Visualize test sample
id = random.randint(0, len(preds_test_t) - 1)
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(X_test[id], cmap='gray')
ax[0].set_title("X_test")
ax[0].axis('off')
ax[1].imshow(np.squeeze(Y_test[id]), cmap='gray')
ax[1].set_title("Y_test")
ax[1].axis('off')
ax[2].imshow(np.squeeze(preds_test_t[id]), cmap='gray')
ax[2].set_title("Prediction")
ax[2].axis('off')
plt.show()

# Visualize overlay
fig, ax = plt.subplots(1, 4, figsize=(16, 12))
ax[0].imshow(X_test[id], cmap='gray')
ax[0].set_title("Image")
ax[1].imshow(np.squeeze(preds_test_t[id]), cmap='gray')
ax[1].set_title("Predicted Mask")
ax[2].imshow(np.squeeze(Y_test[id]), cmap='gray')
ax[2].set_title("Ground Truth Mask")
ax[3].imshow(X_test[id], cmap='gray', interpolation='none')
ax[3].imshow(np.squeeze(preds_test_t[id]), cmap='jet', interpolation='none', alpha=0.7)
ax[3].set_title("Overlay")
plt.show()

# Montage of test data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(montage2d(X_test[:, :, :, 0]), cmap='gray')
ax1.set_title('MRI Input Images Samples')
ax2.imshow(montage2d(Y_test[:, :, :, 0]), cmap='gray')
ax2.set_title('Ground Truth Masks Samples')
plt.show()

# Plot predictions function
def plotPredictions(a, b, c, d, e):
    model = e
    preds_train = model.predict(a[:int(a.shape[0] * 0.9)], verbose=1)
    preds_val = model.predict(a[int(a.shape[0] * 0.9):], verbose=1)
    preds_test = model.predict(c, verbose=1)
    preds_train_t = (preds_train > 0.5).astype(np.uint8)
    preds_val_t = (preds_val > 0.5).astype(np.uint8)
    preds_test_t = (preds_test > 0.5).astype(np.uint8)

    # Training sample
    ix = random.randint(0, len(preds_train_t) - 1)
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(a[ix], cmap='gray')
    ax[0].set_title("X_train")
    ax[0].axis('off')
    ax[1].imshow(np.squeeze(b[ix]), cmap='gray')
    ax[1].set_title("Y_train")
    ax[1].axis('off')
    ax[2].imshow(np.squeeze(preds_train_t[ix]), cmap='gray')
    ax[2].set_title("Prediction")
    ax[2].axis('off')
    plt.show()

    # Validation sample
    ix = random.randint(0, len(preds_val_t) - 1)
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(a[int(a.shape[0] * 0.9):][ix], cmap='gray')
    ax[0].set_title("X_val")
    ax[0].axis('off')
    ax[1].imshow(np.squeeze(b[int(b.shape[0] * 0.9):][ix]), cmap='gray')
    ax[1].set_title("Y_val")
    ax[1].axis('off')
    ax[2].imshow(np.squeeze(preds_val_t[ix]), cmap='gray')
    ax[2].set_title("Prediction")
    ax[2].axis('off')
    plt.show()

plotPredictions(X_train, Y_train, X_test, Y_test, model)

# Plot training history
acc = results.history.get('dice_coef', results.history.get('acc'))
val_acc = results.history.get('val_dice_coef', results.history.get('val_acc'))
loss = results.history['loss']
val_loss = results.history['val_loss']
epochs = range(len(acc))

plt.figure()
plt.plot(epochs, acc, 'r', label='Training Dice Coef')
plt.plot(epochs, val_acc, 'b', label='Validation Dice Coef')
plt.title('Training and Validation Dice Coefficient')
plt.ylabel('Dice Coefficient')
plt.xlabel('Epoch')
plt.legend()
plt.show()

plt.figure()
plt.plot(epochs, loss, 'r', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.show()

Collecting montage
  Downloading montage-0.3.6.tar.gz (2.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting decoupage>=0.8 (from montage)
  Downloading decoupage-0.16.0.tar.gz (14 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting cropresize (from montage)
  Downloading cropresize-0.2.0.tar.gz (3.0 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting WebOb (from decoupage>=0.8->montage)
  Downloading WebOb-1.8.9-py2.py3-none-any.whl.metadata (11 kB)
Collecting Paste (from decoupage>=0.8->montage)
  Downloading Paste-3.10.1-py3-none-any.whl.metadata (5.3 kB)
Collecting PasteScript (from decoupage>=0.8->montage)
  Downloading PasteScript-3.7.0-py2.py3-none-any.whl.metadata (3.0 kB)
Collecting genshi (from decoupage>=0.8->montage)
  Downloading Genshi-0.7.9-py3-none-any.whl.metadata (1.5 kB)
Collecting martINI>=0.4.2 (from decoupage>=0.8->montage)
  Downloading martINI-0.7.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdon

ModuleNotFoundError: No module named 'skimage.util.montage'