## 1. Save TensorFlow Trained Model

In [1]:
import tensorflow as tf
import numpy as np
import math
from tensorflow.examples.tutorials.mnist import input_data

# Reset the graph
tf.reset_default_graph()

# Load the MNIST database
mnist = input_data.read_data_sets('.', one_hot=True)

Extracting ./train-images-idx3-ubyte.gz
Extracting ./train-labels-idx1-ubyte.gz
Extracting ./t10k-images-idx3-ubyte.gz
Extracting ./t10k-labels-idx1-ubyte.gz


In [2]:
# The path for saving file
save_file = './train_model.ckpt'

# Training parameters
learning_rate = 0.001
n_input = 784 # MNIST input data
n_class = 10  # Total MNIST class
batch_size = 128
n_epochs = 100

# 1. Features and Labels
features = tf.placeholder(tf.float32, [None, n_input])
labels = tf.placeholder(tf.float32, [None, n_class])

# 2. Weights and bias
weights = tf.Variable(tf.random_normal([n_input, n_class]))
bias = tf.Variable(tf.random_normal([n_class]))

# 3. Logits - xW+b
logits = tf.add(tf.matmul(features, weights), bias)

# 4. Define cost and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

# 5. Caculate the accuracy 
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Create saver
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(1, n_epochs+1):
        total_batch = math.ceil(mnist.train.num_examples/batch_size)
        
        for batch in range(total_batch):
            batch_feature, batch_label = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={features:batch_feature, labels:batch_label})

        if epoch%10 == 0:
            valid_accuracy = sess.run(accuracy, 
                                      feed_dict={features:mnist.validation.images, 
                                                 labels:mnist.validation.labels})
            print('Epoch: {:<3} - Validation Accuracy: {}'.format(epoch, valid_accuracy))
    
    # Save the model
    saver.save(sess, save_file)
    print('Train model already saved')

Epoch: 10  - Validation Accuracy: 0.3034000098705292
Epoch: 20  - Validation Accuracy: 0.44040000438690186
Epoch: 30  - Validation Accuracy: 0.5296000242233276
Epoch: 40  - Validation Accuracy: 0.5824000239372253
Epoch: 50  - Validation Accuracy: 0.6287999749183655
Epoch: 60  - Validation Accuracy: 0.6593999862670898
Epoch: 70  - Validation Accuracy: 0.6826000213623047
Epoch: 80  - Validation Accuracy: 0.7016000151634216
Epoch: 90  - Validation Accuracy: 0.7174000144004822
Epoch: 100 - Validation Accuracy: 0.7293999791145325
Train model already saved


## 2.Load Trained model

In [3]:
# Rest the graph and previous computation
saver = tf.train.Saver()

# Launch the graph
with tf.Session() as sess:
    saver.restore(sess, save_file)
    
    test_accuracy = sess.run(accuracy, 
                             feed_dict={features:mnist.test.images,
                                        labels:mnist.test.labels})
    print("Test accuracy: {}".format(test_accuracy))

Test accuracy: 0.7297999858856201
