In [1]:
from keras.models import Model
from keras.layers import Input, concatenate, Conv3D, Conv2DTranspose, SpatialDropout3D, ConvLSTM2D, TimeDistributed
from keras.layers.core import Activation, Permute
from keras.layers.convolutional import MaxPooling2D
from keras.layers import BatchNormalization
from keras.optimizers import Adam

In [2]:
img_rows = 256
img_cols = 256
depth = 8
smooth = 1.
nb_epoch = 30
batch_size = 2
Epochs = 30


inputs = Input((img_rows, img_cols, depth, 1))

#list of number of filters per block
depth_cnn = [32, 64, 128, 256]

##start of encoder block

##encoder block1
conv11 = Conv3D(depth_cnn[0], (3, 3, 3), padding='same', name = 'conv1_1')(inputs)
conv11 = BatchNormalization()(conv11)
conv11 = Activation('relu')(conv11)
conc11 = concatenate([inputs, conv11], axis=4)
conv12 = Conv3D(depth_cnn[0], (3, 3, 3), padding='same', name = 'conv1_2')(conc11)
conv12 = BatchNormalization()(conv12)
conv12 = Activation('relu')(conv12)
conc12 = concatenate([inputs, conv12], axis=4)
perm = Permute((3,1,2,4))(conc12)
pool1 = TimeDistributed(MaxPooling2D((2, 2)), name = 'pool1')(perm)
pool1 = Permute((2,3,1,4))(pool1)

pool1 = SpatialDropout3D(0.1)(pool1)

#encoder block2
conv21 = Conv3D(depth_cnn[1], (3, 3, 3), padding='same', name = 'conv2_1')(pool1)
conv21 = BatchNormalization()(conv21)
conv21 = Activation('relu')(conv21)
conc21 = concatenate([pool1, conv21], axis=4)
conv22 = Conv3D(depth_cnn[1], (3, 3, 3), padding='same', name = 'conv2_2')(conc21)
conv22 = BatchNormalization()(conv22)
conv22 = Activation('relu')(conv22)
conc22 = concatenate([pool1, conv22], axis=4)
perm = Permute((3,1,2,4))(conc22)
pool2 = TimeDistributed(MaxPooling2D((2, 2)), name = 'pool2')(perm)
pool2 = Permute((2,3,1,4))(pool2)

pool2 = SpatialDropout3D(0.1)(pool2)

#encoder block3
conv31 = Conv3D(depth_cnn[2], (3, 3, 3), padding='same', name = 'conv3_1')(pool2)
conv31 = BatchNormalization()(conv31)
conv31 = Activation('relu')(conv31)
conc31 = concatenate([pool2, conv31], axis=4)
conv32 = Conv3D(depth_cnn[2], (3, 3, 3), padding='same', name = 'conv3_2')(conc31)
conv32 = BatchNormalization()(conv32)
conv32 = Activation('relu')(conv32)
conc32 = concatenate([pool2, conv32], axis=4)
perm = Permute((3,1,2,4))(conc32)
pool3 = TimeDistributed(MaxPooling2D((2, 2)), name = 'pool3')(perm)

pool3 = SpatialDropout3D(0.1)(pool3)


##end of encoder block

#ConvLSTM block 
x = BatchNormalization()(ConvLSTM2D(filters =depth_cnn[3], kernel_size = (3,3), padding='same', return_sequences=True)(pool3))
x = BatchNormalization()(ConvLSTM2D(filters =depth_cnn[3], kernel_size = (3,3), padding='same', return_sequences=True)(x))
x = BatchNormalization()(ConvLSTM2D(filters = depth_cnn[3], kernel_size = (3,3), padding='same', return_sequences=True)(x))

# start of decoder block

