In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv3D, UpSampling2D, MaxPool3D, Conv3DTranspose
from tensorflow.keras.layers import BatchNormalization, Add, Reshape
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.backend import squeeze, transpose, reshape
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
import numpy as np
import os
import random
from tqdm import tqdm

%matplotlib inline

#### 3D U-Net Model

In [2]:
INPUT_SIZE = 64 # Input feature width/height
INPUT_DEPTH = 64 # Input depth 
INPUT_CHANNEL = 1
OUTPUT_SIZE = 64 # Output feature width/height 
OUTPUT_DEPTH = 64 # Output depth
OUTPUT_CHANNEL = 1
OUTPUT_CLASSES = 4 # Number of output classes in dataset

base_filt = 32
dropout_rate = 0.15

In [3]:
def conv_batch_relu(tensor, filters, name, kernel = [3,3,3], stride = [1,1,1]):
    conv = Conv3D(filters, kernel_size = kernel, strides = stride, padding = 'same',
                  kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=0.1), 
                  kernel_regularizer = tf.keras.regularizers.l2(0.1), 
                  name=name)(tensor)
    conv = BatchNormalization()(conv)
    conv = Activation('relu')(conv)
    return conv

def upconvolve(tensor, filters, name, kernel = 2, stride = 2, activation = None):
    conv = Conv3DTranspose(filters, kernel_size = kernel, strides = stride, padding = 'same', use_bias=False, 
                                      kernel_initializer = tf.keras.initializers.TruncatedNormal,  
                                      kernel_regularizer = tf.keras.regularizers.l2(0.1), name=name)(tensor)
    return conv

