Skip to content

Commit

Permalink
fix init.
Browse files Browse the repository at this point in the history
  • Loading branch information
feanorliu committed Jun 19, 2018
1 parent 53ad661 commit bc4a0ec
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions pixel_cnn_pp/nn.py
Expand Up @@ -157,27 +157,30 @@ def get_name(layer_name, counters):
return name

@add_arg_scope
def dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
def dense(x_, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
''' fully connected layer '''
name = get_name('dense', counters)
with tf.variable_scope(name):
V = get_var_maybe_avg('V', ema, shape=[int(x.get_shape()[1]),num_units], dtype=tf.float32,
V = get_var_maybe_avg('V', ema, shape=[int(x_.get_shape()[1]),num_units], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
g = get_var_maybe_avg('g', ema, shape=[num_units], dtype=tf.float32,
initializer=tf.constant_initializer(1.), trainable=True)
b = get_var_maybe_avg('b', ema, shape=[num_units], dtype=tf.float32,
initializer=tf.constant_initializer(0.), trainable=True)

# use weight normalization (Salimans & Kingma, 2016)
x = tf.matmul(x, V)
x = tf.matmul(x_, V)
scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0]))
x = tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units])

if init: # normalize x
m_init, v_init = tf.nn.moments(x, [0])
scale_init = init_scale/tf.sqrt(v_init + 1e-10)
with tf.control_dependencies([g.assign(g*scale_init), b.assign_add(-m_init*scale_init)]):
x = tf.identity(x)
# x = tf.identity(x)
x = tf.matmul(x_, V)
scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0]))
x = tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units])

# apply nonlinearity
if nonlinearity is not None:
Expand All @@ -186,11 +189,11 @@ def dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=Fals
return x

@add_arg_scope
def conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
def conv2d(x_, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
''' convolutional layer '''
name = get_name('conv2d', counters)
with tf.variable_scope(name):
V = get_var_maybe_avg('V', ema, shape=filter_size+[int(x.get_shape()[-1]),num_filters], dtype=tf.float32,
V = get_var_maybe_avg('V', ema, shape=filter_size+[int(x_.get_shape()[-1]),num_filters], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
g = get_var_maybe_avg('g', ema, shape=[num_filters], dtype=tf.float32,
initializer=tf.constant_initializer(1.), trainable=True)
Expand All @@ -201,13 +204,15 @@ def conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinea
W = tf.reshape(g, [1, 1, 1, num_filters]) * tf.nn.l2_normalize(V, [0, 1, 2])

# calculate convolutional layer output
x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1] + stride + [1], pad), b)
x = tf.nn.bias_add(tf.nn.conv2d(x_, W, [1] + stride + [1], pad), b)

if init: # normalize x
m_init, v_init = tf.nn.moments(x, [0,1,2])
scale_init = init_scale / tf.sqrt(v_init + 1e-10)
with tf.control_dependencies([g.assign(g * scale_init), b.assign_add(-m_init * scale_init)]):
x = tf.identity(x)
# x = tf.identity(x)
W = tf.reshape(g, [1, 1, 1, num_filters]) * tf.nn.l2_normalize(V, [0, 1, 2])
x = tf.nn.bias_add(tf.nn.conv2d(x_, W, [1] + stride + [1], pad), b)

# apply nonlinearity
if nonlinearity is not None:
Expand All @@ -216,16 +221,16 @@ def conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinea
return x

@add_arg_scope
def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
def deconv2d(x_, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
''' transposed convolutional layer '''
name = get_name('deconv2d', counters)
xs = int_shape(x)
xs = int_shape(x_)
if pad=='SAME':
target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters]
else:
target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters]
with tf.variable_scope(name):
V = get_var_maybe_avg('V', ema, shape=filter_size+[num_filters,int(x.get_shape()[-1])], dtype=tf.float32,
V = get_var_maybe_avg('V', ema, shape=filter_size+[num_filters,int(x_.get_shape()[-1])], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
g = get_var_maybe_avg('g', ema, shape=[num_filters], dtype=tf.float32,
initializer=tf.constant_initializer(1.), trainable=True)
Expand All @@ -236,14 +241,17 @@ def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlin
W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3])

# calculate convolutional layer output
x = tf.nn.conv2d_transpose(x, W, target_shape, [1] + stride + [1], padding=pad)
x = tf.nn.conv2d_transpose(x_, W, target_shape, [1] + stride + [1], padding=pad)
x = tf.nn.bias_add(x, b)

if init: # normalize x
m_init, v_init = tf.nn.moments(x, [0,1,2])
scale_init = init_scale / tf.sqrt(v_init + 1e-10)
with tf.control_dependencies([g.assign(g * scale_init), b.assign_add(-m_init * scale_init)]):
x = tf.identity(x)
# x = tf.identity(x)
W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3])
x = tf.nn.conv2d_transpose(x_, W, target_shape, [1] + stride + [1], padding=pad)
x = tf.nn.bias_add(x, b)

# apply nonlinearity
if nonlinearity is not None:
Expand Down

0 comments on commit bc4a0ec

Please sign in to comment.