In [None]:
# Copyright (c) 2020 ZZH

In [None]:
# 空洞卷积(DilatedConvolution)更多适用于语义分割，这里仅进行了功能实现，代替下采样
# 借助residual结构，并将内部替换为Hybrid DilatedConvolution模块，并不断堆叠

In [None]:
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, AveragePooling2D, MaxPool2D
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Flatten, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras import Model
import os
import numpy as np
import math

In [None]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
epochs = 100
lr = 0.1
batch_size = 128
REGULARIZER  = 0.0001
checkpoint_save_path =  './Model/DilatedConvolution/'

In [None]:
#数据导入及数据增强
cifar10 = tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
mean = [125.307, 122.95, 113.865]  #np.mean()
std = [62.9932, 62.0887, 66.7048]  #np.std()
for i in range(3):
    x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
    x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]

DataGenTrain = tf.keras.preprocessing.image.ImageDataGenerator(
               rotation_range = 15,
               width_shift_range = 0.1,
               height_shift_range = 0.1,
               horizontal_flip = True,
               vertical_flip = False,
               shear_range=0.1,
               zoom_range = 0.1)
DataGenTest = tf.keras.preprocessing.image.ImageDataGenerator()
DataGenTrain.fit(x_train)
DataGenTest.fit(x_test)

In [None]:
def scheduler(epoch):  #HTD(-6,3)衰减
    start = -6.0
    end = 3.0
    return lr / 2.0 * (1- math.tanh( (end-start)*epoch/epochs + start))

In [None]:
class BNRelu(Model):
    def __init__(self):
        super(BNRelu,self).__init__()
        self.bn = BatchNormalization(momentum=0.9)
        self.relu = Activation('relu')
    def call(self,inputs):
        x = self.bn(inputs)
        outputs = self.relu(x)
        return outputs

In [None]:
class DilatedConv(Model):
    def __init__(self,channels,kernel_size,strides,dilation_rate,padding):
        super(DilatedConv,self).__init__()
        self.model = Conv2D(filters=channels, kernel_size=kernel_size,strides=strides, padding=padding,
                           dilation_rate=dilation_rate,
                           kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
    def call(self,inputs):
        outputs = self.model(inputs)
        return outputs

In [None]:
class DilatedConvBNRelu(Model):
    def __init__(self,channels,kernel_size,strides,dilation_rate,padding):
        super(DilatedConvBNRelu,self).__init__()
        self.DC = DilatedConv(channels=channels, kernel_size=kernel_size,strides=strides,dilation_rate=dilation_rate,padding=padding)
        self.b = BNRelu()                                                 
    def call(self,inputs):
        x = self.DC(inputs)
        outputs = self.b(x)
        return outputs

In [None]:
class HybridDilatedConv(Model):
    def __init__(self,channels):
        super(HybridDilatedConv,self).__init__()
        self.HDC1 = DilatedConvBNRelu(channels=channels,kernel_size=3,strides=1,dilation_rate=1,padding='same')
        self.HDC2 = DilatedConvBNRelu(channels=channels,kernel_size=3,strides=1,dilation_rate=2,padding='same')
        self.HDC3 = DilatedConv(channels=channels,kernel_size=3,strides=1,dilation_rate=3,padding='same')
        self.b = BNRelu()
        self.c = Conv2D(filters=channels, kernel_size=1,strides=1, padding='same',
                           kernel_initializer="he_normal",kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
    def call(self,inputs):
        x = self.HDC1(inputs)
        x = self.HDC2(x)
        x = self.HDC3(x)
        outputs = self.b(x + self.c(inputs))
        return outputs

In [None]:
class HDCNet(Model):
    def __init__(self):
        super(HDCNet,self).__init__()
        self.channels = [16,32,64,128,256]
        self.blocks = Sequential()
        for i in range(5):
            self.blocks.add(HybridDilatedConv(channels=self.channels[i]))
        self.p1 = GlobalAveragePooling2D()
        self.f1 = Dense(10,activation='softmax',kernel_initializer="he_normal",
                        kernel_regularizer=tf.keras.regularizers.l2(REGULARIZER))
    def call(self,inputs):
        x = self.blocks(inputs)
        x = self.p1(x)
        y = self.f1(x)
        return y

In [None]:
model = HDCNet()

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9, nesterov=True, clipnorm=2.),
              loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
              metrics=['accuracy'])

log_dir = os.path.join("Model","DilatedConvolution_logs")
callbacks = [
            tf.keras.callbacks.LearningRateScheduler(scheduler),  #学习率衰减表
            #tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.1, min_lr=0.0001, patience=10, cooldown=0)
            tf.keras.callbacks.ModelCheckpoint(     #模型保存
                filepath = checkpoint_save_path,
                save_weights_only = False,
                monitor = 'val_accuracy',
                save_best_only = True),
#             tf.keras.callbacks.EarlyStopping(       #早停
#                 monitor = 'val_accuracy',
#                 patience=15, 
#                 baseline=None),
            tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=False)  #保存计算图
]

hist = model.fit(DataGenTrain.flow(x_train,y_train,batch_size=batch_size,shuffle=True),
                 epochs=epochs,
                 validation_data=DataGenTest.flow(x_test,y_test,batch_size=1000),
                 validation_freq=1,
                 callbacks=callbacks)

model.summary()

In [None]:
#结果可视化
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use({'figure.figsize':(6,4)})

plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.legend()
plt.show()
plt.plot(hist.history['val_accuracy'], label='val_accuracy')
plt.legend()
plt.show()

In [None]:
#tensorboard可视化
#!tensorboard --logdir=./Model/DilatedConvolution_logs
#http://localhost:6006/

In [None]:
print('best result: {:.2f}%  ({}epochs)'.format(100*max(hist.history['val_accuracy']),1+hist.history['val_accuracy'].index(max(hist.history['val_accuracy']))))
# best result: 92.58%  (95epochs)  Baseline  No residual
#              93.22%  (86epochs)  with residual