# 初始化

In [0]:
#@markdown - **挂载** 
from google.colab import drive
drive.mount('GoogleDrive')

In [0]:
# #@markdown - **卸载**
# !fusermount -u GoogleDrive

# 代码区

In [0]:
#@title 逻辑回归 { display-mode: "both" }
# logistic regression
import numpy as np
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.logging.set_verbosity(tf.logging.ERROR)

#@markdown - **只获取0或1的图像**
def extraction_fn(data): 
    index_list = []
    for idx in range(data.shape[0]):
        if data[idx] == 0 or data[idx] == 1:
            index_list.append(idx)
    return index_list

In [2]:
events_path = 'GoogleDrive/My Drive/Colab Notebooks/Tensorboard'
checkpoints_path = 'GoogleDrive/My Drive/Colab Notebooks/Checkpoints'

max_num_checkpoints = 3 #@param {type: "integer"}
num_classes = 2 #@param {type: "integer"}
batch_size = 64 #@param {type: "integer"}
num_epochs = 100 #@param {type: "integer"}

learning_rate = 5e-3 #@param {type: "number"}

print('The events_path is: ', events_path)
print('The checkpoints_path is: ', checkpoints_path)

The events_path is:  GoogleDrive/My Drive/Colab Notebooks/Tensorboard
The checkpoints_path is:  GoogleDrive/My Drive/Colab Notebooks/Checkpoints


In [3]:
#@markdown - **所需训练图像和标签的获取**

mnist = input_data.read_data_sets("sample_data/MNIST", reshape=True, one_hot=False)
data = {}
data['train_image'] = mnist.train.images
data['train_label'] = mnist.train.labels
data['test_image'] = mnist.test.images
data['test_label'] = mnist.test.labels

index_list_train = extraction_fn(data['train_label'])
index_list_test = extraction_fn(data['test_label'])

data['train_image'] = mnist.train.images[index_list_train]
data['train_label'] = mnist.train.labels[index_list_train]
data['test_image'] = mnist.test.images[index_list_test]
data['test_label'] = mnist.test.labels[index_list_test]

data['train_image_label'] = np.c_[data['train_image'], data['train_label']]
num_samples, num_features = data['train_image'].shape

Extracting sample_data/MNIST/train-images-idx3-ubyte.gz
Extracting sample_data/MNIST/train-labels-idx1-ubyte.gz
Extracting sample_data/MNIST/t10k-images-idx3-ubyte.gz
Extracting sample_data/MNIST/t10k-labels-idx1-ubyte.gz


In [0]:
#@markdown - **网络的设置**

with tf.name_scope('Inputs'):
    image_place = tf.placeholder(tf.float32, shape=[None, num_features], name='images')
    label_place = tf.placeholder(tf.int32, shape=[None,], name='labels')
    label_one_hot = tf.one_hot(label_place, depth=num_classes, axis=-1)

with tf.name_scope('Loss'):
    logits = tf.contrib.layers.fully_connected(inputs=image_place, num_outputs=num_classes, scope='fc')
    loss_tensor = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_one_hot), name='loss_tensor')

with tf.name_scope('Accuracy'):
    predition = tf.equal(tf.argmax(logits, 1), tf.arg_max(label_one_hot, 1))
    accuracy = tf.reduce_mean(tf.cast(predition, tf.float32), name='accuracy')

with tf.name_scope('Train'):
    train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_tensor)

In [0]:
#@markdown - **summary 的设置**
loss_sum = tf.summary.scalar('loss_summary', loss_tensor) # 关于loss的summary
acc_sum = tf.summary.scalar('acc_summary', accuracy) # 关于accuracy的summary
image_sum = tf.summary.image('input_images', tf.reshape(image_place, [-1, 28, 28, 1]), max_outputs=3) #选取3个输入图像展示

train_sum = tf.summary.merge_all()

In [0]:
#@markdown - **saver 的设置**
saver = tf.train.Saver(max_to_keep=max_num_checkpoints)
max_acc = 99.2 # 高于此精度的模型将被saved
min_cross = 0.2

