<a href="https://colab.research.google.com/github/sohom21d/Segmentation-models/blob/master/Unet3D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPool3D, Conv3DTranspose, concatenate, BatchNormalization

In [2]:
def Conv3DNorm2X(input_tensor, filters):
  '''
  conv3d->batchnorm->conv3d->batchnorm
  '''
  output_tensor = Conv3D(filters, kernel_size=(3,3,3), padding='same', activation='relu', kernel_initializer='he_normal')(input_tensor)
  output_tensor = BatchNormalization()(output_tensor)
  output_tensor = Conv3D(filters, kernel_size=(3,3,3), padding='same', activation='relu', kernel_initializer='he_normal')(output_tensor)
  output_tensor = BatchNormalization()(output_tensor)
  return output_tensor

In [3]:
def unet3d(input_shape=(128,128,16,1), n_classes=1, start_filters=32):

  '''
  input_shape = (H, W, D, C)
  n_classes = number of classes (ouput channels)
  start_filters = starting encoder layer filters
  '''

  f = [start_filters, start_filters*2, start_filters*3, start_filters*4, start_filters*5]

  inputs = Input(input_shape)

  c1 = Conv3DNorm2X(inputs, f[0])
  p1 = MaxPool3D()(c1)

  c2 = Conv3DNorm2X(p1, f[1])
  p2 = MaxPool3D()(c2)

  c3 = Conv3DNorm2X(p2, f[2])
  p3 = MaxPool3D()(c3)

  c4 = Conv3DNorm2X(p3, f[3])
  p4 = MaxPool3D()(c4)

  c5 = Conv3DNorm2X(p4, f[4])

  u6 = Conv3DTranspose(f[3], kernel_size=(2,2,2), strides=(2,2,2), padding='same')(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv3DNorm2X(u6, f[3])

  u7 = Conv3DTranspose(f[2], kernel_size=(2,2,2), strides=(2,2,2), padding='same')(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv3DNorm2X(u7, f[2])

  u8 = Conv3DTranspose(f[1], kernel_size=(2,2,2), strides=(2,2,2), padding='same')(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv3DNorm2X(u8, f[1])

  u9 = Conv3DTranspose(f[0], kernel_size=(2,2,2), strides=(2,2,2), padding='same')(c8)
  u9 = concatenate([u9, c1])
  c9 = Conv3DNorm2X(u9, f[0])

  outputs = Conv3D(n_classes, kernel_size=(1,1,1), padding='same', activation='sigmoid', kernel_initializer='he_normal')(c9)

  model = Model(inputs=[inputs], outputs=[outputs])

  return model

In [6]:
model = unet3d(input_shape=(128,128,16,1), n_classes=1, start_filters=32)

In [7]:
model.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 128, 128, 16 0                                            
__________________________________________________________________________________________________
conv3d_19 (Conv3D)              (None, 128, 128, 16, 896         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 128, 128, 16, 128         conv3d_19[0][0]                  
__________________________________________________________________________________________________
conv3d_20 (Conv3D)              (None, 128, 128, 16, 27680       batch_normalization_18[0][0]     
_______________________________________________________________________________________