In [1]:
import tensorflow as tf
import numpy as np
from datetime import datetime

def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

# Some training configs
n_epochs = 10
batch_size = 200

# Data input

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


# Get handles of the relevant ops
* To check what are the available operations
```
for op in tf.get_default_graph().get_operations():
    print(op.name)
```
For example
```
evaluate/Mean
train/GradientDescent
```

In [3]:
saver = tf.train.import_meta_graph("./my_model_final.ckpt.meta")
#for op in tf.get_default_graph().get_operations():
#    print(op.name)

In [4]:
# Placeholder for input data 
X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")

# Relevant ops for training/evaluate
accuracy = tf.get_default_graph().get_tensor_by_name("evaluate/accuracy:0")
training_op = tf.get_default_graph().get_operation_by_name("train/minimizeOptimizer")

# Continue the training on _exact_ pretrained model
* As the header says, the training process continues **exact** model that was trained earlier. 
* Change the batch size, but continue to train on the same data set. 
* Same prediction output size.

In [5]:
with tf.Session() as sess:
    # VERY IMPORTANT: Don't init() but instead restore from last model
    saver.restore(sess, "./my_model_final.ckpt")

    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
                                                y: mnist.test.labels})
        print(epoch, "Test accuracy:", accuracy_val)

    save_path = saver.save(sess, "./my_new_model_final.ckpt") 

INFO:tensorflow:Restoring parameters from ./my_model_final.ckpt
0 Test accuracy: 0.9732
1 Test accuracy: 0.9733
2 Test accuracy: 0.9734
3 Test accuracy: 0.9739
4 Test accuracy: 0.9732
5 Test accuracy: 0.9734
6 Test accuracy: 0.9729
7 Test accuracy: 0.9736
8 Test accuracy: 0.9736
9 Test accuracy: 0.9733


# Next, modify the old network
* Copy fewer number of layers than the full set: Only reuse `hidden_1, hidden_2` and **REPLACE** `hidden_3` by `hidden_3b` (eg. change the number of weights in that layers - effectively with new set of weights) 
    - Specify from which point in the last graph you would like to continue on. In this example, hidden_2 output, or hidden_3 output. 
    - Pay attention on `/Relu:0` as the output of Relu layer. 

* Only train particular layers (in the end) but freeze the lower layers

In [6]:
n_hidden4 = 15  # new layer
n_outputs = 10  # new layer
learning_rate = 0.001
n_epochs = 100

In [7]:
reset_graph()
saver = tf.train.import_meta_graph("./my_model_final.ckpt.meta")

# Placeholder for input data 
X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")

# Relevant ops for training/evaluate
accuracy = tf.get_default_graph().get_tensor_by_name("evaluate/accuracy:0")
training_op = tf.get_default_graph().get_operation_by_name("train/minimizeOptimizer")

In [8]:
# Add a new hidden layer on top
output_hidden_3 = tf.get_default_graph().get_tensor_by_name("dnn/hidden_3/Relu:0")
new_hidden4 = tf.layers.dense(output_hidden_3, n_hidden4, activation=tf.nn.relu, name="new_hidden4")
new_logits = tf.layers.dense(new_hidden4, n_outputs, name="new_outputs")

with tf.name_scope("new_loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=new_logits)
    loss = tf.reduce_mean(xentropy, name="loss")

with tf.name_scope("new_eval"):
    correct = tf.nn.in_top_k(new_logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

with tf.name_scope("new_train"):
    train_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="hidden[234]|new_outputs")
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    training_op = optimizer.minimize(loss, var_list=train_variables)

In [9]:
init = tf.global_variables_initializer()
new_saver = tf.train.Saver()

In [10]:
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
root_logdir = "tf_logs"
logdir = "{}/run-{}/".format(root_logdir, now)
loss_summary = tf.summary.scalar('loss', loss)
file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())

In [11]:
with tf.Session() as sess:
    init.run()
    saver.restore(sess, "./my_model_final.ckpt")

    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
                                                y: mnist.test.labels})
        if epoch%10==0: 
            print(epoch, "Test accuracy:", accuracy_val)
            summary_str = loss_summary.eval(feed_dict={X: X_batch, y: y_batch})  # Record training loss
            step = epoch * batch_size + iteration
            file_writer.add_summary(summary_str, step)

    save_path = new_saver.save(sess, "./my_new_model_final.ckpt")
file_writer.close()

INFO:tensorflow:Restoring parameters from ./my_model_final.ckpt
0 Test accuracy: 0.1845
10 Test accuracy: 0.841
20 Test accuracy: 0.8851
30 Test accuracy: 0.8976
40 Test accuracy: 0.9061
50 Test accuracy: 0.9103
60 Test accuracy: 0.9129
70 Test accuracy: 0.9146
80 Test accuracy: 0.9171
90 Test accuracy: 0.9183


### Check if the freezing was actually frozen

In [12]:
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

[<tf.Variable 'dnn/hidden_1/kernel:0' shape=(784, 100) dtype=float32_ref>,
 <tf.Variable 'dnn/hidden_1/bias:0' shape=(100,) dtype=float32_ref>,
 <tf.Variable 'dnn/hidden_2/kernel:0' shape=(100, 50) dtype=float32_ref>,
 <tf.Variable 'dnn/hidden_2/bias:0' shape=(50,) dtype=float32_ref>,
 <tf.Variable 'dnn/hidden_3/kernel:0' shape=(50, 30) dtype=float32_ref>,
 <tf.Variable 'dnn/hidden_3/bias:0' shape=(30,) dtype=float32_ref>,
 <tf.Variable 'dnn/outputs/kernel:0' shape=(30, 10) dtype=float32_ref>,
 <tf.Variable 'dnn/outputs/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'new_hidden4/kernel:0' shape=(30, 15) dtype=float32_ref>,
 <tf.Variable 'new_hidden4/bias:0' shape=(15,) dtype=float32_ref>,
 <tf.Variable 'new_outputs/kernel:0' shape=(15, 10) dtype=float32_ref>,
 <tf.Variable 'new_outputs/bias:0' shape=(10,) dtype=float32_ref>]

### Visualize the plot of the new checkpoint
* Both `new_eval` and `new_loss` are on top of `logits` (represented by `new_outputs`)
* `new_train` works on `new_loss`, `new_outputs` and `hidden_4`. 
![](img/node_io.png)

![Some note](img/tf_transfer.png)