# ResNeXt50

implemented by Tensorflow

Reference: 
+ [https://zhuanlan.zhihu.com/p/26276020](https://zhuanlan.zhihu.com/p/26276020)
+ [https://github.com/calmisential/ResNeXt_TensorFlow2](https://github.com/calmisential/ResNeXt_TensorFlow2)

![](https://i.imgur.com/yl3kvSM.png)


In [7]:
import numpy
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
import tensorflow as tf


## ResNeXt Unit (aka ResNext BottleNeck)

![](https://i.imgur.com/1e5Upbt.png)

In [8]:
# reference: https://github.com/calmisential/ResNeXt_TensorFlow2/blob/master/resnext_block.py

class ResNeXtUnit(tf.keras.Model):
    def __init__(self, filters, strides, cardinality=32):
        super().__init__()
        self.conv1x1_1 = layers.Conv2D(filters, 1, 1)
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.ReLU()

        self.conv3x3 = layers.Conv2D(filters, 3, strides, groups=cardinality, padding='same')
        self.bn2 = layers.BatchNormalization()
        self.relu2 = layers.ReLU()

        self.conv1x1_2 = layers.Conv2D(filters*2, 1, 1)
        self.bn3 = layers.BatchNormalization()
        self.relu3 = layers.ReLU()

        self.sortcut = models.Sequential([
            layers.Conv2D(filters*2, 1, strides,
                          padding='same'),
            layers.BatchNormalization()
        ])

    def call(self, input):
        x = input
        x = self.conv1x1_1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv3x3(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv1x1_2(x)
        x = self.bn3(x)
        x = self.relu3(x)

        shortcut = self.sortcut(input)
        return tf.nn.relu(layers.add([x, shortcut]))


## Stage

![](https://i.imgur.com/oU305Wc.png)

In [9]:
def stage(inputs, filters, strides, units):
    output = ResNeXtUnit(filters=filters, strides=strides)(inputs)

    for _ in range(1, units):
        output = ResNeXtUnit(filters=filters, strides=1)(output)
    
    return output


## ResNeXt50

In [10]:
def ResNeXt50(input_shape, outputs=1000):
    input = layers.Input(shape=input_shape)

    x = layers.BatchNormalization()(input)
    x = layers.Conv2D(64, 7, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)

    x = stage(x, units=3, filters=128, strides=1)
    x = stage(x, units=4, filters=256, strides=2)
    x = stage(x, units=6, filters=512, strides=2)
    x = stage(x, units=3, filters=1024, strides=2)
    
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(outputs, activation='softmax')(x)
    
    return models.Model(input, x)


## Preview

In [11]:
import gc
tf.keras.backend.clear_session()
gc.collect()

317

In [12]:
m = ResNeXt50((224, 224, 3), outputs=1000)
m.summary()


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
batch_normalization (BatchNo (None, 224, 224, 3)       12        
_________________________________________________________________
conv2d (Conv2D)              (None, 112, 112, 64)      9472      
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 64)      256       
_________________________________________________________________
activation (Activation)      (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 56, 56, 64)        0         
_________________________________________________________________
res_ne_xt_unit (ResNeXtUnit) (None, 56, 56, 256)       65792 