# SKNet

+ M = 2
+ kernel_size = 3
+ kernel_size = 5

Refernence:

+ [https://liaowc.github.io/blog/SKNet-structure/](https://liaowc.github.io/blog/SKNet-structure/)

![](https://i.imgur.com/HvOPnHS.png)

+ M：是分支數，也就是有幾種 kernel size。
+ G：是各分支的卷積層做分組卷積的分組數。
+ r： z 的維度為 d=max(C/r,L)d=max(C/r,L)，r 是控制用的比例（L 是 d 的最小值）。

In [8]:
import tensorflow as tf
from tensorflow.keras import layers, models

## Build Model

![](https://i.imgur.com/6AsDeo9.png)


In [9]:
class SKUnit(tf.keras.Model):
    def __init__(self, filters, strides, M=2, G=32, r=16, L=32):
        """ Constructor
        Args:
            M: the number of branchs.
            G: num of convolution groups.
            r: the ratio for compute d, the length of z.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super().__init__()
        self.M = M
        self.filters = filters

        self.convs = []  # 各分支的卷積層
        for i in range(M):
            self.convs.append(models.Sequential([
                layers.Conv2D(filters, 3+2*i, strides,
                              padding='same', groups=G),
                layers.BatchNormalization(),
                layers.Activation('relu'),
            ]))

        self.gap = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(max(filters//r, L))
        self.fcs = []
        for i in range(M):
          self.fcs.append(layers.Dense(filters))

    def call(self, input):
        # 計算不同分支的 U
        for i in range(self.M):
            feat = self.convs[i](input)
            feat = tf.expand_dims(feat, axis=-1)
            feats_U = feat if i == 0 else tf.concat([feats_U, feat], axis=-1)

        # feats_U (H, W, filters, M)

        # 對 U 做全局平均池化得到 s
        feats_s = self.gap(tf.reduce_sum(feats_U, axis=-1))

        # s 經過全連結層可以得到 z
        feats_Z = self.fc(feats_s)

        for i in range(self.M):
            fcs = self.fcs[i](feats_Z)
            att_vec = fcs if i == 0 else tf.concat([att_vec, fcs], axis=-1)

        att_vec = layers.Reshape((1, 1, self.filters, self.M))(att_vec)
        att_vec = tf.nn.softmax(att_vec, axis=-1)

        # att_vec (1, 1, filters, M)

        mul = tf.multiply(feats_U, att_vec)

        return tf.reduce_sum(mul, axis=-1)


In [10]:
class SKConv(tf.keras.Model):
    def __init__(self, filters, strides, M=2, G=32, r=16, L=32):
        super().__init__()
        self.filters = filters
        self.strides = strides

        #----------------------------- conv1x1_1 -----------------------------
        self.conv1x1_1 = layers.Conv2D(filters, 1, 1)
        self.bn1 = layers.BatchNormalization()

        #------------------------------ middle -------------------------------
        self.skunit = SKUnit(filters, strides, M, G, r, L)
        self.bn2 = layers.BatchNormalization()

        #----------------------------- conv1x1_2 -----------------------------
        self.conv1x1_2 = layers.Conv2D(filters*2, 1, 1)
        self.bn3 = layers.BatchNormalization()

    def build(self, input_shape):
        if input_shape[-1] != self.filters*2:
            self.shortcut = models.Sequential([
                layers.Conv2D(self.filters*2, 1, self.strides),
                layers.BatchNormalization()
            ])
        else:
            self.shortcut = models.Sequential()

    def call(self, input):
        x = input
        x = self.conv1x1_1(x)
        x = self.bn1(x)
        x = tf.nn.relu(x)

        x = self.skunit(x)
        x = self.bn2(x)

        x = self.conv1x1_2(x)
        x = self.bn3(x)
        x = tf.nn.relu(x)

        shortcut = self.shortcut(input)

        return tf.nn.relu(tf.add(x, shortcut))


In [11]:
def stage(input, filters, strides, repeat, M=2, G=32, r=16, L=32):
    x = SKConv(filters, strides, M, G, r, L)(input)

    for _ in range(1, repeat):
        x = SKConv(filters, 1, M, G, r, L)(x)

    return x


In [12]:
def SKNet(input_shape, outputs=10):
    input = layers.Input(shape=input_shape)
    x = layers.BatchNormalization()(input)
    x = layers.Conv2D(64, 7, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)

    x = stage(x, filters=128, strides=1, repeat=3)
    x = stage(x, filters=256, strides=2, repeat=4)
    x = stage(x, filters=512, strides=2, repeat=6)
    x = stage(x, filters=1024, strides=2, repeat=3)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(outputs, activation = 'softmax')(x)

    return models.Model(input, x)


In [13]:
import gc
tf.keras.backend.clear_session()
gc.collect()


6890

In [14]:
m = SKNet((224, 224, 3), outputs=1000)
m.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
batch_normalization (BatchNo (None, 224, 224, 3)       12        
_________________________________________________________________
conv2d (Conv2D)              (None, 112, 112, 64)      9472      
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 64)      256       
_________________________________________________________________
activation (Activation)      (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 56, 56, 64)        0         
_________________________________________________________________
sk_conv (SKConv)             (None, 56, 56, 256)       92320 