def centre_crop_and_concat(prev_conv, up_conv):
    # Needed if the padding is 'valid'
    p_c_s = prev_conv.get_shape()
    u_c_s = up_conv.get_shape()
    offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                         (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
    size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
    prev_conv_crop = tf.slice(prev_conv, offsets, size)
    up_concat = tf.concat((prev_conv_crop, up_conv), 4)
    return up_concat


model_input = Input(shape=(INPUT_DEPTH, INPUT_SIZE, INPUT_SIZE, INPUT_CHANNEL), name='input_img')
# Level zero
conv_0_1 = conv_batch_relu(model_input, base_filt, name='conv_0_1')
conv_0_2 = conv_batch_relu(conv_0_1, base_filt*2, name='conv_0_2')
# Level one
max_1_1 = MaxPool3D([2,2,2], [2,2,2], name='max_1_1')(conv_0_2) 
conv_1_1 = conv_batch_relu(max_1_1, base_filt*2, name='conv_1_1')
conv_1_2 = conv_batch_relu(conv_1_1, base_filt*4, name='conv_1_2')
conv_1_2 = Dropout(rate = dropout_rate, name='conv_1_2_dropout')(conv_1_2)
# Level two
max_2_1 = MaxPool3D([2,2,2], [2,2,2], name='max_2_1')(conv_1_2) 
conv_2_1 = conv_batch_relu(max_2_1, base_filt*4, name='conv_2_1')
conv_2_2 = conv_batch_relu(conv_2_1, base_filt*8, name='conv_2_2')
conv_2_2 = Dropout(rate = dropout_rate, name='conv_2_2_dropout')(conv_2_2)
# Level three
max_3_1 = MaxPool3D([2,2,2], [2,2,2], name='max_3_1')(conv_2_2) 
conv_3_1 = conv_batch_relu(max_3_1, base_filt*8, name='conv_3_1')
conv_3_2 = conv_batch_relu(conv_3_1, base_filt*16, name='conv_3_2')
conv_3_2 = Dropout(rate = dropout_rate, name='conv_3_2_dropout')(conv_3_2)
# Level two
up_conv_3_2 = upconvolve(conv_3_2, base_filt*16, kernel = 2, stride = [2,2,2], name='up_conv_3_2')  
concat_2_1 = centre_crop_and_concat(conv_2_2, up_conv_3_2)
conv_2_3 = conv_batch_relu(concat_2_1, base_filt*8, name='conv_2_3')
conv_2_4 = conv_batch_relu(conv_2_3, base_filt*8, name='conv_2_4')
conv_2_4 = Dropout(rate = dropout_rate, name='conv_2_4_dropout')(conv_2_4)
# Level one
up_conv_2_1 = upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [2,2,2], name='up_conv_2_1')
concat_1_1 = centre_crop_and_concat(conv_1_2, up_conv_2_1)
conv_1_3 = conv_batch_relu(concat_1_1, base_filt*4, name='conv_1_3')
conv_1_4 = conv_batch_relu(conv_1_3, base_filt*4, name='conv_1_4')
conv_1_4 = Dropout(rate = dropout_rate, name='conv_1_4_dropout')(conv_1_4)
# Level zero
up_conv_1_0 = upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [2,2,2], name='conv_1_0') 
concat_0_1 = centre_crop_and_concat(conv_0_2, up_conv_1_0)
conv_0_3 = conv_batch_relu(concat_0_1, base_filt*2, name='conv_0_3')
conv_0_4 = conv_batch_relu(conv_0_3, base_filt*2, name='conv_0_4')
conv_0_4 = Dropout(rate = dropout_rate, name='conv_0_4_dropout')(conv_0_4)
conv_out = Conv3D(OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same', name='conv_out')(conv_0_4)


In [4]:
unet3d = Model(inputs=model_input, outputs=conv_out, name='unet3d')

In [5]:
unet3d.summary()

Model: "unet3d"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_img (InputLayer)          [(None, 64, 64, 64,  0                                            
__________________________________________________________________________________________________
conv_0_1 (Conv3D)               (None, 64, 64, 64, 3 896         input_img[0][0]                  
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 64, 64, 64, 3 128         conv_0_1[0][0]                   
__________________________________________________________________________________________________
activation (Activation)         (None, 64, 64, 64, 3 0           batch_normalization[0][0]        
_____________________________________________________________________________________________

#### Loss function

In [6]:

def soft_max_loss(model_labels, conv_out, num_classes=OUTPUT_CLASSES, loss_weights=[1, 150, 100, 1.0], do_weight=True):
    model_labels = tf.image.convert_image_dtype(model_labels, tf.int32)
    conv_out = tf.image.convert_image_dtype(conv_out, tf.float32)
    labels_one_hot = tf.squeeze(tf.one_hot(model_labels, num_classes, axis = -1), axis = -2)
    labels_one_hot = tf.cast(labels_one_hot, tf.float32)
    ce_loss = tf.nn.softmax_cross_entropy_with_logits(logits=conv_out, labels=labels_one_hot)
    if do_weight:
        weighted_loss = tf.reshape(tf.constant(loss_weights), [1, 1, num_classes]) # Format to the right size
        weighted_one_hot = tf.reduce_sum(weighted_loss*labels_one_hot, axis = -1)
        ce_loss = ce_loss * weighted_one_hot
    loss = tf.reduce_mean(ce_loss)

    return loss

#### Compile

In [None]:
op = tf.keras.optimizers.Adam(learning_rate = 0.001)
unet3d.compile(op, 
               loss= soft_max_loss, 
               metrics=['accuracy'],
                   )

#### Checkpoints

In [None]:
# Create checkpoint
checkpoint_path = "./checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = ModelCheckpoint(checkpoint_path,
                              monitor = 'val_loss',
                              verbose = 1,
                              save_best_only=False,
                              save_weights_only=False,
                              period=10)    

In [None]:
hist = unet3d.fit(x=X_train, 
                y=Y_train, 
                batch_size=3,
                epochs=100,
                shuffle=True,
                validation_split = 0.2,
                #validation_data=(X_test, X_test)
                callbacks = [cp_callback]
                )