# SEResNeXt50

implemented by Tensorflow

SEResNeXt = ResNeXt + SE 模組

![](https://i.imgur.com/Mzhp8Bu.png)


In [22]:
from tensorflow.keras import layers, models
import tensorflow as tf

## SEBlock

Reference:

+ [https://github.com/titu1994/keras-squeeze-excite-network](https://github.com/titu1994/keras-squeeze-excite-network)

In [23]:
# reference https://github.com/titu1994/keras-squeeze-excite-network

class SEBlock(tf.keras.Model):
    def __init__(self, ratio=16):
        super().__init__()
        self.ratio = ratio
        self.gap = layers.GlobalAveragePooling2D()

    def build(self, input_shape):
        filters = input_shape[-1]
        self.reshape = layers.Reshape((1, 1, filters))
        self.fc1 = layers.Dense(
            filters // self.ratio, kernel_initializer='he_normal', use_bias=False, activation='relu')
        self.fc2 = layers.Dense(
            filters, kernel_initializer='he_normal', use_bias=False, activation='sigmoid')

    def call(self, input):
        x = self.gap(input)
        x = self.reshape(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return tf.multiply(x, input)


## SEResNeXt Unit (aka ResNext BottleNeck)

In [24]:
# reference: https://github.com/calmisential/ResNeXt_TensorFlow2/blob/master/resnext_block.py

class SEResNeXtUnit(tf.keras.Model):
    def __init__(self, filters, strides, cardinality=32):
        super().__init__()
        self.conv1x1_1 = layers.Conv2D(filters, 1, 1)
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.ReLU()

        self.conv3x3 = layers.Conv2D(
            filters, 3, strides, groups=cardinality, padding='same')
        self.bn2 = layers.BatchNormalization()
        self.relu2 = layers.ReLU()

        self.conv1x1_2 = layers.Conv2D(filters*2, 1, 1)
        self.bn3 = layers.BatchNormalization()
        self.relu3 = layers.ReLU()

        self.seBlock = SEBlock()

        self.sortcut = models.Sequential([
            layers.Conv2D(filters*2, 1, strides, padding='same'),
            layers.BatchNormalization()
        ])

    def call(self, input):
        x = input
        x = self.conv1x1_1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv3x3(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv1x1_2(x)
        x = self.bn3(x)
        x = self.relu3(x)

        # Add SEBlock here
        x = self.seBlock(x)

        shortcut = self.sortcut(input)
        return tf.nn.relu(tf.add(x, shortcut))


## Stage

![](https://i.imgur.com/oU305Wc.png)

In [25]:
def stage(inputs, filters, strides, units):
    output = SEResNeXtUnit(filters=filters, strides=strides)(inputs)

    for _ in range(1, units):
        output = SEResNeXtUnit(filters=filters, strides=1)(output)

    return output


## SEResNeXt50

In [26]:
def SEResNeXt50(input_shape, outputs=1000):
    input = layers.Input(shape=input_shape)

    x = layers.BatchNormalization()(input)
    x = layers.Conv2D(64, 7, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)

    x = stage(x, units=3, filters=128, strides=1)
    x = stage(x, units=4, filters=256, strides=2)
    x = stage(x, units=6, filters=512, strides=2)
    x = stage(x, units=3, filters=1024, strides=2)
    
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(outputs, activation='softmax')(x)
    
    return models.Model(input, x)


## Preview

In [27]:
import gc
tf.keras.backend.clear_session()
gc.collect()

30128

In [28]:
m = SEResNeXt50((224, 224, 3), outputs=1000)
m.summary()


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
batch_normalization (BatchNo (None, 224, 224, 3)       12        
_________________________________________________________________
conv2d (Conv2D)              (None, 112, 112, 64)      9472      
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 64)      256       
_________________________________________________________________
activation (Activation)      (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 56, 56, 64)        0         
_________________________________________________________________
se_res_ne_xt_unit (SEResNeXt (None, 56, 56, 256)       73984 