In [None]:
# Copyright (c) 2020 ZZH

In [None]:
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, AveragePooling2D, MaxPool2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Flatten, Dense, DepthwiseConv2D
from tensorflow.keras.models import Sequential
from tensorflow.keras import Model
import os
import numpy as np
import math

In [None]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
epochs = 100
lr = 0.1
batch_size = 128
REGULARIZER  = 0.0001
checkpoint_save_path =  './Model/MobileNetV3/'
log_dir = os.path.join("Model","MobileNetV3_logs")

In [None]:
#数据导入及数据增强
cifar10 = tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
mean = [125.307, 122.95, 113.865]  #np.mean()
std = [62.9932, 62.0887, 66.7048]  #np.std()
for i in range(3):
    x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
    x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]

DataGenTrain = tf.keras.preprocessing.image.ImageDataGenerator(
               rotation_range = 15,
               width_shift_range = 0.1,
               height_shift_range = 0.1,
               horizontal_flip = True,
               vertical_flip = False,
               shear_range=0.1,
               zoom_range = 0.1)
DataGenTrain.fit(x_train)

In [None]:
def scheduler(epoch):  #HTD(-6,3) with WarmingUp
    start = -6.0
    end = 3.0
    if epoch < 5:
        return 0.02 * epoch + 0.02
    return lr / 2.0 * (1 - math.tanh((end - start) * epoch / epochs + start))

In [None]:
def relu6(x):
    return tf.keras.layers.ReLU(max_value=6)(x)

In [None]:
def hard_sigmoid(x):
    return relu6(x + 3.0) / 6.0

In [None]:
def hard_swish(x):
    return x * hard_sigmoid(x)

In [None]:
class BNAct(Model):
    def __init__(self,activation):
        super(BNAct,self).__init__()
        self.activation = activation
        self.bn = BatchNormalization(momentum=0.9)
    def call(self,inputs):
        x = self.bn(inputs)
        if self.activation == 'relu6':
            outputs = relu6(x)
        elif self.activation == 'hard_swish':
            outputs = hard_swish(x)
        else:
            raise ActivationError
        return outputs

