# Saving and Loading **parts** of TF models
- How do you save a trained model and load only part of it into a new graph?
- How do you load an entire trained model and train a new, modified version of it?

In [1]:
import tensorflow as tf
import numpy as np
from sklearn.datasets import fetch_mldata

We will first train a single layer neural network on the mnist dataset and save that to a file. Then we'll load that pretrained layer into a new graph with a 2-layer neural network, as the first layer. This network will then be fine-tuned on the data.

In [2]:
mnist = fetch_mldata("MNIST original", data_home="~/data/mldata")
mnist["data"] = mnist["data"].astype(np.float32) / 255.
print(mnist["data"].mean())
mnist

0.13092543


{'DESCR': 'mldata.org dataset: mnist-original',
 'COL_NAMES': ['label', 'data'],
 'target': array([0., 0., 0., ..., 9., 9., 9.]),
 'data': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}

In [3]:
# split the data into 2 halves
perm = np.random.permutation(mnist["target"].shape[0])
mnist_1 = {"data": mnist["data"][perm[0:35000]],
           "target": mnist["target"][perm[0:35000]]}
mnist_2 = {"data": mnist["data"][perm[35000:]],
           "target": mnist["target"][perm[35000:]]}

# split both into train/valid sets
perm_1 = np.random.permutation(mnist_1["target"].shape[0])
mnist_1_train = {"data": mnist_1["data"][perm_1[0:31500]],
                 "target": mnist_1["target"][perm_1[0:31500]]}
mnist_1_valid = {"data": mnist_1["data"][perm_1[31500:]],
                 "target": mnist_1["target"][perm_1[31500:]]}

perm_2 = np.random.permutation(mnist_2["target"].shape[0])
mnist_2_train = {"data": mnist_2["data"][perm_2[0:31500]],
                 "target": mnist_2["target"][perm_2[0:31500]]}
mnist_2_valid = {"data": mnist_2["data"][perm_2[31500:]],
                 "target": mnist_2["target"][perm_2[31500:]]}

In [4]:
def create_dataset(X_place, y_place, batch_size, num_epochs):
    dataset = tf.data.Dataset.from_tensor_slices((X_place, y_place))
    dataset = dataset.apply(
        tf.contrib.data.shuffle_and_repeat(5000, count=num_epochs))
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(lambda x, y: (x, tf.one_hot(tf.cast(y, tf.int32), 10)), 
                                      batch_size,
                                      num_parallel_batches=4))
    dataset = dataset.prefetch(4)
    return dataset    

## Shared variables
Variables have names that can be referenced when saving and restoring them. 
Since the `tf.train.Saver` object takes a list or dictionary of variables to save/restore, 
we can group the variables we want to share in a variable scope:
```python
with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):
```
The variable scope `shared` can be used accessed later via `tf.get_collection(..., scope="shared")`.

(**NOTE**: In this case, the `reuse` argument simply tells tensorflow to reuse variables that already 
exist with the same name. This is only important when calling create_model more than once in the
same graph, not for restoring variables)

In [5]:
def create_model(version, iterator, mode):
    """
    Returns ops for (pred_labels, loss, metrics)
    """
    x, y = iterator.get_next()
    
    with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):

        h = tf.layers.dense(x, 1000, activation=tf.nn.relu, name="dense_1")

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        if version == 2:
            h = tf.layers.dense(h, 500, activation=tf.nn.relu, name="dense_2")

        logits = tf.layers.dense(h, 10, activation=tf.nn.relu, name="logits")

    crossent = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)
    
    loss = tf.reduce_sum(crossent)

    if mode == "TRAIN":
        opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
        update = opt.minimize(loss, name="update_op")
        metrics = None
        metrics_updates = None
        pred_labels = None
    # metrics
    elif mode == "EVAL":
        update = None
        total_loss, loss_update = tf.metrics.mean(values=loss)
        pred_labels = tf.argmax(input=logits, axis=-1)
        tgt_labels = tf.argmax(input=y, axis=-1)
        acc, acc_update = tf.metrics.accuracy(predictions=pred_labels,
                                            labels=tgt_labels)

        # summaries
        tf.summary.scalar("validation_loss", total_loss, collections=["eval"])
        tf.summary.scalar("accuracy", acc, collections=["eval"])
        for var in tf.trainable_variables():
            tf.summary.histogram(var.name, var, collections=["eval"])
    
        metrics = [total_loss, acc]
        metrics_updates = [loss_update, acc_update]
    
    return pred_labels, loss, update, metrics, metrics_updates

## Saving and restoring
The `train` function contains the logic for both saving models and restoring
(some) variables from a pretrained model. In this case, we train with one
version of the model, then train again with a different version of the model.

If `save_path` is specified, then variables from the scope `shared` will be
loaded from the checkpoint:
```python
if save_path is not None:
    loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="shared"))
    # ...
    loader.restore(sess, save_path)
```
This works because `create_model` will name the variables the same way. 

**NOTE**: Restoring variables will initialize them, but any variables that are not restored
must be initialized as normal. Here, all variables are initialized, and any that are
restored are restored after this initialization. **If variables are initialized after they
are restored, they will be overwritten.**

