# SKNet

Refernence:

+ [https://liaowc.github.io/blog/SKNet-structure/](https://liaowc.github.io/blog/SKNet-structure/)

![](https://i.imgur.com/HvOPnHS.png)

+ M：是分支數，也就是有幾種 kernel size。
+ G：是各分支的卷積層做分組卷積的分組數。
+ r： z 的維度為 d=max(C/r,L)d=max(C/r,L)，r 是控制用的比例（L 是 d 的最小值）。

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

## Build Model

![](https://i.imgur.com/6AsDeo9.png)


In [2]:
class SKConv(tf.keras.Model):
    def __init__(self, filters, M, G, r, strides, L=32):
        """ Constructor
        Args:
            M: the number of branchs.
            G: num of convolution groups.
            r: the ratio for compute d, the length of z.
            strides: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super().__init__()
        self.M = M
        self.G = G
        self.r = r
        self.strides = strides
        self.L = L

        ##### conv1x1_1 #####
        self.conv1x1_1 = layers.Conv2D(filters, kernel_size=1, strides=strides)
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.ReLU()

        ##### middle #####
        self.convs = []  # 各分支的卷積層
        for i in range(self.M):
            self.convs.append(
                layers.Conv2D(filters, kernel_size=3, padding='same', groups=self.G,
                              name='group_conv_%d' % i, dilation_rate=i+1, use_bias=False),
            )

        self.gap = layers.GlobalAveragePooling2D()
        self.fc = models.Sequential([
            layers.Dense(max(filters//self.r, self.L)),
            layers.BatchNormalization(),
            layers.Activation('relu')
        ])  # Fuse 的全連結層

        self.fcs = []
        for _ in range(self.M):
            self.fcs.append(layers.Dense(filters))

        self.softmax = layers.Activation('softmax')

        ##### conv1x1_2 #####
        self.conv1x1_2 = layers.Conv2D(filters*2, 1, 1)
        self.bn3 = layers.BatchNormalization()
        self.relu3 = layers.ReLU()

        self.shortcut = models.Sequential([
            layers.Conv2D(filters*2, 1, strides, padding='same'),
            layers.BatchNormalization()
        ])

    def call(self, input):
        x = input
        x = self.conv1x1_1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        # U
        feats_U_list = []
        for i in range(len(self.convs)):
            feat = self.convs[i](x)
            feats_U_list.append(feat)
            if i == 0:
                feats_U = feat
            else:
                feats_U = tf.add(feats_U, feat)

        # 對 U 做平均池化可以得到 s
        feats_s = self.gap(feats_U)

        # s 經過全連結層，BatchNormalization及relu可以得到z
        feats_Z = self.fc(feats_s)

        # 各分支在對 z 做相乘，最後加總
        for i in range(len(self.fcs)):
            z = self.fcs[i](feats_Z)     # shape (batch, filters)
            z = tf.expand_dims(z, axis=1)  # shape (batch, 1, filters)
            z = tf.expand_dims(z, axis=1)  # shape (batch, 1, 1, filters)

            # feats_U_list[i] shape (batch, height, width, filters)
            mul = tf.multiply(z, feats_U_list[i])
            if i == 0:
                att_vec = mul
            else:
                att_vec = tf.add(att_vec, mul)

        # 加總結果對 channels 做 softmax
        att_vec_softmax = tf.nn.softmax(att_vec, axis=-1)

        x = self.conv1x1_2(att_vec_softmax)
        x = self.bn3(x)
        x = self.relu3(x)

        shortcut = self.shortcut(input)

        return tf.nn.relu(tf.add(x, shortcut))


In [3]:
def stage(input, filters, repeat, M, G, r, strides):
    x = input
    x = SKConv(filters, M, G, r, strides)(x)

    for _ in range(1, repeat):
        x = SKConv(filters, M, G, r, strides=1)(x)

    return x


In [4]:
def SKNet(input_shape, outputs=10):
    input = layers.Input(shape=input_shape)
    x = layers.BatchNormalization()(input)
    x = layers.Conv2D(64, 7, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)

    x = stage(x, filters=128, repeat=3, M=2, G=32, r=16, strides=1)
    x = stage(x, filters=256, repeat=4, M=2, G=32, r=16, strides=2)
    x = stage(x, filters=512, repeat=6, M=2, G=32, r=16, strides=2)
    x = stage(x, filters=1024, repeat=3, M=2, G=32, r=16, strides=2)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(outputs)(x)

    return models.Model(input, x)


In [5]:
m = SKNet((224, 224, 3), outputs=1000)
m.summary()


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
batch_normalization (BatchNo (None, 224, 224, 3)       12        
_________________________________________________________________
conv2d (Conv2D)              (None, 112, 112, 64)      9472      
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 64)      256       
_________________________________________________________________
activation (Activation)      (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 56, 56, 64)        0         
_________________________________________________________________
sk_conv (SKConv)             (None, 56, 56, 256)       82464 