# decoder block1
up1 = TimeDistributed(Conv2DTranspose(depth_cnn[2], (2, 2), strides=(2, 2), padding='same', name = 'up1'))(x)
up1 = Permute((2,3,1,4))(up1)
up6 = concatenate([up1, conc32], axis=4)
conv61 = Conv3D(depth_cnn[2], (3, 3, 3), padding='same', name = 'conv4_1')(up6)
conv61 = BatchNormalization()(conv61)
conv61 = Activation('relu')(conv61)
conc61 = concatenate([up6, conv61], axis=4)
conv62 = Conv3D(depth_cnn[2], (3, 3, 3), padding='same', name = 'conv4_2')(conc61)
conv62 = BatchNormalization()(conv62)
conv62 = Activation('relu')(conv62)
conv62 = concatenate([up6, conv62], axis=4)

#decoder block2
up2 = Permute((3,1,2,4))(conv62)
up2 = TimeDistributed(Conv2DTranspose(depth_cnn[1], (2, 2), strides=(2, 2), padding='same'), name = 'up2')(up2)
up2 = Permute((2,3,1,4))(up2)
up7 = concatenate([up2, conv22], axis=4)
conv71 = Conv3D(depth_cnn[1], (3, 3, 3), padding='same', name = 'conv5_1')(up7)
conv71 = BatchNormalization()(conv71)
conv71 = Activation('relu')(conv71)
conc71 = concatenate([up7, conv71], axis=4)
conv72 = Conv3D(depth_cnn[1], (3, 3, 3), padding='same', name = 'conv5_2')(conc71)
conv72 = BatchNormalization()(conv72)
conv72 = Activation('relu')(conv72)
conv72 = concatenate([up7, conv72], axis=4)

#decoder block3
up3 = Permute((3,1,2,4))(conv72)
up3 = TimeDistributed(Conv2DTranspose(depth_cnn[0], (2, 2), strides=(2, 2), padding='same', name = 'up3'))(up3)
up3 = Permute((2,3,1,4))(up3)
up8 = concatenate([up3, conv12], axis=4)
conv81 = Conv3D(depth_cnn[0], (3, 3, 3), padding='same', name = 'conv6_1')(up8)
conv81 = BatchNormalization()(conv81)
conv81 = Activation('relu')(conv81)
conc81 = concatenate([up8, conv81], axis=4)
conv82 = Conv3D(depth_cnn[0], (3, 3, 3), padding='same', name = 'conv6_2')(conc81)
conv82 = BatchNormalization()(conv82)
conv82 = Activation('relu')(conv82)
conc82 = concatenate([up8, conv82], axis=4)

##end of decoder block
conv10 = Conv3D(1, (1, 1, 1), activation='sigmoid', name = 'final')(conc82)

model = Model(inputs=[inputs], outputs=[conv10])

#model.compile(optimizer=Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.000000199), loss = 'binary_crossentropy', metrics=[dice_coef])

In [3]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 8  0           []                               
                                , 1)]                                                             
                                                                                                  
 conv1_1 (Conv3D)               (None, 256, 256, 8,  896         ['input_1[0][0]']                
                                 32)                                                              
                                                                                                  
 batch_normalization (BatchNorm  (None, 256, 256, 8,  128        ['conv1_1[0][0]']                
 alization)                      32)                                                          

                                                                                                  
 activation_4 (Activation)      (None, 64, 64, 8, 1  0           ['batch_normalization_4[0][0]']  
                                28)                                                               
                                                                                                  
 concatenate_4 (Concatenate)    (None, 64, 64, 8, 2  0           ['spatial_dropout3d_1[0][0]',    
                                25)                               'activation_4[0][0]']           
                                                                                                  
 conv3_2 (Conv3D)               (None, 64, 64, 8, 1  777728      ['concatenate_4[0][0]']          
                                28)                                                               
                                                                                                  
 batch_nor

                                 64)                                                              
                                                                                                  
 permute_7 (Permute)            (None, 128, 128, 8,  0           ['up2[0][0]']                    
                                 64)                                                              
                                                                                                  
 concatenate_9 (Concatenate)    (None, 128, 128, 8,  0           ['permute_7[0][0]',              
                                 128)                             'activation_3[0][0]']           
                                                                                                  
 conv5_1 (Conv3D)               (None, 128, 128, 8,  221248      ['concatenate_9[0][0]']          
                                 64)                                                              
          