In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Input, concatenate, BatchNormalization 
from tensorflow.keras.layers import Conv3D, UpSampling3D, Conv3DTranspose
from tensorflow.keras.layers import add
from tensorflow.keras.layers import LeakyReLU, Reshape, Lambda
from tensorflow.keras.initializers import RandomNormal
#import keras
import numpy as np
import psutil
import humanize
import os

def myConv(x_in, nf, strides=1, kernel_size = 3):
    """
    specific convolution module including convolution followed by leakyrelu
    """
    x_out = Conv3D(nf, kernel_size=3, padding='same',kernel_initializer='he_normal', strides=strides)(x_in)
    x_out = BatchNormalization()(x_out)
    x_out = LeakyReLU(0.2)(x_out)
    return x_out

RESIDUAL = True

def Unet3dBlock(l, n_feat):
    if RESIDUAL:
        l_in = l
    for i in range(2):
        l = myConv(l, n_feat)
    return add([l_in, l]) if RESIDUAL else l


def UnetUpsample(l, num_filters):
    l = UpSampling3D()(l)
    l = myConv(l, num_filters)
    return l


BASE_FILTER = 16
FILTER_GROW = True
DEEP_SUPERVISION = True
NUM_CLASS = 1

def unet3d(vol_size, depth = 3):
    inputs = Input(shape=vol_size)
    filters = []
    down_list = []
    deep_supervision = None
    layer = myConv(inputs, BASE_FILTER)
    
    for d in range(depth):
        if FILTER_GROW:
            num_filters = BASE_FILTER * (2**d)
        else:
            num_filters = BASE_FILTER
        filters.append(num_filters)
        layer = Unet3dBlock(layer, n_feat = num_filters)
        down_list.append(layer)
        if d != depth - 1:
            layer = myConv(layer, num_filters*2, strides=2)
        
    for d in range(depth-2, -1, -1):
        layer = UnetUpsample(layer, filters[d])
        layer = concatenate([layer, down_list[d]])
        layer = myConv(layer, filters[d])
        layer = myConv(layer, filters[d], kernel_size = 1)
        
        if DEEP_SUPERVISION:
            if 0< d < 3:
                pred = myConv(layer, NUM_CLASS)
                if deep_supervision is None:
                    deep_supervision = pred
                else:
                    deep_supervision = add([pred, deep_supervision])
                deep_supervision = UpSampling3D()(deep_supervision)
    
    layer = myConv(layer, NUM_CLASS, kernel_size = 1)
    
    if DEEP_SUPERVISION:
        layer = add([layer, deep_supervision])
    layer = myConv(layer, NUM_CLASS, kernel_size = 1)
    x = Activation('softmax', name='softmax')(layer)
        
    model = Model(inputs=[inputs], outputs=[x])
    return model

In [2]:
# def network(input_img, n_filters=16, dropout=0.5, batchnorm=True):
#    outputs = inception_block(input_img, n_filters=n_filters, batchnorm=batchnorm, strides=1, recurrent=2)
#    model = Model(inputs=[input_img], outputs=[outputs])
#    return model

# img = np.random.rand(1,256,256,256,1).astype(np.float32)

In [3]:
xs = 64
ys = 64
zs = 64
depth = 3

m = unet3d((xs,ys,zs,1),depth)
m.summary()

2022-01-18 12:40:07.801958: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 64, 64, 64,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv3d (Conv3D)                (None, 64, 64, 64,   448         ['input_1[0][0]']                
                                16)                                                               
                                                                                                  
 batch_normalization (BatchNorm  (None, 64, 64, 64,   64         ['conv3d[0][0]']                 
 alization)                     16)                                                           

In [4]:
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
printm()

Gen RAM Free: 6.2 GB  | Proc size: 317.9 MB


In [5]:
import numpy as np
x = np.random.rand(10,xs,ys,zs,1)
y = np.ones((10,xs,ys,zs,1))

x = tf.convert_to_tensor(x)
y = tf.convert_to_tensor(y)

dataset = tf.data.Dataset.from_tensor_slices((x,y))
dataset = dataset.batch(1)

In [6]:
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
printm()

Gen RAM Free: 6.2 GB  | Proc size: 360.1 MB


In [7]:
def surface_loss(true,pred):
    b_true = true[...,0]
    b_pred = pred[...,0]
    f_true = 1 - true[...,0]
    f_pred = 1 - pred[...,0]

    true_map = b_true - f_true
    multiplied = f_pred * true_map

    return tf.math.reduce_mean(multiplied)

def sdice(true, pred):
    b_true = true[...,0]
    b_pred = pred[...,0]
    f_true = 1 - true[...,0]
    f_pred = 1 - pred[...,0]

    true_map = f_true - b_true
    multiplied = f_pred * true_map

    return tf.math.reduce_mean(multiplied)

from tensorflow.keras import optimizers

adamlr = optimizers.Adam(
    learning_rate=0.00001, 
    beta_1=0.9, 
    beta_2=0.999, 
    epsilon=1e-08, 
    amsgrad=True)

log_dir = "logs/fit/{}_{}_{}_{}".format(xs,ys,zs,depth) 
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch = '0,9')

m.compile(
        loss=surface_loss,
        optimizer=adamlr, 
        metrics=[sdice])

In [8]:
history=m.fit(
        dataset, 
        epochs=5,
        callbacks=[tensorboard_callback],
        verbose=1)

Epoch 1/5


  layer_config = serialize_layer_fn(layer)


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