In [None]:
class SEBlock(Model):
    def __init__(self,channels):
        super(SEBlock,self).__init__()
        ratio = 4
        self.channels = channels
        self.p1 = GlobalAveragePooling2D()
        self.d1 = Dense(channels//ratio,activation=None,kernel_initializer="he_normal",use_bias=False,
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.d2 = Dense(channels,activation=None,kernel_initializer="he_normal",use_bias=False,
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.m1 = tf.keras.layers.Multiply()
    def call(self,inputs):
        x = self.p1(inputs)
        x = self.d1(x)
        x = relu6(x)
        x = self.d2(x)
        x = hard_sigmoid(x)
        y = tf.reshape(x, [-1,1,1,self.channels])
        outputs = self.m1([inputs,y])
        return outputs

In [None]:
class MobileBlock(Model):
    def __init__(self,channels,exp,kernel_size,strides,SE,NL):
        super(MobileBlock,self).__init__()
        self.strides = strides
        self.SE = SE
        if NL == 'RE':
            activation = 'relu6'
        elif NL == 'HS':
            activation = 'hard_swish'
        self.c1 = Conv2D(filters=exp, kernel_size=1, strides=1, padding='same',use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b1 = BNAct(activation=activation)
        self.c2 = DepthwiseConv2D(kernel_size=kernel_size, strides=strides, padding='same', use_bias=False,
                                  depthwise_initializer="he_normal",depthwise_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b2 = BatchNormalization(momentum=0.9)
        self.se = SEBlock(channels=exp)
        self.c3 = Conv2D(filters=channels, kernel_size=1, strides=1, padding='same', use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b3 = BNAct(activation=activation)

    def call(self,inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.c2(x)
        x = self.b2(x)
        if self.SE:
            x = self.se(x)
        x = self.c3(x)
        outputs = self.b3(x)
        if self.strides == 1 and inputs.shape[-1] == outputs.shape[-1]:
            outputs += inputs
        return outputs

In [None]:
class MobileNetV3_Small(Model):
    def __init__(self):
        super(MobileNetV3_Small,self).__init__()
        self.c1 = Conv2D(filters=24, kernel_size=3, strides=1, padding='same', use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b1 = BNAct(activation='relu6')
        
        self.channels = [24,40,40,40,48,48,96,96,96]
        self.exp = [88,96,240,240,120,144,288,576,576]
        self.blocks = Sequential()
        self.blocks.add(MobileBlock(channels=self.channels[0],exp=self.exp[0],
                                    kernel_size=3,strides=1,SE=False,NL='RE'))
        self.blocks.add(MobileBlock(channels=self.channels[1],exp=self.exp[1],
                                    kernel_size=5,strides=2,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[2],exp=self.exp[2],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[3],exp=self.exp[3],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[4],exp=self.exp[4],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[5],exp=self.exp[5],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[6],exp=self.exp[6],
                                    kernel_size=5,strides=2,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[7],exp=self.exp[7],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[8],exp=self.exp[8],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        
        self.c2 = Conv2D(filters=576, kernel_size=1, strides=1, padding='same', use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b2 = BNAct(activation='hard_swish')
        self.p1 = GlobalAveragePooling2D()
        self.c3 = Conv2D(filters=1024, kernel_size=1, strides=1, padding='same', 
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.flatten = Flatten()
        self.f1 = Dense(10,activation='softmax',kernel_initializer="he_normal",
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
    def call(self,inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.blocks(x)
        x = self.c2(x)
        x = self.b2(x)
        x = self.p1(x)
        x = tf.reshape(x, [-1,1,1,576])
        x = self.c3(x)
        x = hard_swish(x)
        x = self.flatten(x)
        y = self.f1(x)
        return y

In [None]:
class MobileNetV3_Large(Model):
    def __init__(self):
        super(MobileNetV3_Large,self).__init__()
        self.c1 = Conv2D(filters=40, kernel_size=3, strides=1, padding='same', use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b1 = BNAct(activation='relu6')
        
        self.channels = [40,40,80,80,80,80,112,112,160,160,160]
        self.exp = [120,120,240,200,184,184,480,672,672,960,960]
        self.blocks = Sequential()
        self.blocks.add(MobileBlock(channels=self.channels[0],exp=self.exp[0],
                                    kernel_size=5,strides=1,SE=True,NL='RE'))
        self.blocks.add(MobileBlock(channels=self.channels[1],exp=self.exp[1],
                                    kernel_size=5,strides=1,SE=True,NL='RE'))
        self.blocks.add(MobileBlock(channels=self.channels[2],exp=self.exp[2],
                                    kernel_size=3,strides=2,SE=False,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[3],exp=self.exp[3],
                                    kernel_size=3,strides=1,SE=False,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[4],exp=self.exp[4],
                                    kernel_size=3,strides=1,SE=False,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[5],exp=self.exp[5],
                                    kernel_size=3,strides=1,SE=False,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[6],exp=self.exp[6],
                                    kernel_size=3,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[7],exp=self.exp[7],
                                    kernel_size=3,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[8],exp=self.exp[8],
                                    kernel_size=5,strides=2,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[9],exp=self.exp[9],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        self.blocks.add(MobileBlock(channels=self.channels[10],exp=self.exp[10],
                                    kernel_size=5,strides=1,SE=True,NL='HS'))
        
        self.c2 = Conv2D(filters=960, kernel_size=1, strides=1, padding='same', use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b2 = BNAct(activation='hard_swish')
        self.p1 = GlobalAveragePooling2D()
        self.c3 = Conv2D(filters=1280, kernel_size=1, strides=1, padding='same', 
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.flatten = Flatten()
        self.f1 = Dense(10,activation='softmax',kernel_initializer="he_normal",
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
    def call(self,inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.blocks(x)
        x = self.c2(x)
        x = self.b2(x)
        x = self.p1(x)
        x = tf.reshape(x, [-1,1,1,960])
        x = self.c3(x)
        x = hard_swish(x)
        x = self.flatten(x)
        y = self.f1(x)
        return y

In [None]:
model = MobileNetV3_Large()

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9, nesterov=True, clipnorm=2.),
              loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
              metrics=['accuracy'])

callbacks = [
            tf.keras.callbacks.LearningRateScheduler(scheduler),  #学习率衰减表
            #tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.1, min_lr=0.0001, patience=10, cooldown=0)
            tf.keras.callbacks.ModelCheckpoint(     #模型保存
                filepath = checkpoint_save_path,
                save_weights_only = False,
                monitor = 'val_accuracy',
                save_best_only = True),
#             tf.keras.callbacks.EarlyStopping(       #早停
#                 monitor = 'val_accuracy',
#                 patience=15, 
#                 baseline=None),
            tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=False)  #保存计算图
]

hist = model.fit(DataGenTrain.flow(x_train,y_train,batch_size=batch_size,shuffle=True),
                 epochs=epochs,
                 validation_data=(x_test,y_test),
                 validation_freq=1,
                 callbacks=callbacks)

model.summary()

In [None]:
#结果可视化
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use({'figure.figsize':(6,4)})

plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.legend()
plt.show()
plt.plot(hist.history['val_accuracy'], label='val_accuracy')
plt.legend()
plt.show()

In [None]:
#tensorboard可视化
#!tensorboard --logdir=./Model/MobileNetV3_logs
#http://localhost:6006/

In [None]:
print('best result: {:.2f}%  ({}epochs)'.format(100*max(hist.history['val_accuracy']),1+hist.history['val_accuracy'].index(max(hist.history['val_accuracy']))))
# best result: 94.85%  (95epochs)