In [16]:
import tensorflow as tf
class MyInstanceNorm(tf.keras.layers.Layer):
    def __init__(self,
                 axis=-1,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(MyInstanceNorm, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)
        self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
        self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
        self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
        self.beta_constraint = tf.keras.constraints.get(beta_constraint)
        self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)

    def build(self, input_shape):
        self._add_gamma_weight(input_shape)
        self._add_beta_weight(input_shape)
        self.built = True
        super(MyInstanceNorm, self).build(input_shape)

    def call(self, inputs):
        input_shape = tf.keras.backend.int_shape(inputs)
        tensor_input_shape = tf.shape(inputs)
        mean, variance = tf.nn.moments(inputs, [1, 2], keepdims=True)
        weight_shape = self._create_broadcast_shape(input_shape)
        expanded_beta, expanded_gamma = self._get_reshaped_weights(input_shape, weight_shape, broadcast=False)
        outputs = tf.nn.batch_normalization(inputs, mean, variance, offset=expanded_beta, scale=expanded_gamma,
                                            variance_epsilon=self.epsilon)
        
        return outputs
    
    def _get_reshaped_weights(self, input_shape, weight_shape, broadcast=False):
        gamma = None
        beta = None
        if self.scale:
            gamma = tf.reshape(self.gamma, weight_shape)
        if self.center:
            beta = tf.reshape(self.beta, weight_shape)
        return gamma, beta
    
    def _create_broadcast_shape(self, input_shape):
        broadcast_shape = [1] * (len(input_shape) - 1)
        broadcast_shape[self.axis] = input_shape[self.axis]
        return broadcast_shape
    
    def _add_gamma_weight(self, input_shape):

        dim = input_shape[self.axis]
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                name='gamma',
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint)
        else:
            self.gamma = None
    def _add_beta_weight(self, input_shape):

        dim = input_shape[self.axis]
        shape = (dim,)

        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                name='beta',
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint)
        else:
            self.beta = None

    def get_config(self):
        config = {
            'axis':
            self.axis,
            'epsilon':
            self.epsilon,
            'center':
            self.center,
            'scale':
            self.scale,
            'beta_initializer':
            tf.keras.initializers.serialize(self.beta_initializer),
            'gamma_initializer':
            tf.keras.initializers.serialize(self.gamma_initializer),
            'beta_regularizer':
            tf.keras.regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer':
            tf.keras.regularizers.serialize(self.gamma_regularizer),
            'beta_constraint':
            tf.keras.constraints.serialize(self.beta_constraint),
            'gamma_constraint':
            tf.keras.constraints.serialize(self.gamma_constraint)
        }
        base_config = super(MyInstanceNorm, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))



class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

    def get_config(self):
        config = {
            'padding':
            self.padding
        }
        base_config = super(ReflectionPadding2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def normalization(intput_tensor, method='instance'):
  if method == 'instance':
    x = MyInstanceNorm(center=True, scale=True,
                                                  beta_initializer="random_uniform",
                                                  gamma_initializer="random_uniform")(intput_tensor)
  else:
    x = tf.keras.layers.BatchNormalization()(intput_tensor)
  return x

def conv_w_reflection(input_tensor,
               kernel_size,
               filters,
               stride):
  p = kernel_size // 2
  x = ReflectionPadding2D(padding=(p, p))(input_tensor)
  x = tf.keras.layers.Conv2D(filters, kernel_size, strides=stride, use_bias=True)(x)
  x = normalization(x, method='batch')
  x = tf.keras.layers.Activation(tf.nn.relu)(x)
  return x

def conv_block(input_tensor, filters):
  x = ReflectionPadding2D(padding=(1, 1))(input_tensor)
  x = tf.keras.layers.Conv2D(filters, kernel_size=3, strides=(1, 1), use_bias=True)(x)
  x = normalization(x, method='batch')
  x = tf.keras.layers.Activation(tf.nn.relu)(x)

  x = ReflectionPadding2D(padding=(1, 1))(x)
  x = tf.keras.layers.Conv2D(filters, kernel_size=3, strides=(1, 1), use_bias=True)(x)
  x = normalization(x, method='batch')
  return x

def residual_block(input_tensor, filters):
  b1 = conv_block(input_tensor, filters)
  x = tf.keras.layers.Add()([input_tensor, b1])
  return x

def upsample_conv(input_tensor, kernel_size, filters, stride):
  x = tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides=stride, padding='same', use_bias=True)(input_tensor)
  x = normalization(x, method='batch')
  x = tf.keras.layers.Activation(tf.nn.relu)(x)
  return x

def create_generator(shape=(256, 256, 3)):
    inputs = tf.keras.layers.Input(shape=shape)
    x = conv_w_reflection(inputs, 7, 64, 1)
    x = conv_w_reflection(x, 3, 128, 2)
    x = conv_w_reflection(x, 3, 256, 2)
    x = residual_block(x, 256)
    x = residual_block(x, 256)
    x = residual_block(x, 256)

    x = residual_block(x, 256)
    x = residual_block(x, 256)
    x = residual_block(x, 256)

    x = residual_block(x, 256)
    x = residual_block(x, 256)
    x = residual_block(x, 256)
    x = upsample_conv(x, 3, 128, 2)
    x = upsample_conv(x, 3, 64, 2)
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = tf.keras.layers.Conv2D(3, 7, strides=1, activation='tanh')(x)
#     x = tf.keras.layers.Conv2DTranspose(output_dim, kernel_size=7, strides=1, padding='same', activation='tanh')(x)
    x = tf.keras.layers.Lambda(lambda x: tf.math.scalar_mul(.5, x) + .5)(x)
    return tf.keras.Model(inputs=inputs, outputs=x)
