In [None]:
# Copyright (c) 2020 ZZH

In [None]:
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, AveragePooling2D, MaxPool2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Flatten, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras import Model
import os
import numpy as np
import math

In [None]:
epochs = 150
lr = 0.1
batch_size = 128
REGULARIZER  = 0.0001
checkpoint_save_path =  './Model/SE_WRN/'
log_dir = os.path.join("Model","SE_WRN_logs")

In [None]:
#数据导入及数据增强
cifar10 = tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
mean = [125.307, 122.95, 113.865]  #np.mean()
std = [62.9932, 62.0887, 66.7048]  #np.std()
for i in range(3):
    x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
    x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]

DataGenTrain = tf.keras.preprocessing.image.ImageDataGenerator(
               rotation_range = 15,
               width_shift_range = 0.125,
               height_shift_range = 0.125,
               horizontal_flip = True,
               vertical_flip = False,
               shear_range=0.125,
               zoom_range = 0.125)
DataGenTrain.fit(x_train)

In [None]:
def scheduler(epoch):  #HTD(-6,3) with WarmingUp
    start = -6.0
    end = 3.0
    if epoch < 10:
        return 0.01 * epoch + 0.01
    return lr / 2.0 * (1 - math.tanh((end - start) * epoch / epochs + start))

In [None]:
def swish(x):
    return tf.nn.swish(x)

In [None]:
class SEBlock(Model):
    def __init__(self,channels):
        super(SEBlock,self).__init__()
        self.channels = channels
        self.p1 = GlobalAveragePooling2D()
        self.d1 = Dense(channels//16,activation=None,kernel_initializer="he_normal",
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.d2 = Dense(channels,activation='sigmoid',kernel_initializer="he_normal",
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.m1 = tf.keras.layers.Multiply()
    def call(self,inputs):
        x = self.p1(inputs)
        x = self.d1(x)
        x = swish(x)
        y = self.d2(x)
        y = tf.reshape(y, [-1,1,1,self.channels])
        outputs = self.m1([inputs,y])
        return outputs

In [None]:
class WideResnetBlock(Model):
    def __init__(self,channels,k,strides,increase):
        super(WideResnetBlock,self).__init__()
        self.increase = increase
        self.b1 = BatchNormalization(momentum=0.9)
        self.c1 = Conv2D(filters=channels*k, kernel_size=3, strides=strides, padding='same',
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b2 = BatchNormalization(momentum=0.9)  
        self.d1 = Dropout(0.2)
        self.c2 = Conv2D(filters=channels*k, kernel_size=3, strides=1, padding='same', 
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.se = SEBlock(channels=channels*k)
        if self.increase:
            self.c3 = Conv2D(filters=channels*k, kernel_size=1, strides=strides, padding='same', 
                             kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))

    def call(self,inputs):
        x = self.b1(inputs)
        x = swish(x)
        x = self.c1(x)
        x = self.b2(x)
        x = swish(x)
        x = self.d1(x)
        x = self.c2(x)
        if self.increase:
            proj = self.c3(inputs)
        else :
            proj = inputs
        scale = self.se(x)
        outputs = scale + proj
        return outputs

In [None]:
class WideResNet(Model):
    def __init__(self,depth,k):
        super(WideResNet,self).__init__()
        N = (depth - 4)//6
        channels = [16,32,64]
        self.c1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same', use_bias=False,
                         kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        self.b1 = BatchNormalization(momentum=0.9)
        
        self.blocks = Sequential()
        self.blocks.add(WideResnetBlock(channels=channels[0],k=k,strides=1,increase=True))
        for _ in range(1,N):
            self.blocks.add(WideResnetBlock(channels=channels[0],k=k,strides=1,increase=False))
            
        self.blocks.add(WideResnetBlock(channels=channels[1],k=k,strides=2,increase=True))
        for _ in range(1,N):
            self.blocks.add(WideResnetBlock(channels=channels[1],k=k,strides=1,increase=False))
            
        self.blocks.add(WideResnetBlock(channels=channels[2],k=k,strides=2,increase=True))
        for _ in range(1,N):
            self.blocks.add(WideResnetBlock(channels=channels[2],k=k,strides=1,increase=False))
            
        self.b2 = BatchNormalization(momentum=0.9)
        self.p1 = GlobalAveragePooling2D()
        self.f1 = Dense(10,activation='softmax',kernel_initializer="he_normal",
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
        
    def call(self,inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = swish(x)
        x = self.blocks(x)
        x = self.b2(x)
        x = swish(x)
        x = self.p1(x)
        y = self.f1(x)
        return y

In [None]:
model = WideResNet(depth=28,k=10)

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9, nesterov=True, clipnorm=2.),
              loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
              metrics=['accuracy'])

callbacks = [
            tf.keras.callbacks.LearningRateScheduler(scheduler),  #学习率衰减表
            #tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.1, min_lr=0.0001, patience=10, cooldown=0)
            tf.keras.callbacks.ModelCheckpoint(     #模型保存
                filepath = checkpoint_save_path,
                save_weights_only = False,
                monitor = 'val_accuracy',
                save_best_only = True),
#             tf.keras.callbacks.EarlyStopping(       #早停
#                 monitor = 'val_accuracy',
#                 patience=15, 
#                 baseline=None),
            tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=False)  #保存计算图
]

hist = model.fit(DataGenTrain.flow(x_train,y_train,batch_size=batch_size,shuffle=True),
                 epochs=epochs,
                 validation_data=(x_test,y_test),
                 validation_freq=1,
                 callbacks=callbacks)

model.summary()

In [None]:
#结果可视化
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use({'figure.figsize':(6,4)})

plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.legend()
plt.show()
plt.plot(hist.history['val_accuracy'], label='val_accuracy')
plt.legend()
plt.show()

In [None]:
model.save('./')  # get .pb file

In [None]:
#tensorboard可视化
#!tensorboard --logdir=./Model/SE_WRN_logs
#http://localhost:6006/

In [None]:
print('best result: {:.2f}%  ({}epochs)'.format(100*max(hist.history['val_accuracy']),1+hist.history['val_accuracy'].index(max(hist.history['val_accuracy']))))
# best result : 96.60%  (140epoch)  with mixup