In [1]:
import tensorflow as tf
import numpy as np

tf.__version__

'1.14.0'

In [2]:
def _l2normalize(v, eps=1e-12):
    return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)

def spectral_norm(weights, name, num_iters=1, update_collection=None, with_sigma=False):
    w_shape = weights.shape.as_list()
    w_mat = tf.reshape(weights, [-1, w_shape[-1]])  # [-1, output_channel]
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        u = tf.get_variable('u', [1, w_shape[-1]],
                          initializer=tf.truncated_normal_initializer(),
                          trainable=False)
        print(u.name)
    u_ = u
    for _ in range(num_iters):
        v_ = _l2normalize(tf.matmul(u_, w_mat, transpose_b=True))
        u_ = _l2normalize(tf.matmul(v_, w_mat))

    sigma = tf.squeeze(tf.matmul(tf.matmul(v_, w_mat), u_, transpose_b=True))
    w_mat /= sigma
    if update_collection is None:
        with tf.control_dependencies([u.assign(u_)]):
            w_bar = tf.reshape(w_mat, w_shape)
            print('u is updated')
    else:
        w_bar = tf.reshape(w_mat, w_shape)
        print('u is NOT updated')
        if update_collection != 'NO_OPS':
            tf.add_to_collection(update_collection, u.assign(u_))
    if with_sigma:
        return w_bar, sigma
    else:
        return w_bar

In [3]:
class SNDense(tf.keras.layers.Layer):
    def __init__(self, units, kernel_initializer, bias_initializer, update_collection=None, **kwargs):
        super(SNDense, self).__init__(**kwargs)
        self.units = units
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.update_collection = update_collection
        
    
    def build(self, input_shape):
        '''
        input_shape = (batch_size, ..., input_dim)
        '''
        self.input_dim = int(input_shape[-1])
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            self.kernel = tf.get_variable("kernel", 
                                          shape=(self.input_dim, self.units), 
                                          initializer=self.kernel_initializer)
            self.bias = tf.get_variable("bias", shape=(self.units,), initializer=self.bias_initializer)
        super(SNDense, self).build(input_shape)  # Be sure to call this at the end
    
    
    def call(self, inputs):
        x = tf.matmul(inputs, spectral_norm(self.kernel, name=self.name, update_collection=self.update_collection))
        x = tf.nn.bias_add(x, self.bias)
        return x

In [4]:
def make_model(update_collection):
    x = tf.keras.layers.Input(shape=(10,))
    y = SNDense(2, kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02), 
                bias_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02), 
                name='fc1', update_collection=update_collection)(x)
    return tf.keras.models.Model(x, y)

In [5]:
model = make_model(update_collection=None)

W1126 16:02:57.545933 11740 deprecation.py:506] From f:\anaconda3\envs\tensorflow1.14\lib\site-packages\tensorflow\python\keras\initializers.py:94: calling TruncatedNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1126 16:02:57.628952 11740 ag_logging.py:145] Entity <bound method SNDense.call of <__main__.SNDense object at 0x000001BD62601D30>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method SNDense.call of <__main__.SNDense object at 0x000001BD62601D30>>: AssertionError: Bad argument number for Name: 3, expecting 4


fc1/u:0
u is updated


In [6]:
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

[<tf.Variable 'fc1/kernel:0' shape=(10, 2) dtype=float32_ref>,
 <tf.Variable 'fc1/bias:0' shape=(2,) dtype=float32_ref>,
 <tf.Variable 'fc1/u:0' shape=(1, 2) dtype=float32_ref>]

In [7]:
model_no_op = make_model(update_collection='NO_OPS')

W1126 16:02:57.881875 11740 ag_logging.py:145] Entity <bound method SNDense.call of <__main__.SNDense object at 0x000001BD65D7C4E0>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method SNDense.call of <__main__.SNDense object at 0x000001BD65D7C4E0>>: AssertionError: Bad argument number for Name: 3, expecting 4


fc1/u:0
u is NOT updated


In [8]:
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

[<tf.Variable 'fc1/kernel:0' shape=(10, 2) dtype=float32_ref>,
 <tf.Variable 'fc1/bias:0' shape=(2,) dtype=float32_ref>,
 <tf.Variable 'fc1/u:0' shape=(1, 2) dtype=float32_ref>]

In [9]:
output_no_op = model_no_op(model_no_op.input)
output = model(model.input)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print('Initial value: {}'.format(sess.run('fc1/u:0')))
    
    sess.run(output_no_op, feed_dict={model_no_op.input: 2*np.ones((5,10))})
    print('After run NO_OPS : {}'.format(sess.run('fc1/u:0')))
    
    sess.run(output, feed_dict={model.input: 2*np.ones((5,10))})
    print('After run None : {}'.format(sess.run('fc1/u:0')))
    
    sess.run(output_no_op, feed_dict={model_no_op.input: 2*np.ones((5,10))})
    print('After run NO_OPS : {}'.format(sess.run('fc1/u:0')))
    
    sess.run(output, feed_dict={model.input: np.ones((5,10))})
    print('After run None : {}'.format(sess.run('fc1/u:0')))

W1126 16:02:58.184012 11740 ag_logging.py:145] Entity <bound method SNDense.call of <__main__.SNDense object at 0x000001BD65D7C4E0>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method SNDense.call of <__main__.SNDense object at 0x000001BD65D7C4E0>>: AssertionError: Bad argument number for Name: 3, expecting 4
W1126 16:02:58.261214 11740 ag_logging.py:145] Entity <bound method SNDense.call of <__main__.SNDense object at 0x000001BD62601D30>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method SNDense.call of <__main__.SNDense object at 0x000001BD62601D30>>: AssertionError: Bad argument number for Name: 3, ex

fc1/u:0
u is NOT updated
fc1/u:0
u is updated
Initial value: [[-0.6487929  -0.45307902]]
After run NO_OPS : [[-0.6487929  -0.45307902]]
After run None : [[-0.8592447 -0.5115648]]
After run NO_OPS : [[-0.8592447 -0.5115648]]
After run None : [[-0.88483095 -0.4659123 ]]
