In [None]:
# Copyright (c) 2020 ZZH

In [None]:
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, AveragePooling2D, MaxPool2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
import os
import numpy as np

In [None]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
epochs = 100
lr = 0.1
batch_size = 128
REGULARIZER  = 0.0001
checkpoint_save_path =  './Model/NIN/'

In [None]:
#数据导入及数据增强
cifar10 = tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
mean = [125.307, 122.95, 113.865]  #np.mean()
std = [62.9932, 62.0887, 66.7048]  #np.std()
for i in range(3):
    x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
    x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]
DataGenTrain = tf.keras.preprocessing.image.ImageDataGenerator(
               rotation_range = 15,
               width_shift_range = 0.1,
               height_shift_range = 0.1,
               horizontal_flip = True,
               vertical_flip = False,
               shear_range=0.1,
               zoom_range = 0.1)
DataGenTrain.fit(x_train)

In [None]:
#学习率衰减表
def scheduler(epoch):
    if epoch <= 25:
        return 0.1
    if epoch <= 50:
        return 0.05
    if epoch <= 75:    
        return 0.01
    return 0.002

In [None]:
class ConvBNRelu(Model):
    def __init__(self,channels,kernel_size,strides,padding):
        super(ConvBNRelu,self).__init__()
        self.model = tf.keras.models.Sequential([
                    Conv2D(filters=channels, kernel_size=kernel_size,strides=strides, padding=padding,
                           kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER)),
                    BatchNormalization(momentum=0.9),
                    Activation('relu') ])
    def call(self,inputs):
        outputs = self.model(inputs)
        return outputs

In [None]:
#网络搭建及训练
class NIN(Model):
    def __init__(self):
        super(NIN,self).__init__()
        self.model = tf.keras.models.Sequential()
        self.model.add(ConvBNRelu(channels=192,kernel_size=5,strides=1, padding='same'))
        self.model.add(ConvBNRelu(channels=160,kernel_size=1,strides=1, padding='same'))
        self.model.add(ConvBNRelu(channels=96,kernel_size=1,strides=1, padding='same'))
        self.model.add(MaxPool2D(pool_size=3,strides=2,padding='same'))
        self.model.add(Dropout(0.5))
        
        self.model.add(ConvBNRelu(channels=192,kernel_size=5,strides=1, padding='same'))
        self.model.add(ConvBNRelu(channels=192,kernel_size=1,strides=1, padding='same'))
        self.model.add(ConvBNRelu(channels=192,kernel_size=1,strides=1, padding='same'))
        self.model.add(MaxPool2D(pool_size=3,strides=2,padding='same'))
        self.model.add(Dropout(0.5))
        
        self.model.add(ConvBNRelu(channels=192,kernel_size=3,strides=1, padding='same'))
        self.model.add(ConvBNRelu(channels=192,kernel_size=1,strides=1, padding='same'))
        self.model.add(ConvBNRelu(channels=10,kernel_size=1,strides=1, padding='same'))
        self.model.add(GlobalAveragePooling2D())
        self.model.add(Activation('softmax'))

    def call(self,inputs):
        y = self.model(inputs)
        return y

In [None]:
model = NIN()

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9, nesterov=True),
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

log_dir = os.path.join("Model","NIN_logs")
callbacks = [
            tf.keras.callbacks.LearningRateScheduler(scheduler),  #学习率衰减表
            tf.keras.callbacks.ModelCheckpoint(     #模型保存
                filepath = checkpoint_save_path,
                save_weights_only = False,
                monitor = 'val_accuracy',
                save_best_only = True),
#             tf.keras.callbacks.EarlyStopping(       #早停
#                 monitor = 'val_accuracy',
#                 patience=15, 
#                 baseline=None),
            tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=False)  #保存计算图
]

hist = model.fit(DataGenTrain.flow(x_train,y_train,batch_size=batch_size,shuffle=True),
                 epochs=epochs,
                 validation_data=(x_test,y_test),
                 validation_freq=1,
                 callbacks=callbacks)

model.summary()

In [None]:
#结果可视化
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use({'figure.figsize':(6,4)})

plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.legend()
plt.show()
plt.plot(hist.history['val_accuracy'], label='val_accuracy')
plt.legend()
plt.show()

In [None]:
print('best result: {:.2f}%  ({}epochs)'.format(100*max(hist.history['val_accuracy']),1+hist.history['val_accuracy'].index(max(hist.history['val_accuracy']))))

In [None]:
#tensorboard可视化
#!tensorboard --logdir=./Model/NIN_logs
#http://localhost:6006/