In [1]:
from tensorflow.keras.layers import Input, Conv2D, DepthwiseConv2D, \
     Dense, Concatenate, Add, ReLU, BatchNormalization, AvgPool2D, \
     MaxPool2D, GlobalAvgPool2D, Reshape, Permute, Lambda
from tensorflow.keras import Model

In [2]:
def stage(x, channels, repetitions, groups):
    x = shufflenet_block(x, channels=channels, strides=2, groups=groups)
    for i in range(repetitions):
        x = shufflenet_block(x, channels=channels, strides=1, groups=groups)
    return x

In [3]:
def shufflenet_block(tensor, channels, strides, groups):
    x = gconv(tensor, channels=channels//4 , groups=groups)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = channel_shuffle(x, groups)
    x = DepthwiseConv2D(kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)

    if strides == 2:
        channels = channels - tensor.get_shape().as_list()[-1]
    x = Conv2D(channels, kernel_size = 1, strides = (1,1),padding = 'same', groups=groups)(x)
    # x = gconv(x, channels=channels, groups=groups)
    x = BatchNormalization()(x)

    if strides == 1:
        x = Add()([tensor, x])
    else:
        avg = AvgPool2D(pool_size=3, strides=2, padding='same')(tensor)
        x = Concatenate()([avg, x])

    output = ReLU()(x)
    return output

In [4]:
def gconv(tensor, channels, groups):
    input_ch = tensor.get_shape().as_list()[-1]
    group_ch = input_ch // groups
    output_ch = channels // groups
    groups_list = []

    for i in range(groups):
        group_tensor = tensor[:, :, :, i * group_ch: (i+1) * group_ch]
        # group_tensor = Lambda(lambda x: x[:, :, :, i * group_ch: (i+1) * group_ch])(tensor)
        group_tensor = Conv2D(output_ch, 1)(group_tensor)
        groups_list.append(group_tensor)

    output = Concatenate()(groups_list)
    return output

In [5]:
def channel_shuffle(x, groups):  
    _, width, height, channels = x.get_shape().as_list()
    group_ch = channels // groups

    x = Reshape([width, height, group_ch, groups])(x)
    x = Permute([1, 2, 4, 3])(x) #transpose
    x = Reshape([width, height, channels])(x)
    return x

In [6]:
def build_shuffleNet():
    input = Input([224, 224, 3])
    #stage1
    x = Conv2D(filters=24, kernel_size=3, strides=2, padding='same')(input)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)
    repetitions = 3, 7, 3
    initial_channels = 240
    groups = 3

    for i, reps in enumerate(repetitions):
        channels = initial_channels * (2**i)
        x = stage(x, channels, reps, groups)
    x = GlobalAvgPool2D()(x)
    output = Dense(1000, activation='softmax')(x)
    model = Model(input, output)    
    return model

In [7]:
model=build_shuffleNet()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 112, 112, 24  672         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 112, 112, 24  96         ['conv2d[0][0]']                 
 alization)                     )                                                             

In [8]:
# from keras.utils import plot_model
# plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)