In [6]:
def train(mnist_train, mnist_valid, logdir, save_path=None):
    X_placeholder = tf.placeholder(mnist_train["data"].dtype, shape=mnist_train["data"].shape)
    y_placeholder = tf.placeholder(mnist_train["target"].dtype, shape=mnist_train["target"].shape)
    train_dataset = create_dataset(X_placeholder, y_placeholder, 32, 3)
    train_iterator = train_dataset.make_initializable_iterator()

    version = 1 if save_path is None else 2
    
    _, _, train_op, _, _ = create_model(version, train_iterator, "TRAIN")

    valid_X_placeholder = tf.placeholder(mnist_valid["data"].dtype, shape=mnist_valid["data"].shape)
    valid_y_placeholder = tf.placeholder(mnist_valid["target"].dtype, shape=mnist_valid["target"].shape)
    valid_dataset = create_dataset(valid_X_placeholder, valid_y_placeholder, 64, 1)
    valid_iterator = valid_dataset.make_initializable_iterator()
    
    pred_y_op, loss_op, _, metrics, metrics_updates = create_model(version, valid_iterator, "EVAL")
    loss_update_op = metrics_updates[0]
    acc_update_op = metrics_updates[1]

    init_op = tf.global_variables_initializer()
    local_init_op = tf.local_variables_initializer()
    eval_summaries = tf.summary.merge_all("eval")

    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())
    
    if save_path is not None:
        loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="shared"))

    with tf.Session() as sess:
        sess.run([init_op])
        
        if save_path is not None:
            loader.restore(sess, save_path)
        
        sess.run(train_iterator.initializer, 
                 feed_dict={X_placeholder: mnist_train["data"],
                            y_placeholder: mnist_train["target"]})
        step = 0
        while True:
            try:
                step += 1
                _ = sess.run(train_op)
                if step % 50 == 0:
                    # VALIDATION
                    sess.run(valid_iterator.initializer,
                             feed_dict={valid_X_placeholder: mnist_valid["data"],
                                        valid_y_placeholder: mnist_valid["target"]})
                    sess.run([local_init_op])
                    while True:
                        try:
                            avg_loss, acc = sess.run([loss_update_op, acc_update_op])
                        except tf.errors.OutOfRangeError:
                            print("Step %d: Loss %3.3f, Accuracy: %.4f" % 
                                 (step, avg_loss, acc))
                            summ = sess.run(eval_summaries)
                            summary_writer.add_summary(summ, step)
                            break
            except tf.errors.OutOfRangeError:
                print("Training done.")
                summary_writer.flush()
                summary_writer.close()
                break

        # save the model
        if version == 1:
            ckpt_path = saver.save(sess, "/tmp/models/model.ckpt")
            return ckpt_path
    return None

In [7]:
# train the first model
ckpt_path = train(mnist_1_train, mnist_1_valid, "/tmp/models/shared/model1/")

INFO:tensorflow:Summary name shared/dense_1/kernel:0 is illegal; using shared/dense_1/kernel_0 instead.
INFO:tensorflow:Summary name shared/dense_1/bias:0 is illegal; using shared/dense_1/bias_0 instead.
INFO:tensorflow:Summary name model/logits/kernel:0 is illegal; using model/logits/kernel_0 instead.
INFO:tensorflow:Summary name model/logits/bias:0 is illegal; using model/logits/bias_0 instead.
Step 50: Loss 113.589, Accuracy: 0.4329
Step 100: Loss 96.301, Accuracy: 0.5286
Step 150: Loss 79.204, Accuracy: 0.5934
Step 200: Loss 76.312, Accuracy: 0.6126
Step 250: Loss 84.846, Accuracy: 0.5677
Step 300: Loss 76.181, Accuracy: 0.6097
Step 350: Loss 62.318, Accuracy: 0.6797
Step 400: Loss 53.678, Accuracy: 0.7489
Step 450: Loss 47.602, Accuracy: 0.8074
Step 500: Loss 44.714, Accuracy: 0.8209
Step 550: Loss 46.082, Accuracy: 0.8117
Step 600: Loss 30.508, Accuracy: 0.8340
Step 650: Loss 27.289, Accuracy: 0.8489
Step 700: Loss 27.659, Accuracy: 0.8446
Step 750: Loss 27.034, Accuracy: 0.8469


In [8]:
# now create a new, 2-layer network and load the weights from the first
tf.reset_default_graph()

train(mnist_2_train, mnist_2_valid, "/tmp/models/shared/model2", ckpt_path)

INFO:tensorflow:Summary name shared/dense_1/kernel:0 is illegal; using shared/dense_1/kernel_0 instead.
INFO:tensorflow:Summary name shared/dense_1/bias:0 is illegal; using shared/dense_1/bias_0 instead.
INFO:tensorflow:Summary name model/dense_2/kernel:0 is illegal; using model/dense_2/kernel_0 instead.
INFO:tensorflow:Summary name model/dense_2/bias:0 is illegal; using model/dense_2/bias_0 instead.
INFO:tensorflow:Summary name model/logits/kernel:0 is illegal; using model/logits/kernel_0 instead.
INFO:tensorflow:Summary name model/logits/bias:0 is illegal; using model/logits/bias_0 instead.
INFO:tensorflow:Restoring parameters from /tmp/models/model.ckpt
Step 50: Loss 52.512, Accuracy: 0.6817
Step 100: Loss 51.127, Accuracy: 0.6891
Step 150: Loss 45.964, Accuracy: 0.7423
Step 200: Loss 34.847, Accuracy: 0.7891
Step 250: Loss 34.135, Accuracy: 0.7911
Step 300: Loss 34.662, Accuracy: 0.7897
Step 350: Loss 34.207, Accuracy: 0.7917
Step 400: Loss 35.034, Accuracy: 0.7880
Step 450: Loss 3