In [1]:
from keras.engine.topology import Layer
from keras import backend as K
from keras.layers import Conv2D, BatchNormalization, Concatenate, MaxPooling2D, Activation
from keras.layers import Input, Flatten, Reshape
from keras.models import Model
from keras.layers.core import K as Kc

Using TensorFlow backend.


In [2]:
class MaxoutConv2D(Layer):
    """
    Convolution Layer followed by Maxout activation as described 
    in https://arxiv.org/abs/1505.03540.
    
    PARAMETERS
    ----------
    
    kernel_size: kernel_size parameter for Conv2D
    output_dim: final number of filters after Maxout
    keep_prob: keep probability for Dropout
    nb_features: number of filter maps to take the Maxout over; default=4
    padding: 'same' or 'valid'
    first_layer: True if x is the input_tensor
    input_shape: Required if first_layer=True
    
    """
    
    def __init__(self, kernel_size, output_dim, nb_features=4, padding='valid', **kwargs):
        
        self.kernel_size = kernel_size
        self.output_dim = output_dim
        self.nb_features = nb_features
        self.padding = padding
        super(MaxoutConv2D, self).__init__(**kwargs)

    def call(self, x):

        num_channels = self.output_dim * self.nb_features
        conv_out = Conv2D(num_channels, self.kernel_size, padding=self.padding)(x)
        batch_norm_out = BatchNormalization()(conv_out)
        out_shape = batch_norm_out.get_shape().as_list()
        reshape_out = Reshape((out_shape[1], out_shape[2], 
                               self.nb_features, self.output_dim))(batch_norm_out)
        maxout_out = K.max(reshape_out, axis=-2)

        return maxout_out

    def get_config(self):

        config = {"kernel_size": self.kernel_size,
                  "output_dim": self.output_dim,
                  "nb_features": self.nb_features,
                  "padding": self.padding}

        base_config = super(MaxoutConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        input_height= input_shape[1]
        input_width = input_shape[2]
        
        if(self.padding == 'same'):
            output_height = input_height
            output_width = input_width
        
        else:
            output_height = input_height - self.kernel_size[0] + 1
            output_width = input_width - self.kernel_size[1] + 1
        
        return (input_shape[0], output_height, output_width, self.output_dim)

In [3]:
class MultichannelCascadeCNN(object):
    """
    A Multi-channel cascaded CNN architecture introduced in 
    https://arxiv.org/pdf/1505.03540.pdf - "Brain Tumor Segmentation with Deep Neural Networks".
    
    PARAMETERS
    ----------
    patch_size: int, default=33
                Size of the patch fed into the model.
    
    mode: {'input', 'local', 'final'}, default='input'
          The type of cascaded to be done (Refer the README for clarity).
          
    num_channels: int, default=1
                  Number of channels in the input patch.
    
    num_classes: int, default=2
                 Number of possible classes for a pixel.
                 
    num_filters_local_1: int, default=64
                         Number of filters to be used in the first convolutional layer
                         of the local pathway.
                         
    num_filters_local_2: int, default=64
                         Number of filters to be used in the second convolutional layer
                         of the local pathway.
                         
    num_filters_global: int, default=160
                         Number of filters to be used in the convolutional layer of the
                         global pathway.
                         
    kernel_local_1: int, default=7
                    Kernel size to be used in the first convolutional layer of the local
                    pathway.
                    
    kernel_local_2: int, default=3
                    Kernel size to be used in the second convolutional layer of the local
                    pathway.
                
    kernel_global: int, default=13
                   Kernel size to be used in the convolutional layer of the global pathway.
    
    pool_local_1: int, default=4
                  The pooling size for the max pooling layer after the first convolutional 
                  layer of the local pathway.
    
    pool_local_2: int, default=2
                  The pooling size for the max pooling layer after the second convolutional 
                  layer of the local pathway.
    
    
    """
    
    def __init__(self, patch_size=33, mode='input', num_channels=1, num_classes=2, 
                 num_filters_local_1=64, num_filters_local_2=64, num_filters_global=160, 
                 kernel_local_1=7, kernel_local_2=3, kernel_global=13, pool_local_1=4,
                 pool_local_2=2):
        
        self.patch_size = patch_size
        self.mode = mode
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.num_filters_local_1 = num_filters_local_1
        self.num_filters_local_2 = num_filters_local_2
        self.num_filters_global = num_filters_global
        self.kernel_local_1 = kernel_local_1
        self.kernel_local_2 = kernel_local_2
        self.kernel_global = kernel_global
        self.pool_local_1 = pool_local_1
        self.pool_local_2 = pool_local_2
        
        self.classification_kernel_size = self.patch_size - self.kernel_global + 1
        
        if(self.mode == 'input'):
            self.outer_patch_size = 2 * self.patch_size - 1
            
        elif(self.mode == 'local'):
            self.outer_patch_size = self.patch_size + self.classification_kernel_size + \
                self.pool_local_2 + self.kernel_local_2 - 3
        
        else:
            self.outer_patch_size = self.patch_size + self.classification_kernel_size - 1
            
            
    def _forward_pass(self, model_input, stage, append_output=None, prev_output=None):
        
        local_kernel_1 = (self.kernel_local_1, self.kernel_local_1)
        local_kernel_2 = (self.kernel_local_2, self.kernel_local_2)
        global_kernel = (self.kernel_global, self.kernel_global)

        local_output_dim_1 = self.num_filters_local_1
        local_output_dim_2 = self.num_filters_local_2
        global_output_dim = self.num_filters_global

        local_pool_1 = (self.pool_local_1, self.pool_local_1)
        local_pool_2 = (self.pool_local_2, self.pool_local_2)
        
        # InputCascadeCNN
        if(append_output == 'input'): 
            final_input = Concatenate(axis=-1, name='input_concat')([model_input, prev_output])

        else:
            final_input = model_input

        local_output = MaxoutConv2D(local_kernel_1, output_dim=local_output_dim_1, 
                                    name=stage+'_local_conv1')(final_input)

        local_output = MaxPooling2D(pool_size=local_pool_1, strides=(1, 1),
                                    name=stage+'_local_pool1')(local_output)
        
        # LocalCascadeCNN
        if(append_output == 'local'):
            local_output = Concatenate(axis=-1, name='local_concat')([local_output, prev_output])

        local_output = MaxoutConv2D(local_kernel_2, output_dim=local_output_dim_2, 
                                    name=stage+'_local_conv2')(local_output)
        local_output = MaxPooling2D(pool_size=local_pool_2, strides=(1, 1), 
                                    name=stage+'_local_pool2')(local_output)

        global_output = MaxoutConv2D(global_kernel, output_dim=global_output_dim, 
                                     name=stage+'_global')(final_input)

        output = Concatenate(axis=-1, name=stage+'_concat')([local_output, global_output])
        
        # MfCascadeCNN
        if(append_output == 'final'):
            output = Concatenate(axis=-1, name='final_concat')([output, prev_output])

        final_kernel_size = self.classification_kernel_size, self.classification_kernel_size
        output = Conv2D(self.num_classes, final_kernel_size, padding='valid', name=stage+'_conv_last')(output)
        output = Activation('softmax', name=stage+'_softmax')(output)

        return output
        
        
    def build_model(self):
        """
        Builds the MultiCascadedCNN model.
        
        Returns
        -------
        
        model: Keras Model object
               The MultiCascadedCNN model as per the mode.
               
        """
        
        # First stage
        first_cascade_input_shape = (self.outer_patch_size, 
                                     self.outer_patch_size, 
                                     self.num_channels)
    
        first_cascade_input = Input(shape=first_cascade_input_shape)
        first_cascade_output = self._forward_pass(first_cascade_input, stage='1')    
        
        # Second stage
        second_cascade_input_shape = (self.patch_size, self.patch_size, self.num_channels)
        second_cascade_input = Input(shape=second_cascade_input_shape)
        second_cascade_output = self._forward_pass(second_cascade_input, 
                                             stage='2',
                                             append_output=self.mode,
                                             prev_output=first_cascade_output)

        output = Flatten()(second_cascade_output)
        model = Model(inputs=[first_cascade_input, second_cascade_input], outputs=output)
        
        return model

In [4]:
# delete earlier graphs
K.clear_session()                   

In [5]:
# Required for Layers which have different functions for 
# training / testing (e.g. Dropout, BatchNormalization)
Kc.set_learning_phase(True)      

In [6]:
multiCascadeCNN = MultichannelCascadeCNN(mode='final')

In [7]:
model = multiCascadeCNN.build_model()

In [8]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 53, 53, 1)    0                                            
__________________________________________________________________________________________________
1_local_conv1 (MaxoutConv2D)    (None, 47, 47, 64)   0           input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 33, 33, 1)    0                                            
__________________________________________________________________________________________________
1_local_pool1 (MaxPooling2D)    (None, 44, 44, 64)   0           1_local_conv1[0][0]              
__________________________________________________________________________________________________
2_local_co