Reference: 
- optimistic_restore: https://github.com/tensorflow/tensorflow/issues/1823, https://github.com/tensorflow/tensorflow/issues/312
- get_uninit_vars: http://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables

# Partial restore + Partial initialization

In [1]:
import tensorflow as tf

### Only A is saved and restored

In [2]:
tf.reset_default_graph()
a = tf.train.get_or_create_global_step()
b = tf.Variable(3)
sess = tf.Session()
saver = tf.train.Saver([a])

# Initialize
sess.run(tf.global_variables_initializer())

# Save
saver.save(sess, save_path='./restore_reinitialize')

print(sess.run([a, b]))

[0, 3]


In [3]:
def optimistic_restore(sess, save_path, restore_vars=None, name=None):
    # Reference: https://github.com/tensorflow/tensorflow/issues/312

    # Load checkpoint reader
    reader = tf.train.NewCheckpointReader(save_path)

    # All variable saved in checkpoint
    # Dict: name => shape
    # {'global_step': [],
    # 'resnet_v1_101/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta': [64], .. }
    saved_shapes = reader.get_variable_to_shape_map()

    # List of all names of global variables in current graph
    # Sort because variables in checkpoints are sorted by their names already.
    # [('global_step:0', 'global_step'),
    # ('resnet_v1_101/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta:0', .. ]
    var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
                        if var.name.split(':')[0] in saved_shapes])

    # Dict: name => variable
    # Key: 'Decoder/LSTM_initializer/Layer_0/fully_connected/biases'
    # Value: <tf.Variable 'Decoder/LSTM_initializer/Layer_0/fully_connected/biases:0' shape=(512,) dtype=float32_ref>
    name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables()))

    # List all global variables to restore if they are in checkpoint
    restore_vars = []
    with tf.variable_scope('', reuse=True):
        for var_name, saved_var_name in var_names:
            curr_var = name2var[saved_var_name]
            var_shape = curr_var.get_shape().as_list()
            if var_shape == saved_shapes[saved_var_name]:
                restore_vars.append(curr_var)

    # Restore variables
    saver = tf.train.Saver(restore_vars, name=name)
    saver.restore(sess, save_path)

In [4]:
tf.reset_default_graph()
a = tf.train.get_or_create_global_step()
b = tf.Variable(3)
sess = tf.Session()

optimistic_restore(sess, './restore_reinitialize')
try:
    print(sess.run([a, b]))
except Exception as e:
    print(e)

INFO:tensorflow:Restoring parameters from ./restore_reinitialize
Attempting to use uninitialized value Variable
	 [[Node: _retval_Variable_0_0 = _Retval[T=DT_INT32, index=0, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable)]]


### Initialize b

In [5]:
def get_uninit_vars(sess, variables=None):
    if variables is None:
        variables = tf.global_variables()
    init_flag = sess.run(
        tf.stack([tf.is_variable_initialized(v) for v in variables]))
    return [v for v, f in zip(variables, init_flag) if not f]

In [6]:
uninit_vars = get_uninit_vars(sess)
print(uninit_vars)
init_op = tf.variables_initializer(uninit_vars) 
sess.run(init_op)
print(sess.run(uninit_vars))

[<tf.Variable 'Variable:0' shape=() dtype=int32_ref>]
[3]


In [7]:
# Initialize
sess.run(tf.variables_initializer([a, b]))

print(sess.run([a, b]))

[0, 3]
