<a href="https://colab.research.google.com/github/kiplimock/colab-notebooks/blob/main/semantic_segmentation_with_attention_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Attention Augemented U-Net for Tree Crown Segmentation

Based on the following papers:

1. Jodas, D. S., Velasco, G. D. N., de Lima, R. A., Machado, A. R., & Papa, J. P. (2023). Deep Learning Semantic Segmentation Models for Detecting the Tree Crown Foliage. In VISIGRAPP (4: VISAPP) (pp. 143-150). https://www.scitepress.org/PublishedPapers/2023/116046/116046.pdf

2. Woo, S., Park, J., Lee, JY., Kweon, I.S. (2018). CBAM: Convolutional Block Attention Module. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds) Computer Vision - ECCV 2018. ECCV 2018. Lecture Notes in Computer Science(), vol 11211. Springer, Cham. https://doi.org/10.1007/978-3-030-01234-2_1

### Model Architecture

<p align="center">
  <img src="https://camo.githubusercontent.com/c40a3febddbb349098cf67e237a46f09489a098907772edc30619877f2980039/68747470733a2f2f64726976652e676f6f676c652e636f6d2f75633f6578706f72743d766965772669643d31307532665a6c2d4f4761364a45435f433852493038576933484356574f727057" alt="model architecture">
</p>

### CBAM Module Architecture

<p align="center">
  <img src="https://media.springernature.com/full/springer-static/image/chp%3A10.1007%2F978-3-030-01234-2_1/MediaObjects/474212_1_En_1_Fig1_HTML.gif?as=webp" alt="cbam architecture">
</p>

### Submodules of CBAM

<p align="center">
  <img src="https://media.springernature.com/full/springer-static/image/chp%3A10.1007%2F978-3-030-01234-2_1/MediaObjects/474212_1_En_1_Fig2_HTML.gif?as=webp" alt="cbam submodules">
</p>

### Setup

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, Input
from tensorflow.keras.layers import Activation, Concatenate, Conv2D, Multiply

### Attention Module

In [2]:
"""
Implementation of CBAM: Convolutional Block Attention Module in the TensorFlow 2.5.
Paper: https://arxiv.org/pdf/1807.06521
Code: https://github.com/nikhilroxtomar/Attention-Mechanism-Implementation/blob/main/TensorFlow/cbam.py
"""

def channel_attention_module(x, ratio=8):
    batch, _, _, channel = x.shape

    ## Shared layers
    l1 = Dense(channel//ratio, activation="relu", use_bias=False)
    l2 = Dense(channel, use_bias=False)

    ## Global Average Pooling
    x1 = GlobalAveragePooling2D()(x)
    x1 = l1(x1)
    x1 = l2(x1)

    ## Global Max Pooling
    x2 = GlobalMaxPooling2D()(x)
    x2 = l1(x2)
    x2 = l2(x2)

    ## Add both the features and pass through sigmoid
    feats = x1 + x2
    feats = Activation("sigmoid")(feats)
    feats = Multiply()([x, feats])

    return feats

def spatial_attention_module(x):
    ## Average Pooling
    x1 = tf.reduce_mean(x, axis=-1)
    x1 = tf.expand_dims(x1, axis=-1)

    ## Max Pooling
    x2 = tf.reduce_max(x, axis=-1)
    x2 = tf.expand_dims(x2, axis=-1)

    ## Concatenat both the features
    feats = Concatenate()([x1, x2])
    ## Conv layer
    feats = Conv2D(1, kernel_size=7, padding="same", activation="sigmoid")(feats)
    feats = Multiply()([x, feats])

    return feats

def cbam(x):
    x = channel_attention_module(x)
    x = spatial_attention_module(x)
    return x

### Depthwise Block

In [3]:
def depthwise_block(input, num_filters):
  d1 = tf.keras.layers.DepthwiseConv2D((3,3), padding='same', depthwise_initializer='he_normal')(input)
  d1 = tf.keras.layers.Conv2D(num_filters, (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal')(d1)
  d1 = cbam(d1)
  d1 = tf.keras.layers.BatchNormalization()(d1)
  d1 = tf.keras.layers.Activation('relu')(d1)

  return d1

### Conv2D Block

In [10]:
def conv2d_block(input, num_filters):
  d1 = depthwise_block(input, num_filters)
  d2 = depthwise_block(d1, num_filters)

  shortcut = tf.keras.layers.Conv2D(num_filters, (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal')(input)
  shortcut = tf.keras.layers.BatchNormalization()(shortcut)

  out = tf.keras.layers.add([shortcut, d2])

  return out

In [5]:
inputs = Input(shape=(128, 128, 3))
conv2d_block(inputs, 64).shape

(None, 128, 128, 64)
(None, 128, 128, 64)
(None, 128, 128, 64)


TensorShape([None, 128, 128, 64])

### Contraction Path

In [6]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3

In [8]:

# Building the model
input = tf.keras.layers.Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
input = tf.keras.layers.Lambda(lambda x: x / 255.0)(input)

In [22]:
# contraction path
c1 = conv2d_block(input, 64)
p1 = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(c1)

c2 = conv2d_block(p1, 128)
p2 = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(c2)

c3 = conv2d_block(p2, 256)
p3 = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(c3)

c4 = conv2d_block(p3, 512)
p4 = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(c4)

b5 = conv2d_block(p4, 1024)
d5 = tf.keras.layers.Dropout(0.3)(b5)

<p align="center">
  <img src="https://camo.githubusercontent.com/c40a3febddbb349098cf67e237a46f09489a098907772edc30619877f2980039/68747470733a2f2f64726976652e676f6f676c652e636f6d2f75633f6578706f72743d766965772669643d31307532665a6c2d4f4761364a45435f433852493038576933484356574f727057" alt="model architecture">
</p>

In [36]:
# expansion path
u6 = tf.keras.layers.Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same')(b5)
u6 = tf.keras.layers.concatenate([u6, c4])
# u6 shape = (None, 28, 28, 1024)

u7 = depthwise_block(u6, 512)
u7 = tf.keras.layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(u7)
u7 = tf.keras.layers.concatenate([u7, c3])
# u7 shape = (None, 56, 56, 512)

u8 = depthwise_block(u7, 256)
u8 = tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(u8)
u8 = tf.keras.layers.concatenate([u8, c2])
# u7 shape = (None, 112, 112, 256)

u9 = depthwise_block(u8, 128)
u9 = tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(u9)
u9 = tf.keras.layers.concatenate([u9, c1])
# u9 shape = (None, 224, 224, 128)

u10 = depthwise_block(u9, 64)
u10 = tf.keras.layers.Conv2D(1, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(u10)
print(u10.shape)

(None, 224, 224, 1)