#     return x

def unet_downsample(input_tensor, filters, size, apply_norm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    x = tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                    kernel_initializer=initializer, use_bias=True)(input_tensor)
    if apply_norm:
        x = normalization(x, method='instance')
    x = tf.keras.layers.Activation(tf.nn.leaky_relu)(x)
    return x

def unet_upsample(input_tensor, filters, size, apply_dropout=False, last=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    x = tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                    kernel_initializer=initializer, use_bias=True)(input_tensor)
    x = normalization(x, method='instance')
    if apply_dropout:
        x = tf.keras.layers.Dropout(0.5)(x)
    if last:
        x = tf.keras.layers.Activation(tf.nn.tanh)(x)
    else:
        x = tf.keras.layers.Activation(tf.nn.relu)(x)
    return x

def create_unet_generator(shape=(256, 256, 3)):
    inputs = tf.keras.layers.Input(shape=shape)
    down1 = unet_downsample(inputs, 64, 4, apply_norm=False) # 128,128,64
    down2 = unet_downsample(down1, 128, 4) # 64,64,128
    down3 = unet_downsample(down2, 256, 4) # 32,32,256
    down4 = unet_downsample(down3, 512, 4) # 16,16,512
    down5 = unet_downsample(down4, 512, 4) # 8,8,512
    down6 = unet_downsample(down5, 512, 4) # 4,4,512
    down7 = unet_downsample(down6, 512, 4) # 2,2,512
    down8 = unet_downsample(down7, 512, 4) # 1,1,512
    up1 = unet_upsample(down8, 512, 4, apply_dropout=True) # 2,2,512
    up1 = tf.keras.layers.Concatenate()([up1, down7]) # 2,2,1024
    up2 = unet_upsample(up1, 512, 4, apply_dropout=True) # 4,4,512
    up2 = tf.keras.layers.Concatenate()([up2, down6]) # 4,4,1024
    up3 = unet_upsample(up2, 512, 4, apply_dropout=True) # 8,8,512
    up3 = tf.keras.layers.Concatenate()([up3, down5]) # 8,8,1024
    up4 = unet_upsample(up3, 512, 4) # 16,16,512
    up4 = tf.keras.layers.Concatenate()([up4, down4]) # 16,16,1024
    up5 = unet_upsample(up4, 256, 4) # 32,32,256
    up5 = tf.keras.layers.Concatenate()([up5, down3]) # 32,32,512
    up6 = unet_upsample(up5, 128, 4) # 64,64,128
    up6 = tf.keras.layers.Concatenate()([up6, down2]) # 64,64,256
    up7 = unet_upsample(up6, 64, 4) # 128,128,64
    up7 = tf.keras.layers.Concatenate()([up7, down1]) # 128,128,128
    up8 = unet_upsample(up7, 3, 4, last=True) # 256,256,3
    x = tf.keras.layers.Lambda(lambda x: tf.math.scalar_mul(.5, x) + .5)(up8)
    return tf.keras.Model(inputs=inputs, outputs=x)

def dis_downsample(input_tensor,
               kernel_size,
               filters,
               stride, norm=None):
  initializer = tf.random_normal_initializer(0., 0.02)
  p = 1
  x = ReflectionPadding2D(padding=(p, p))(input_tensor)
  x = tf.keras.layers.Conv2D(filters, kernel_size, strides=stride, kernel_initializer=initializer)(x)
  if norm is not None:
    x = normalization(x, method=norm)
  x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
  return x

def create_discriminator(shape=(256, 256, 3)):
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = tf.keras.layers.Input(shape=shape)
    x = dis_downsample(inputs, 4, 64, 2, norm=None)
    x = dis_downsample(x, 4, 128, 2, norm='batch')
    x = dis_downsample(x, 4, 256, 2, norm='batch')
    x = dis_downsample(x, 4, 512, 1, norm='batch')
    x = ReflectionPadding2D(padding=(1, 1))(x)
    x = tf.keras.layers.Conv2D(filters=1, kernel_size=4, strides=1, kernel_initializer=initializer)(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

def create_LSdiscriminator(shape=(256, 256, 3)):
    inputs = tf.keras.layers.Input(shape=shape)
    x = dis_downsample(inputs, 5, 64, 2, norm=None)
    x = dis_downsample(x, 5, 128, 2, norm='instance')
    x = dis_downsample(x, 5, 256, 2, norm='instance')
    x = dis_downsample(x, 5, 512, 2, norm='instance')
    x = tf.keras.layers.Dense(1, activation='linear')(x)
    return tf.keras.Model(inputs=inputs, outputs=x)


In [20]:
g = create_generator()
g.summary()

Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
reflection_padding2d_52 (Reflec (None, 262, 262, 3)  0           input_11[0][0]                   
__________________________________________________________________________________________________
conv2d_63 (Conv2D)              (None, 256, 256, 64) 9472        reflection_padding2d_52[0][0]    
__________________________________________________________________________________________________
batch_normalization_47 (BatchNo (None, 256, 256, 64) 256         conv2d_63[0][0]                  
____________________________________________________________________________________________