## 保存&恢复模型
设置一个值为1000

如果经过1000次迭代，这1000次迭代准确率没有提高，则停止优化

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
import time
from datetime import timedelta
import math
import os

# Use PrettyTensor to simplify Neural Network construction.
import prettytensor as pt

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/', one_hot=True)

Extracting data/MNIST/train-images-idx3-ubyte.gz
Extracting data/MNIST/train-labels-idx1-ubyte.gz
Extracting data/MNIST/t10k-images-idx3-ubyte.gz
Extracting data/MNIST/t10k-labels-idx1-ubyte.gz


In [3]:
data.test.cls = np.argmax(data.test.labels, axis=1)
data.validation.cls = np.argmax(data.validation.labels, axis=1)

In [4]:
data.validation.cls[0:10]

array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4], dtype=int64)

In [5]:
# We know that MNIST images are 28 pixels in each dimension.
img_size = 28

# Images are stored in one-dimensional arrays of this length.
img_size_flat = img_size * img_size

# Tuple with height and width of images used to reshape arrays.
img_shape = (img_size, img_size)

# Number of colour channels for the images: 1 channel for gray-scale.
num_channels = 1

# Number of classes, one class for each of 10 digits.
num_classes = 10

In [6]:
def plot_images(images, cls_true, cls_pred=None):
    assert len(images) == len(cls_true) == 9
    
    fig, axes = plt.subplots(3, 3)
    
    fig.subplots_adjust(hspace=0.3, wspace=0.3)
    
    for i, ax in enumerate(axes.flat):
        
        ax.imshow(images[i].reshape(img_shape), cmap='binary')
        
        if cls_pred is None:
            xlabel = "True: {0}".format(cls_true[i])
        else:
            xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
            
        ax.set_xlabel(xlabel)
        
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.show()

In [7]:
# input
x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')

x_image = tf.reshape(x, shape=[-1, img_size, img_size, num_channels])

y_true = tf.placeholder(tf.float32, shape=[None, 10], name='y_true')

y_true_cls = tf.argmax(y_true, dimension=1)

In [8]:
x_pretty = pt.wrap(x_image)

In [9]:
with pt.defaults_scope(activation_fn=tf.nn.relu):
    y_pred, loss = x_pretty.\
        conv2d(kernel=5, depth=16, name='layer_conv1').\
        max_pool(kernel=2, stride=2).\
        conv2d(kernel=5, depth=36, name='layer_conv2').\
        max_pool(kernel=2, stride=2).\
        flatten().\
        fully_connected(size=128, name='layer_fc1').\
        softmax_classifier(num_classes=num_classes, labels=y_true)

In [10]:
def get_weights_variable(layer_name):
    # Retrieve an existing variable named 'weights' in the scope
    # with the given layer_name.
    # This is awkward because the TensorFlow function was
    # really intended for another purpose.

    with tf.variable_scope(layer_name, reuse=True):
        variable = tf.get_variable('weights')

    return variable

In [11]:
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

In [12]:
y_pred_cls = tf.argmax(y_pred, dimension=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

## Helper function

In [13]:
batch_size = 256

def predict_cls(images, labels, cls_true):
    
    num_images = len(images)
    
    cls_pred = np.zeros(shape=num_images, dtype=np.int)
    
    i = 0
    while i < num_images:
        
        j = min(i + batch_size, num_images)
        
        feed_dict = {x: images[i:j, :],
                     y_true: labels[i:j, :]}
        
        cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)
        
        i = j
    
    correct = (cls_true == cls_pred)
    
    return correct, cls_pred

In [14]:
def predict_cls_test():
    return predict_cls(images = data.test.images,
                       labels = data.test.labels,
                       cls_true = data.test.cls)

In [15]:
def predict_cls_validation():
    return predict_cls(images = data.validation.images,
                       labels = data.validation.labels,
                       cls_true = data.validation.cls)

In [16]:
# eg: correct = [True, False, True, True]
def cls_accuray(correct):
    
    correct_sum = correct.sum()
    
    acc = float(correct_sum) / len(correct)
    
    return acc, correct_sum

In [17]:
def validation_accuracy():
    
    correct, _ = predict_cls_validation()
    
    return cls_accuray(correct)

## Saver

In [18]:
saver = tf.train.Saver()

In [19]:
save_dir = 'checkpoints/'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [20]:
save_path = os.path.join(save_dir, 'best_validation')

In [21]:
session = tf.Session()

In [22]:
def init_variables():
    session.run(tf.global_variables_initializer())

In [23]:
init_variables()

In [24]:
train_batch_size = 64

best_validation_accuracy = 0.0

last_improvement = 0

require_improvement = 1000

In [34]:
total_iterations = 0

def optimize(num_iterations):
    
    global total_iterations
    global best_validation_accuracy
    global last_improvement
    
    start_time = time.time()
    
    for i in range(num_iterations):
        
        total_iterations += 1
        
        x_batch, y_true_batch = data.train.next_batch(train_batch_size)
        
        feed_dict_train = {x: x_batch,
                           y_true: y_true_batch}
        
        session.run(optimizer, feed_dict=feed_dict_train)
        
        # log every 100 or the last one
        if (total_iterations % 100 == 0) or (i == (num_iterations - 1)):
            
            # training-batch accuracy
            acc_train = session.run(accuracy, feed_dict=feed_dict_train)
            
            # validation-batch accuracy
            # 另外计算校验集的准确率
            acc_validation, _ = validation_accuracy()
            
            if acc_validation > best_validation_accuracy:
                
                best_validation_accuracy = acc_validation
                
                # 当前迭代次数
                last_improvement = total_iterations
                
                # 保存模型到文件里
                saver.save(sess=session, save_path=save_path)
                
                improved_str = '*'
            else:
                
                improved_str = ''
            # Status-message for printing.
            msg = "Iter: {0:>6}, Train-Batch Accuracy: {1:>6.1%}, Validation Acc: {2:>6.1%} {3}"

            # Print it.
            print(msg.format(i + 1, acc_train, acc_validation, improved_str))
        
        # If no improvement found in the required number of iterations.
        # 无法再优化了
        if total_iterations - last_improvement > require_improvement:
            print("No improvement found in a while, stopping optimization.")

            # Break out from the for-loop.
            break
    
    end_time = time.time()
    
    time_diff = end_time - start_time
    
    # Print the time-usage.
    print("Time usage: " + str(timedelta(seconds=int(round(time_diff)))))

In [26]:
def print_test_accuracy():
    
    correct, cls_pred = predict_cls_test()
    
    acc, num_correct = cls_accuray(correct)
    
    num_images = len(correct)
    
    # Print the accuracy.
    msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"
    print(msg.format(acc, num_correct, num_images))

In [27]:
print_test_accuracy()

Accuracy on Test-Set: 10.8% (1084 / 10000)


In [35]:
optimize(num_iterations=100)

Iter:    100, Train-Batch Accuracy:  92.2%, Validation Acc:  92.9% *
Time usage: 0:00:15


In [36]:
init_variables()

In [37]:
print_test_accuracy()

Accuracy on Test-Set: 10.6% (1056 / 10000)


再次初始化之后，准确率变低了，变成迭代之前的准确率

In [38]:
# 恢复之前保存的变量
saver.restore(sess=session, save_path=save_path)

INFO:tensorflow:Restoring parameters from checkpoints/best_validation


In [39]:
print_test_accuracy()
# 准确率恢复为90%以上

Accuracy on Test-Set: 93.1% (9309 / 10000)
