In [1]:
from keras.models import Model, Sequential
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, Conv3D, MaxPooling3D, Conv3DTranspose

Using TensorFlow backend.


In [0]:
def unet_3d(input_shape):
  ''' Returns 3D UNet model for binary segementation

  Args:
    input_shape(tuple): Input shape of the network in (x,y,z,1)

  Returns:
    model: 3D Unet model
  '''

  inputs = Input(input_shape)
  
  conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
  conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv1)
  pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

  conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(pool1)
  conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv2)
  pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

  conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(pool2)
  conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv3)
  pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

  conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(pool3)
  conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv4)
  pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)

  conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(pool4)
  conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(conv5)

  up6 = concatenate([Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5), conv4], axis=4)
  conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(up6)
  conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv6)

  up7 = concatenate([Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6), conv3], axis=4)
  conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(up7)
  conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv7)

  up8 = concatenate([Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv7), conv2], axis=4)
  conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up8)
  conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv8)

  up9 = concatenate([Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv8), conv1], axis=4)
  conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(up9)
  conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv9)

  conv10 = Conv3D(1, (1, 1, 1), activation='sigmoid')(conv9)

  model = Model(inputs=[inputs], outputs=[conv10])

  return model

In [0]:
def unet_3d_multiclass(input_shape, num_classes):
  ''' Returns 3D UNet model for multi-class segementation

  Args:
      input_shape(tuple): Input shape of the network in (x,y,z,1)
      num_classes(int): Number of classes

  Returns:
      model: 3D UNet model
  '''
  
  inputs = Input(input_shape)
  
  conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
  conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv1)
  pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

  conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(pool1)
  conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv2)
  pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

  conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(pool2)
  conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv3)
  pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

  conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(pool3)
  conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv4)
  pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)

  conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(pool4)
  conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(conv5)

  up6 = concatenate([Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5), conv4], axis=4)
  conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(up6)
  conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv6)

  up7 = concatenate([Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6), conv3], axis=4)
  conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(up7)
  conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv7)

  up8 = concatenate([Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv7), conv2], axis=4)
  conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up8)
  conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv8)

  up9 = concatenate([Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv8), conv1], axis=4)
  conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(up9)
  conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv9)

  conv10 = Conv3D(num_classes, (1, 1, 1), activation='softmax')(conv9)

  model = Model(inputs=[inputs], outputs=[conv10])

  return model

In [0]:
def unet_2d(input_shape):
  ''' Returns 2D UNet model for binary segementation

  Args:
    input_shape(tuple): Input shape of the network in (x,y,1)

  Returns:
    model: 2D Unet model
  '''
  
  inputs = Input(input_shape)
  
  conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
  conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
  pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

  conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
  conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
  pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

  conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
  conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
  pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

  conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
  conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
  pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

  conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
  conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

  up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
  conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
  conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

  up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
  conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
  conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

  up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
  conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
  conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

  up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
  conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
  conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

  conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

  model = Model(inputs=[inputs], outputs=[conv10])

  return model

In [0]:
def unet_2d_multiclass(input_shape, num_classes):
  ''' Returns 2D UNet model for multi-class segementation

  Args:
    input_shape(tuple): Input shape of the network in (x,y,1)
    num_classes(int): Number of classes

  Returns:
    model: 2D Unet model
  '''

  inputs = Input(input_shape)
  
  conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
  conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
  pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

  conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
  conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
  pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

  conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
  conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
  pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

  conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
  conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
  pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

  conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
  conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

  up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
  conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
  conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

  up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
  conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
  conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

  up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
  conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
  conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

  up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
  conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
  conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

  conv10 = Conv2D(num_classes, (1, 1), activation='softmax')(conv9)

  model = Model(inputs=[inputs], outputs=[conv10])

  return model