In [5]:
%load_ext autoreload
%autoreload 2

In [108]:
import tensorflow as tf
import unittest
import re
import numpy as np
from train_ops import create_train_ops

In [148]:
def assert_grad_bufs_zero():
    for buf in grad_bufs.items():
        val = sess.run(buf)[0]
        np.testing.assert_equal(val, np.array([0., 0.]))

In [125]:
inits = {}
inits['w1'] = np.array([10.0, 20.0]).astype(np.float32)
inits['w2'] = np.array([5.0, 10.0]).astype(np.float32)

scopes = ['update_scope', 'apply_scope']

In [158]:
tf.reset_default_graph()
sess = tf.Session()

vars = {}
losses = {}
for scope in scopes:
    with tf.variable_scope(scope):
        w1 = tf.Variable(inits['w1'], name='w1')
        w2 = tf.Variable(inits['w2'], name='w2')
        losses[scope] = w1 + 2 * w2
        vars[scope] = {'w1': w1, 'w2': w2}

o = tf.train.GradientDescentOptimizer(learning_rate=1)
update_ops, apply_ops = \
    create_train_ops(losses['update_scope'], o, 'update_scope', 'apply_scope')
    
sess.run(tf.global_variables_initializer())

In [159]:
"""
Check that no extra trainable variables have been introduced.
"""
assert(len(tf.trainable_variables()) == 4)

In [160]:
grad_bufs = {v.name: v for v in tf.global_variables() if 'grad_buf' in v.name}

In [161]:
"""
Check that the gradient buffers start out zero.
"""
assert_grad_bufs_zero()

In [162]:
sess.run(update_ops)

In [163]:
"""
Confirm that no changes have taken place to the trainable
variables yet in either scope.
"""
for scope in scopes:
    for var_name, var in vars[scope].items():
        val = sess.run(var)
        np.testing.assert_equal(val, inits[var_name])

In [164]:
"""
Confirm that the gradient buffers look reasonable.
"""
for buf_name, buf in grad_bufs.items():
    actual = sess.run(buf)
    if 'w1' in buf_name:
        expected = [1., 1.]
    elif 'w2' in buf_name:
        expected = [2., 2.]
    np.testing.assert_equal(actual, expected)

In [165]:
sess.run(update_ops)

In [166]:
"""
Confirm that the gradient buffers still look reasonable.
"""
for buf_name, buf in grad_bufs.items():
    actual = sess.run(buf)
    if 'w1' in buf_name:
        expected = [2., 2.]
    elif 'w2' in buf_name:
        expected = [4., 4.]
    np.testing.assert_equal(actual, expected)

In [167]:
sess.run(apply_ops)

In [168]:
"""
Check that gradient buffers have been zeroed.
"""
assert_grad_bufs_zero()

In [169]:
"""
Confirm that no changes have been made to the variables in update_scope.
"""
for var_name, var in vars['update_scope'].items():
    actual = sess.run(var)
    if 'w1' in var_name:
        expected = inits['w1']
    elif 'w2' in var_name:
        expected = inits['w2']
    np.testing.assert_equal(actual, expected)

In [170]:
"""
Confirm that changes _have_ been made to the variables in apply_scope.
"""

'\nConfirm that changes _have_ been made to the variables in apply_scope.\n'