Put them here for convinence.
class EqualizeLearningRate(tf.keras.layers.Wrapper):
"""
Reference from WeightNormalization implementation of TF Addons
EqualizeLearningRate wrapper works for keras CNN and Dense (RNN not tested).
```python
net = EqualizeLearningRate(
tf.keras.layers.Conv2D(2, 2, activation='relu'),
input_shape=(32, 32, 3),
data_init=True)(x)
net = EqualizeLearningRate(
tf.keras.layers.Conv2D(16, 5, activation='relu'),
data_init=True)(net)
net = EqualizeLearningRate(
tf.keras.layers.Dense(120, activation='relu'),
data_init=True)(net)
net = EqualizeLearningRate(
tf.keras.layers.Dense(n_classes),
data_init=True)(net)
```
Arguments:
layer: a layer instance.
Raises:
ValueError: If `Layer` does not contain a `kernel` of weights
"""
def __init__(self, layer, **kwargs):
super(EqualizeLearningRate, self).__init__(layer, **kwargs)
self._track_trackable(layer, name='layer')
self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)
def build(self, input_shape):
"""Build `Layer`"""
input_shape = tf.TensorShape(input_shape)
self.input_spec = tf.keras.layers.InputSpec(
shape=[None] + input_shape[1:])
if not self.layer.built:
self.layer.build(input_shape)
kernel_layer = self.layer.cell if self.is_rnn else self.layer
if not hasattr(kernel_layer, 'kernel'):
raise ValueError('`EqualizeLearningRate` must wrap a layer that'
' contains a `kernel` for weights')
if self.is_rnn:
kernel = kernel_layer.recurrent_kernel
else:
kernel = kernel_layer.kernel
# He constant
self.fan_in, self.fan_out= self._compute_fans(kernel.shape)
self.he_constant = tf.Variable(1.0 / np.sqrt(self.fan_in), dtype=tf.float32, trainable=False)
self.v = kernel
self.built = True
def call(self, inputs, training=True):
"""Call `Layer`"""
with tf.name_scope('compute_weights'):
# Multiply the kernel with the he constant.
kernel = tf.identity(self.v * self.he_constant)
if self.is_rnn:
print(self.is_rnn)
self.layer.cell.recurrent_kernel = kernel
update_kernel = tf.identity(self.layer.cell.recurrent_kernel)
else:
self.layer.kernel = kernel
update_kernel = tf.identity(self.layer.kernel)
# Ensure we calculate result after updating kernel.
with tf.control_dependencies([update_kernel]):
outputs = self.layer(inputs)
return outputs
def compute_output_shape(self, input_shape):
return tf.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
def _compute_fans(self, shape, data_format='channels_last'):
"""
From Official Keras implementation
Computes the number of input and output units for a weight shape.
# Arguments
shape: Integer shape tuple.
data_format: Image data format to use for convolution kernels.
Note that all kernels in Keras are standardized on the
`channels_last` ordering (even when inputs are set
to `channels_first`).
# Returns
A tuple of scalars, `(fan_in, fan_out)`.
# Raises
ValueError: in case of invalid `data_format` argument.
"""
if len(shape) == 2:
fan_in = shape[0]
fan_out = shape[1]
elif len(shape) in {3, 4, 5}:
# Assuming convolution kernels (1D, 2D or 3D).
# TH kernel shape: (depth, input_depth, ...)
# TF kernel shape: (..., input_depth, depth)
if data_format == 'channels_first':
receptive_field_size = np.prod(shape[2:])
fan_in = shape[1] * receptive_field_size
fan_out = shape[0] * receptive_field_size
elif data_format == 'channels_last':
receptive_field_size = np.prod(shape[:-2])
fan_in = shape[-2] * receptive_field_size
fan_out = shape[-1] * receptive_field_size
else:
raise ValueError('Invalid data_format: ' + data_format)
else:
# No specific assumptions.
fan_in = np.sqrt(np.prod(shape))
fan_out = np.sqrt(np.prod(shape))
return fan_in, fan_out
class PixelNormalization(tf.keras.layers.Layer):
"""
Arguments:
epsilon: a float-point number, the default is 1e-8
"""
def __init__(self, epsilon=1e-8):
super(PixelNormalization, self).__init__()
self.epsilon = epsilon
def call(self, inputs):
return inputs / tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)
def compute_output_shape(self, input_shape):
return input_shape
class MinibatchSTDDEV(tf.keras.layers.Layer):
"""
Reference from official pggan implementation
https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py
Arguments:
group_size: a integer number, minibatch must be divisible by (or smaller than) group_size.
"""
def __init__(self, group_size=4):
super(MinibatchSTDDEV, self).__init__()
self.group_size = group_size
def call(self, inputs):
group_size = tf.minimum(self.group_size, tf.shape(inputs)[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = inputs.shape # [NHWC] Input shape.
y = tf.reshape(inputs, [group_size, -1, s[1], s[2], s[3]]) # [GMHWC] Split minibatch into M groups of size G.
y = tf.cast(y, tf.float32) # [GMHWC] Cast to FP32.
y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMHWC] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [MHWC] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [MHWC] Calc stddev over group.
y = tf.reduce_mean(y, axis=[1,2,3], keepdims=True) # [M111] Take average over fmaps and pixels.
y = tf.cast(y, inputs.dtype) # [M111] Cast back to original data type.
y = tf.tile(y, [group_size, s[1], s[2], 1]) # [NHW1] Replicate over group and pixels.
return tf.concat([inputs, y], axis=-1) # [NHWC] Append as new fmap.
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1], input_shape[2], input_shape[3] + 1)