# MultiResUNet : Rethinking the U-Net architecture for multimodal biomedical image segmentation
**Reference** <br/>
*[1] Ibtehaz, Nabil, and M. Sohel Rahman. "MultiResUNet: Rethinking the U-Net architecture for multimodal biomedical image segmentation." Neural Networks 121 (2020): 74-87. <br/>
[2] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015. <br/>
[3] Szegedy, Christian, et al. "Going deeper with convolutions." Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.*

### Related Works : U-Net
- **Skip connections in U-Net** <br/>
<img src="https://user-images.githubusercontent.com/75057952/161746257-f868e1f6-3f78-4b19-9b7e-bbbf910360c4.png" width = "500dp"></img>
- **Extensions of 2D U-Net to 3D U-Net**
    - In particular, the two dimensional convolution, max pooling, transposed convolution operations were replaced by their three dimensional counterparts. 
    - In order to limit the number of parameters, the depth of the network was reduced by one. 
    - Moreover, the number of filters were doubled before thepooling layers to avoid bottlenecks. 
    - The original U-Net did not use batch normalization, however, they were experimented with in the 3D U-Net and astonishingly the results revealed that batch normalization may sometime even hurt the performance.

### Motivations and current challenges
**01. Variation of Scale in Medical Images**
- segmenting cell necluei, organs, tumors etc. from images originating from various modalities.
- objects of interest are of irregular and different scales.
- e.g. scale of skin lesions can greatly vary in dermoscopy images! <br/>
<img src = "https://user-images.githubusercontent.com/75057952/161746274-9286abfb-15ed-47bb-b23d-34f0931e0e37.png" width = "500dp"></img>
- Inception Module?
    - Remark on Inception
        <img src = "https://user-images.githubusercontent.com/75057952/161746285-c5d73d54-169b-45e1-aa88-c3c418474dbd.png" width = "400dp"></img>
    - Proposed Inception Module, called "MultiRes Block" : last figure
        <img src = "https://user-images.githubusercontent.com/75057952/161746297-5d6c6ea7-1328-4da4-bd11-d2b81195544e.png"></img>

**02. Probable Semantic Gap between the Corresponding Levels of Encoder-Decoder**
- Novelty of U-Net
    - introduction of shortcut connections between the corresponding layers before and after the max-pooling and the deconvolution layers respectively. 
    - This enables the network to propagate from encoder to decoder, the spatial information that gets lost during the pooling operation.
-  Flaw of the skip connections?
    - first shortcut connection bridges the encoder before the first pooling with the decoder after the last deconvolution operation
    - does first layer of encoder & last layer of decoder match semantically? 
    - observae a possible semantic gap between the two sets of features being merged!
    - propose to incorporate some convolutional layers along the shortcut connections.
    - "Res Path" <br/>
        -<img src="https://user-images.githubusercontent.com/75057952/161746308-07d83efc-00ae-4a9a-8a2a-55947e04eaa5.png" width = "600dp"></img>

**03. Overall Architecture**
- MultiRes Block + Res Path <br/>
    <img src ="https://user-images.githubusercontent.com/75057952/161746322-7e635ab2-263d-4ed9-a6cf-b745254ca1b0.png" width = "600dp"></img>
- Architecture details <br/>
    <img src = "https://user-images.githubusercontent.com/75057952/161746332-bef12bff-61ff-4243-b4cc-c3387ea096ea.png" width = "600dp"></img>

### Result
**Quantitative Result** : MultiResUNet outperforms U-Net consistently<br/>
    <img src = "https://user-images.githubusercontent.com/75057952/161746343-65f6365b-fffe-4d30-b7d1-923ceb031660.png" width = "600dp"></img>
**Qualitative Result**
- MultiResUNet Delineates Faint Boundaries Better<br/>
    <img src = "https://user-images.githubusercontent.com/75057952/161746352-b749ec14-8f06-46aa-9037-d06426a14b9a.png" width = "700dp"></img>
- MultiResUNet is More Immune to Perturbations(Artifacts, Noises) <br/>
   <img src = "https://user-images.githubusercontent.com/75057952/161746365-98f3162e-2637-40fd-ac21-3ca712ded725.png" width = "600dp"></img>

```python
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add
from keras.models import Model, model_from_json
from keras.optimizers import Adam
from keras.layers.advanced_activations import ELU, LeakyReLU
from keras.utils.vis_utils import plot_model
```

```python
def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='relu', name=None):

    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    if(activation == None):
        return x
    else:
        x = Activation(activation, name=name)(x)
        return x

def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None):
    
    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)
    return x

def MultiResBlock(U, inp, alpha = 1.67):

    W = alpha * U
    shortcut = inp
    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) + int(W*0.5), 1, 1, activation=None, padding='same')
    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3, activation='relu', padding='same')
    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3, activation='relu', padding='same')
    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3, activation='relu', padding='same')

    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
    out = BatchNormalization(axis=3)(out)
    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out

def ResPath(filters, length, inp):
    
    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1, activation=None, padding='same')
    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')
    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):
        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1, activation=None, padding='same')
        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')
        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out


def MultiResUnet(height, width, n_channels):
    
    inputs = Input((height, width, n_channels))

    mresblock1 = MultiResBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
    mresblock1 = ResPath(32, 4, mresblock1)

    mresblock2 = MultiResBlock(32*2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
    mresblock2 = ResPath(32*2, 3, mresblock2)

    mresblock3 = MultiResBlock(32*4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
    mresblock3 = ResPath(32*4, 2, mresblock3)

    mresblock4 = MultiResBlock(32*8, pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(mresblock4)
    mresblock4 = ResPath(32*8, 1, mresblock4)

    mresblock5 = MultiResBlock(32*16, pool4)

    up6 = concatenate([Conv2DTranspose(32*8, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock4], axis=3)
    mresblock6 = MultiResBlock(32*8, up6)

    up7 = concatenate([Conv2DTranspose(32*4, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock3], axis=3)
    mresblock7 = MultiResBlock(32*4, up7)

    up8 = concatenate([Conv2DTranspose(32*2, (2, 2), strides=(2, 2), padding='same')(mresblock7), mresblock2], axis=3)
    mresblock8 = MultiResBlock(32*2, up8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(mresblock8), mresblock1], axis=3)
    mresblock9 = MultiResBlock(32, up9)

    conv10 = conv2d_bn(mresblock9, 1, 1, 1, activation='sigmoid')
    
    model = Model(inputs=[inputs], outputs=[conv10])

    return model
   


def main():

    model = MultiResUnet(128, 128,3)
    print(model.summary())
    
if __name__ == '__main__':
    main()
```