# 简单实现CNN

### 导入数据和库

使用数据集为MNIST，可以说入门必玩的一个数据集。下面导入需要的包。

In [1]:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

In [2]:
mnist = input_data.read_data_sets('mnist_data/', one_hot=True)
# 使用交互式Session
sess = tf.InteractiveSession()

Extracting mnist_data/train-images-idx3-ubyte.gz
Extracting mnist_data/train-labels-idx1-ubyte.gz
Extracting mnist_data/t10k-images-idx3-ubyte.gz
Extracting mnist_data/t10k-labels-idx1-ubyte.gz


看看数据集的shape

In [3]:
print('train images shape', mnist.train.images.shape)
print('train labels shape', mnist.train.labels.shape)
print('test images shape', mnist.test.images.shape)
print('test images shape', mnist.test.labels.shape)

train images shape (55000, 784)
train labels shape (55000, 10)
test images shape (10000, 784)
test images shape (10000, 10)


ok, 数据导入就完成了。下面开始搭建模型。

### 搭建模型

先弄一些方便的函数

In [4]:
def weights_variable(shape):
    # 因为是使用ReLU激活函数，使用截断正太噪声就避免参数的对称性
    initial = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1))
    return initial


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return initial


def conv2d(x, W):
    return tf.nn.conv2d(x, filter=W, strides=[1, 1, 1, 1], padding='SAME')


def max_pool2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

定义tensor

In [5]:
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, shape=[-1, 28, 28, 1])

构建网络模型

In [6]:
# 第一层conv
w_conv1 = weights_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool2x2(h_conv1)

# 第二层conv
w_conv2 = weights_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool2x2(h_conv2)

# 全连接层
# 需要flatten图像
w_fc1 = weights_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, shape=[-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

# dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=keep_prob)

# 全连接
w_fc2 = weights_variable([1024, 10])
b_fc2 = bias_variable([10])
h_fc2 = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
y = tf.nn.softmax(h_fc2)

定义损失函数

In [7]:
loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), axis=1))
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)

到这里就可以训练了，在此之前我们再定义一个accuracy

In [8]:
prediction = tf.equal(tf.argmax(y_, axis=1), tf.argmax(y, axis=1))
accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))

开始训练

In [9]:
def training():
    tf.global_variables_initializer().run()
    for step in range(2001):
        batch = mnist.train.next_batch(batch_size=50)
        train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
        if step % 100 == 0:
            train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
            print('After step %4d   accuracy is %8f' % (step, train_accuracy))

In [10]:
%time training()

After step    0   accuracy is 0.040000
After step  100   accuracy is 0.840000
After step  200   accuracy is 0.940000
After step  300   accuracy is 0.880000
After step  400   accuracy is 0.980000
After step  500   accuracy is 0.920000
After step  600   accuracy is 0.980000
After step  700   accuracy is 0.980000
After step  800   accuracy is 0.900000
After step  900   accuracy is 1.000000
After step 1000   accuracy is 0.960000
After step 1100   accuracy is 0.960000
After step 1200   accuracy is 0.960000
After step 1300   accuracy is 0.980000
After step 1400   accuracy is 0.960000
After step 1500   accuracy is 1.000000
After step 1600   accuracy is 0.940000
After step 1700   accuracy is 1.000000
After step 1800   accuracy is 1.000000
After step 1900   accuracy is 1.000000
After step 2000   accuracy is 0.980000
CPU times: user 19min 48s, sys: 2min 22s, total: 22min 11s
Wall time: 6min 54s


最后看看测试集的accuracy

In [11]:
print('Test Accuracy', accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

Test Accuracy 0.9768