In [7]:
#@markdown - **训练网络**

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    checkpoints_prefix = 'model.ckpt'

    writer = tf.summary.FileWriter(events_path, sess.graph)

    for epoch_num in range(num_epochs):
      
        #@markdown - **提供两种基于 mini-batch 的训练方式**
        # 该段用于整个batch的按序提取及训练，速度较慢，梯度波动较小
        # num_batches = int(num_samples/batch_size)
        # for batch_num in range(num_batches):
        #     index_list_start = batch_num*batch_size
        #     index_list_end = (batch_num+1)*batch_size
        #     image_batch = data['train_image'][index_list_start:index_list_end,:]
        #     label_batch = data['train_label'][index_list_start:index_list_end]
        #     batch_loss, batch_accuracy, _ = sess.run([loss_tensor, accuracy, train_op],
        #                                             feed_dict={image_place: image_batch, label_place: label_batch})
        #该段用于随机batch的训练，速度较快，梯度波动较大
        #--------------------------------------------------------------------------------------------------------
        np.random.shuffle(data['train_image_label'])
        image_batch = data['train_image_label'][:batch_size,:-1]
        label_batch = data['train_image_label'][:batch_size,-1]
        #--------------------------------------------------------------------------------------------------------

        batch_loss, batch_accuracy, _ = sess.run([loss_tensor, accuracy, train_op],
                                                feed_dict={image_place: image_batch, label_place: label_batch})
        if (epoch_num + 1) % 5 == 0 or (epoch_num + 1) == 1:
            batch_accuracy *= 100
            print("Epoch " + str(epoch_num + 1) + ", Training Loss is " + \
                  "{:.5f}, ".format(batch_loss) + "batch_accuracy is " + "{:.2f}%".format(batch_accuracy))

        if (batch_loss <= min_cross) & (batch_accuracy > max_acc): # 按照要求保存网络模型
            min_cross = batch_loss
            max_acc = batch_accuracy
            saver.save(sess, os.path.join(checkpoints_path, checkpoints_prefix), global_step=epoch_num+1)
            print("Model restored...")
        rs = sess.run(train_sum, feed_dict={image_place: image_batch, label_place: label_batch})
        writer.add_summary(rs, epoch_num)
    
    # 测试集精度
    test_accuracy = sess.run([accuracy], feed_dict={image_place: data['test_image'],
                                                    label_place: data['test_label']})
    test_acc = test_accuracy[0]*100
    print("Final Test Accuracy is %.2f%%" % test_acc)
writer.close()
sess.close()

Epoch 1, Training Loss is 0.61157, batch_accuracy is 78.12%
Epoch 5, Training Loss is 0.21343, batch_accuracy is 96.88%
Epoch 10, Training Loss is 0.06927, batch_accuracy is 100.00%
Model restored...
Epoch 15, Training Loss is 0.03819, batch_accuracy is 100.00%
Epoch 20, Training Loss is 0.02246, batch_accuracy is 100.00%
Epoch 25, Training Loss is 0.02498, batch_accuracy is 100.00%
Epoch 30, Training Loss is 0.01780, batch_accuracy is 100.00%
Epoch 35, Training Loss is 0.02615, batch_accuracy is 98.44%
Epoch 40, Training Loss is 0.00894, batch_accuracy is 100.00%
Epoch 45, Training Loss is 0.01427, batch_accuracy is 100.00%
Epoch 50, Training Loss is 0.00615, batch_accuracy is 100.00%
Epoch 55, Training Loss is 0.01431, batch_accuracy is 100.00%
Epoch 60, Training Loss is 0.00532, batch_accuracy is 100.00%
Epoch 65, Training Loss is 0.00850, batch_accuracy is 100.00%
Epoch 70, Training Loss is 0.00365, batch_accuracy is 100.00%
Epoch 75, Training Loss is 0.01522, batch_accuracy is 100