Skip to content
Permalink
Browse files

-

  • Loading branch information
iperov committed Jan 8, 2020
1 parent b8182ae commit bbe81b20af3492728886442bef7a89178e870adf
Showing with 13 additions and 2 deletions.
  1. +13 −2 nnlib/nnlib.py
@@ -853,15 +853,20 @@ def get_config(self):
nnlib.BilinearInterpolation = BilinearInterpolation

class WScaleConv2DLayer(KL.Conv2D):
def __init__(self, *args, **kwargs):
def __init__(self, *args, gain=None, **kwargs):
kwargs['kernel_initializer'] = keras.initializers.random_normal()

if gain is None:
gain = np.sqrt(2)

self.gain = gain

super(WScaleConv2DLayer,self).__init__(*args,**kwargs)

def build(self, input_shape):
super().build(input_shape)
kernel_shape = K.int_shape(self.kernel)
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) )
std = np.sqrt(self.gain) / np.sqrt( np.prod(kernel_shape[:-1]) )
self.wscale = K.constant(std, dtype=K.floatx() )

def call(self, input, **kwargs):
@@ -870,6 +875,12 @@ def call(self, input, **kwargs):
x = super().call(input,**kwargs)
self.kernel = k
return x

def get_config(self):
config = {"gain": self.gain}
base_config = super(WScaleConv2DLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

nnlib.WScaleConv2DLayer = WScaleConv2DLayer

class SelfAttention(KL.Layer):

0 comments on commit bbe81b2

Please sign in to comment.
You can’t perform that action at this time.