In [None]:
# Mount Google Drive to access files
from google.colab import drive
import os
drive.mount('/content/drive', force_remount=True)

# Change the current directory to the script folder in Google Drive
# **NOTE:** You may need to change the path below to match your specific script location
os.chdir('drive/My Drive/ML3D/2025')

# List the contents of the current directory to verify the change
os.listdir('.')

Mounted at /content/drive


['prediction',
 'test',
 'augmentation.zip',
 'augmentation',
 'model.keras',
 'train.ipynb']

In [None]:
# Unzip the augmentation data
# This assumes you have a zip file named 'augmentation.zip' in the current directory
!unzip augmentation.zip

In [None]:
# Import necessary libraries for building and training the 3D U-Net model
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger
from tensorflow.keras import backend as keras
from tensorflow.keras.initializers import *
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tifffile as tif
import glob
import cv2
from skimage.transform import resize

In [None]:
# Define the 3D U-Net model architecture
def get_unet():
    # Input layer with shape (128, 128, 128, 1)
    inputs = Input((128, 128, 128,1))

    # Encoder path
    conv1 = Conv3D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv3D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    conv2 = Conv3D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv3D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    conv3 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    conv3 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

    conv4 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    conv4 = Conv3D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop1 = Dropout(0.5)(conv4)

    # Decoder path
    upol1 = UpSampling3D(size = (2, 2, 2))(drop1)
    up1 = Conv3D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(upol1)
    # Concatenate with the corresponding encoder layer
    merge1 = concatenate([conv3,up1],axis=4)
    conv5 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge1)
    conv5 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)

    upol2 = UpSampling3D(size=(2, 2,2))(conv5)
    up2 = Conv3D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(upol2)
    # Concatenate with the corresponding encoder layer
    merge2 = concatenate([conv2,up2],axis=4)
    conv6 = Conv3D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge2)
    conv6 = Conv3D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    upol3 = UpSampling3D(size=(2, 2,2))(conv6)
    up3 = Conv3D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(upol3)
    # Concatenate with the corresponding encoder layer
    merge3 = concatenate([conv1,up3],axis=4)
    conv7 = Conv3D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge3)
    conv7 = Conv3D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    # Output layer with 3 classes (for segmentation)
    output = Conv3D(3, 1, activation = 'softmax')(conv7)

    # Create the model
    model = Model(inputs = inputs, outputs = output)

    # Compile the model with Adam optimizer and categorical crossentropy loss
    model.compile(optimizer = Adam(learning_rate = 1e-4), loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])

    return model

In [None]:
# Load and preprocess the augmented training data
fileList = glob.glob('augmentation/*.tif')
imgs = np.ndarray((len(fileList),128,128,128,1), dtype='float32')
labels = np.ndarray((len(fileList),128,128,128,3), dtype='bool')

i = 0
for name in fileList:
  # Read the TIFF image and its channels
  img = np.array(cv2.imreadmulti(name)[1])

  # Extract the third channel (index 2) as the input image and normalize
  imgs[i] = np.expand_dims(img[:,:,:,2],axis = 3)/255

  # Extract the first two channels (index 0 and 1) as labels for two classes
  labels[i,:,:,:,0] = (img[:,:,:,0]/255).astype('bool')
  labels[i,:,:,:,1] = (img[:,:,:,1]/255).astype('bool')

  # Create the third label channel as the inverse of the union of the first two
  labels[i,:,:,:,2] = np.logical_not(np.logical_or(img[:,:,:,0] , img[:,:,:,1]))

  i+=1
  if (i+1) % 10 == 0:
    print('Done: {0}/{1} images'.format(i+1, len(fileList)))

Done: 10/420 images
Done: 20/420 images
Done: 30/420 images
Done: 40/420 images
Done: 50/420 images
Done: 60/420 images
Done: 70/420 images
Done: 80/420 images
Done: 90/420 images
Done: 100/420 images
Done: 110/420 images
Done: 120/420 images
Done: 130/420 images
Done: 140/420 images
Done: 150/420 images
Done: 160/420 images
Done: 170/420 images
Done: 180/420 images
Done: 190/420 images
Done: 200/420 images
Done: 210/420 images
Done: 220/420 images
Done: 230/420 images
Done: 240/420 images
Done: 250/420 images
Done: 260/420 images
Done: 270/420 images
Done: 280/420 images
Done: 290/420 images
Done: 300/420 images
Done: 310/420 images
Done: 320/420 images
Done: 330/420 images
Done: 340/420 images
Done: 350/420 images
Done: 360/420 images
Done: 370/420 images
Done: 380/420 images
Done: 390/420 images
Done: 400/420 images
Done: 410/420 images
Done: 420/420 images


In [None]:
# Create the U-Net model
model = get_unet()

# Display the model summary
model.summary()

# Define a ModelCheckpoint callback to save the best model during training
model_checkpoint = ModelCheckpoint('model.keras', monitor='categorical_accuracy',verbose=1, save_best_only=False)

# Train the model
# Using a batch size of 1 due to memory constraints with 3D data
# Training for 25 epochs with a 10% validation split
model.fit(imgs,labels, batch_size=1, epochs=25, verbose=1, shuffle=True, validation_split = 0.1, callbacks=model_checkpoint)

Epoch 1/50
[1m378/378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 283ms/step - categorical_accuracy: 0.9187 - loss: 0.3984
Epoch 1: saving model to model.keras
[1m378/378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m169s[0m 357ms/step - categorical_accuracy: 0.9187 - loss: 0.3979 - val_categorical_accuracy: 0.9847 - val_loss: 0.0368
Epoch 2/50
[1m378/378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 283ms/step - categorical_accuracy: 0.9663 - loss: 0.0817
Epoch 2: saving model to model.keras
[1m378/378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 297ms/step - categorical_accuracy: 0.9663 - loss: 0.0817 - val_categorical_accuracy: 0.9836 - val_loss: 0.0398
Epoch 3/50
[1m378/378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 283ms/step - categorical_accuracy: 0.9730 - loss: 0.0671
Epoch 3: saving model to model.keras
[1m378/378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 297ms/step - categorical_accuracy: 0.9730 - loss: 0.0671 - va

<keras.src.callbacks.history.History at 0x7a2ca34e2250>

In [None]:
# Load the trained model
model = load_model('model.keras')

# Get a list of test image files
nameList = glob.glob('test/*.tif')

# Iterate through each test image
for name in nameList:
  # Read the test image
  img = np.array(cv2.imreadmulti(name)[1])

  # Expand dimensions and normalize the image
  img = np.expand_dims(img[:,:,:],axis = [0,4])/255

  # Get the original size of the image
  sz = img.shape[1]

  # Resize the image to 128x128x128 if it's not already
  if sz != 128:
    img = resize(img.squeeze(), (128,128,128), mode='constant', anti_aliasing=True)
    img = np.expand_dims(img,axis = [0,4])

  # Predict the segmentation mask
  pred = model.predict(img, batch_size=1)

  # Clip the predictions to be within [0, 1]
  pred[pred<0] = 0
  pred[pred>1] = 1

  # Resize the prediction back to the original image size if necessary
  if sz != 128:
    pred = resize(pred[0], (sz,sz,sz), mode='constant', anti_aliasing=True)

  # Save the prediction as a TIFF file
  # The filename is extracted from the original test image filename
  tif.imwrite('prediction/'+name[name.find('/')+1:], (pred*255).astype(np.uint8))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms