In [None]:
!pip install patchify

In [None]:
from skimage import io
import os
import random
import glob
import tifffile as tiff
from patchify import patchify,unpatchify
from tensorflow.keras import backend as K
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.util import view_as_blocks

In [None]:
import tensorflow as tf

#Check if a GPU is available as accelerator
if tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
    print('Using GPU')
else:
    print('Using CPU')

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Activation, MaxPool2D, Concatenate

#Main Convolutional Block for Encoder and Decoder
def conv_block(input, num_filters):
    x = Conv3D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv3D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

#Encoder Block
def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPooling3D((2, 2, 2))(x)
    return x, p

#Decoder Block
def decoder_block(input, skip_features, num_filters):
    x = Conv3DTranspose(num_filters, (2, 2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

#U-Net building
def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv3D(1, 1, padding="same", activation='sigmoid')(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [None]:
#Uploading MRI scan for training (sagittal view) with GT
original_image = io.imread('data/Scansione2/320imageStack2.tif')
print("Dimensioni originali immagini: ", original_image.shape)

original_mask = io.imread('data/Scansione2/320maskStack2.tif')
print("Dimensioni originali maschere: ", original_mask.shape)

In [None]:
#Uploading MRI scan for training (axial and coronal view) with GT
aug_img1 = io.imread('data/Scansione2/axial2.tif')
aug_img2 = io.imread('data/Scansione2/coronal2.tif')

aug_mask1 = io.imread('data/Scansione2/axial2_GT.tif')
aug_mask2 = io.imread('data/Scansione2/coronal2_GT.tif')

In [None]:
#Subvolumes creator function
def get_subvolumes(image,mask,step):
    img_patches = patchify(image, (64,64,64), step=step)
    mask_patches = patchify(mask, (64,64,64), step=step)
    input_img = np.reshape(img_patches, (-1, img_patches.shape[3], img_patches.shape[4], img_patches.shape[5]))
    input_mask = np.reshape(mask_patches, (-1, mask_patches.shape[3], mask_patches.shape[4], mask_patches.shape[5]))
    
    return input_img, input_mask

In [None]:
original_input_img, original_input_mask = get_subvolumes(original_image,original_mask,32)
aug_input_img1, aug_input_mask1 = get_subvolumes(aug_img1,aug_mask1,32)
aug_input_img2, aug_input_mask2 = get_subvolumes(aug_img2,aug_mask2,32)

print(original_input_img.shape, original_input_mask.shape)
print(aug_input_img1.shape, aug_input_mask1.shape)
print(aug_input_img2.shape, aug_input_mask2.shape)

input_img = np.concatenate((original_input_img,aug_input_img1), axis=0)
input_img = np.concatenate((input_img,aug_input_img2), axis=0)

input_mask = np.concatenate((original_input_mask,aug_input_mask1), axis=0)
input_mask = np.concatenate((input_mask,aug_input_mask2), axis=0)

print(input_img.shape)
input_shape = input_img.shape[0]

In [None]:
#Normalization and reshaping for train and validation set splitting
train_img = np.stack((input_img,), axis=-1)
train_img = train_img / 255.
train_mask = np.expand_dims(input_mask, axis=4)

X_train, X_test, Y_train, Y_test = train_test_split(train_img, train_mask, test_size = 0.20, random_state = 42)

In [None]:
#U-Net
patch_size = 64
channels=1

LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)

model = build_unet((patch_size,patch_size,patch_size,channels))

model.compile(optimizer = optim, loss='binary_crossentropy', metrics=['accuracy'])
print(model.summary())

In [None]:
print(model.input_shape)
print(X_train.shape)
print(model.output_shape)
print(Y_train.shape)
print("-------------------")
print(X_train.max())

In [None]:
#Checkpoint callback to save best model according to validation loss
from tensorflow.keras.callbacks import ModelCheckpoint

best_mcp_save = ModelCheckpoint('ckp/best_model-{val_loss:.4f}.keras', save_best_only=True, monitor='val_loss', mode='min')

In [None]:
#Model training
history = model.fit(X_train, Y_train, batch_size=10, epochs=100, shuffle=True, verbose=1, callbacks = [best_mcp_save], validation_data=(X_test, Y_test))

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

plt.plot(epochs, acc, 'y', label='Training Acc')
plt.plot(epochs, val_acc, 'r', label='Validation Acc')
plt.title('Training and validation Acc')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
from keras.models import load_model
my_model = model

#If already got the model:
#my_model = load_model('/kaggle/working/best_epoch-49-0.0024.keras')

In [None]:
#Prediction on a new MRI scan 
large_image = io.imread('data/Scansione1/320imageStack.tif') 
print(large_image.shape) 
patches = patchify(large_image,(64,64,64), step=32)

In [None]:
#Prediction on single patches
predicted_patches = []
for i in range(patches.shape[0]):
  for j in range(patches.shape[1]):
    for k in range(patches.shape[2]):
      single_patch = patches[i,j,k, :,:,:]
      single_patch_ch = np.stack((single_patch,), axis=-1)
      single_patch_ch = single_patch_ch/255.
      single_patch_ch_input = np.expand_dims(single_patch_ch, axis=0)
      single_patch_prediction = (my_model.predict(single_patch_ch_input) >= 0.5) #Thresholding
      predicted_patches.append(single_patch_prediction)

In [None]:
#Final image reconstruction
predicted_patches = np.array(predicted_patches)
print(predicted_patches.shape)

predicted_patches_reshaped = np.reshape(predicted_patches,
                                        (patches.shape[0], patches.shape[1], patches.shape[2],
                                         patches.shape[3], patches.shape[4], patches.shape[5]) )
print(predicted_patches_reshaped.shape)

In [None]:
#Reconstruct the 3D volume from the predicted patches, using overlap
def overlap_and_add(blocks, output_shape, block_shape, overlap):
    reconstructed_volume = np.zeros(output_shape, dtype=blocks.dtype)
    rows, cols, depths = output_shape
    num_blocks_z, num_blocks_x, num_blocks_y, _, _, _ = blocks.shape

    step_row = block_shape[0] - overlap
    step_col = block_shape[1] - overlap
    step_depth = block_shape[2] - overlap

    for i in range(0, num_blocks_z * step_row, step_row):
        for j in range(0, num_blocks_x * step_col, step_col):
            for k in range(0, num_blocks_y * step_depth, step_depth):
                block = blocks[i // step_row, j // step_col, k // step_depth]
                i0 = min(i, rows - block_shape[0])
                i1 = min(i + block_shape[0], rows)
                j0 = min(j, cols - block_shape[1])
                j1 = min(j + block_shape[1], cols)
                k0 = min(k, depths - block_shape[2])
                k1 = min(k + block_shape[2], depths)
                reconstructed_volume[i0:i1, j0:j1, k0:k1] += block[:i1 - i0, :j1 - j0, :k1 - k0]

    return reconstructed_volume

reconstructed_image = overlap_and_add(predicted_patches_reshaped, output_shape=(large_image.shape[0], large_image.shape[1], large_image.shape[2]), block_shape=(64, 64, 64), overlap=32)
print(reconstructed_image.shape)