## 导入依赖包

In [1]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import Sequential, layers, losses, optimizers, datasets, regularizers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense
from tensorflow.keras.callbacks import LearningRateScheduler
import numpy as np

## 加载数据

通用图像分类公开的标准数据集常用的有CIFAR、ImageNet、COCO等，常用的细粒度图像分类数据集包括CUB-200-2011、Stanford Dog、Oxford-flowers等。其中ImageNet数据集规模相对较大，如模型概览一章所讲，大量研究成果基于ImageNet。ImageNet数据从2010年来稍有变化，常用的是ImageNet-2012数据集，该数据集包含1000个类别：训练集包含1,281,167张图片，每个类别数据732至1300张不等，验证集包含50,000张图片，平均每个类别50张图片。

由于ImageNet数据集较大，下载和训练较慢，为了方便大家学习，我们使用CIFAR10数据集。CIFAR10数据集包含60,000张32x32的彩色图片，10个类别，每个类包含6,000张。其中50,000张图片作为训练集，10000张作为测试集。
数据集采用稍复杂的Cifar-10。该数据集Tensorflow同样提供了官方的加载方式

In [2]:
(x, y), (x_test, y_test) = datasets.cifar10.load_data()

In [3]:
print(x.shape, y.shape, x_test.shape, y_test.shape)

(50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)


## 数据预处理并转化为Datasets

In [4]:
def preprocess(x, y):
    """
    预处理函数
    """
    x = tf.cast(x, dtype=tf.float32) / 255
    y = tf.cast(y, tf.int32)
    # [b, 1] => [b]
    y = tf.squeeze(y)
    # [b, 10]
    y = tf.one_hot(y, depth=10)
    return x,y

In [5]:
epoch_num = 50
batch_size = 128
weight_decay = 5e-4
learning_rate = 1e-2
dropout_rate = 0.5

# 转化为Dataset数据集
train_db = tf.data.Dataset.from_tensor_slices((x, y))
# map --> shuffle ---> batch: 先map，后batch，否则会出现ValueError: Cannot take the length of shape with unknown rank.
train_db = train_db.map(preprocess).shuffle(50000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

In [6]:
sample_train = next(iter(train_db))
print('batch_train:', sample_train[0].shape, sample_train[1].shape)
sample_test = next(iter(test_db))
print('batch_test:', sample_test[0].shape, sample_test[1].shape)

batch_train: (128, 32, 32, 3) (128, 10)
batch_test: (128, 32, 32, 3) (128, 10)


## 创建网络层
网上关于VGG论文的解读非常多，因此这里对网络结构和参数不多赘述，可以像下面这样简单的搭建好，由于我们所用的数据是Cifar-10，所以最终网络的输出维度设为10。并且超参数的设置遵循原文，即 weight_decay = 5e-4，dropout_rate = 0.5。

In [7]:
def VGG16():
    model = Sequential()
    
    # 第一个卷积层 （32, 32, 3） => （16, 16, 64）
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3), kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(MaxPooling2D((2, 2)))
    
    # 第二个卷积层 （16，16， 64）=> (8, 8, 128)
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(MaxPooling2D((2, 2)))
    
    # 第三个卷积层（8, 8，128）=> (4, 4, 256)
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(MaxPooling2D((2, 2)))
    
    # 第四个卷积层 （4,4,256）=> (2, 2, 512)
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(MaxPooling2D((2, 2)))
    
    # 第五个卷积层 （2,2,512）
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
    
    model.add(Flatten())  # 2*2*512
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10))
    
    return model 

## 查看网络结构

In [8]:
model = VGG16()
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 32, 32, 64)        1792      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 128)       73856     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 128)       147584    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 256)         2

## 动态学习率
下面介绍变学习率的设置方式。要用到的是model.fit中的callbacks参数，从参数名可以理解，我们需要写一个回调函数来实现学习率随训练轮数增加而减小。VGG原文中采用带动量的SGD，初始学习率为0.01，每次下降为原来的十分之一，这里我们让网络训练50个epoch，即epoch_num = 50，其中前20个采用0.01，中间20个采用0.001，最后10个采用0.0001

In [9]:
def scheduler(epoch):
    if epoch < epoch_num * 0.4:
        return learning_rate
    if epoch < epoch_num * 0.8:
        return learning_rate * 0.1
    return learning_rate * 0.01

sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
# 最后，在训练网络时将change_lr参数传入即可
change_lr = LearningRateScheduler(scheduler)

## 模型训练

In [10]:
model.compile(loss=losses.CategoricalCrossentropy(from_logits=True), optimizer=sgd, metrics=['accuracy'])

In [None]:
model.fit(
    train_db,
    epochs=5,
    validation_data=test_db,
    validation_freq=2,
    callbacks=[change_lr]
)