In [1]:
# Rho clipper: model안에서 존재하는지, weight가갱신되는지 확인 할 것
# summary (Tensorboard)

# tf.model 안에서 sequential을 다시 쓰는 것은 튜토리얼에 없음. block의 경우 layer를 쓰라고 나와있음. model은 seralization 및 training 등에 사용
#  (training이란.. 모델 container는 안에 있는 var들을 모두 모아서..이건 layer도 되는데? training 가능?)
# graph가 연결이 안되어있는 경우 grad를 구하면 None 됨. 
# self.weight_decay 다 적용 안됨 
# Spectral normalization을 dense에도 적용 필요 함 

In [2]:
import tensorflow as tf

# https://github.com/thisisiron/spectral_normalization-tf2/blob/master/sn.py
# https://groups.google.com/a/tensorflow.org/g/discuss/c/PRjyj6tiQvU?pli=1

class SpectralNormalization(tf.keras.layers.Wrapper):
    def __init__(self, layer, iteration=1, eps=1e-12, training=True, **kwargs):
        self.iteration = iteration
        self.eps = eps
        self.do_power_iteration = training
        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                'Please initialize `TimeDistributed` layer with a '
                '`Layer` instance. You passed: {input}'.format(input=layer))
        super(SpectralNormalization, self).__init__(layer, **kwargs)

    def build(self, input_shape):
        self.layer.build(input_shape)

        self.w = self.layer.kernel
        self.w_shape = self.w.shape.as_list()
        # print('w_shape', self.w_shape)

        # 4 x 4 x 3 x 32 (h w previous_c next_c)   # 324 X 1 ( prev_unit, next_unit)
        if len(self.w_shape) >=3 : # Conv2D 
          self.v = self.add_weight(shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
                                  initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                  trainable=False,
                                  name='sn_v',
                                  dtype=tf.float32)
        else:
          self.v = self.add_weight(shape=(1, self.w_shape[0]), # 1 x 2048
                                  initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                  trainable=False,
                                  name='sn_v',
                                  dtype=tf.float32)

        self.u = self.add_weight(shape=(1, self.w_shape[-1]),
                                 initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                 trainable=False,
                                 name='sn_u',
                                 dtype=tf.float32)

        super(SpectralNormalization, self).build()

    def call(self, inputs):
        self.update_weights()
        output = self.layer(inputs)
        # self.restore_weights()  # Restore weights because of this formula "W = W - alpha * W_SN`"
        return output
    
    def update_weights(self):
        w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
        
        u_hat = self.u
        v_hat = self.v

        if self.do_power_iteration:
            for _ in range(self.iteration):
                v_ = tf.matmul(u_hat, tf.transpose(w_reshaped)) # 
                v_hat = v_ / (tf.reduce_sum(v_**2)**0.5 + self.eps)

                u_ = tf.matmul(v_hat, w_reshaped)
                u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)

        sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
        self.u.assign(u_hat)
        self.v.assign(v_hat)
        self.layer.kernel.assign(self.w / sigma)


In [3]:
import tensorflow as tf 
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

class ResnetGenerator(tf.keras.Model): #https://www.tensorflow.org/api_docs/python/tf/keras/Sequential (input_shape issue)
    def __init__(self, output_nc, ngf=64, n_blocks=6, light=False):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.output_nc = output_nc
        self.ngf = ngf
        self.n_blocks = n_blocks
        self.light = light
       
        self.DownBlock= []
        self.DownBlock.append(Conv(filters = self.ngf, kernel_size = 7, strides = 1, pad = 3, normal = 'IN', 
                                act = 'relu', use_bias=True, pad_type='REFLECT'))

       # Down-Sampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            self.DownBlock.append(Conv(filters = self.ngf*mult*2, kernel_size = 3, strides = 2, pad = 1, normal = 'IN', 
                                act = 'relu', use_bias=False, pad_type='REFLECT'))

        # Down-Sampling Bottleneck
        mult = 2**n_downsampling
        for i in range(self.n_blocks):
           self.DownBlock.append(ResnetBlock(self.ngf * mult, use_bias=False))

        # CAM
        self.cam_fc = tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02), use_bias=True)
        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.gmp = tf.keras.layers.GlobalMaxPool2D()
        self.conv1x1 = tf.keras.layers.Conv2D(filters= self.ngf * mult , kernel_size=(1,1), 
                                              kernel_regularizer=tf.keras.regularizers.L2(0.0001), 
                                              strides=(1,1), use_bias=True)
        self.relu = tf.keras.layers.ReLU()

        # Gamma, Beta block (the argument 'self.light' is not used in Tensorflow as the input dims are automatically detected)
        self.FC = []
        for i in range(2):
          self.FC.append(tf.keras.layers.Dense(self.ngf * self.n_blocks, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=True))
          self.FC.append(tf.keras.layers.ReLU())          
        self.gamma = tf.keras.layers.Dense(self.ngf * self.n_blocks, kernel_regularizer=tf.keras.regularizers.L2(0.0001),use_bias=True)  
        self.beta = tf.keras.layers.Dense(self.ngf * self.n_blocks, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=True)  

        # Up-Sampling Bottleneck
        self.UpBlock1 = []
        for _ in range(self.n_blocks):
          self.UpBlock1.append(ResnetAdaILNBlock(self.ngf * mult, use_bias=True, smoothing=True)) # g

        # Up-Sampling 
        self.UpBlock2 = []
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            self.UpBlock2.append(tf.keras.layers.UpSampling2D(size=(2, 2),interpolation='nearest'))
            self.UpBlock2.append(Conv(filters = self.ngf * mult // 2, kernel_size = 3, strides = 1, pad = 1, normal = 'ILN', act = 'relu', 
                              use_bias = False, pad_type='REFLECT', ILN_factor = self.ngf * mult // 2))
        self.UpBlock2.append(Conv(filters = self.output_nc, kernel_size = 7, strides = 1, pad = 3, normal = None, 
                          act = 'tanh', use_bias = False, pad_type='REFLECT'))

    def call(self, x):

      for layer in self.DownBlock:
        x = layer(x)

      gap = self.gap(x)  # Global Average Pooling  # eg. 1 x 32 x 32 x 256 -> 1 x 256 
      gap_logit = self.cam_fc(gap) # 1 x 1
      gap_weight = tf.squeeze(self.cam_fc.trainable_variables[0] + self.cam_fc.trainable_variables[1])
      gap = x * gap_weight 
      gmp = self.gmp(x)  # Global Average Pooling  # 4 x 32 x 32 x 256 -> 4 x 256 or 4 x 1 x 1 x 256?
      gmp_logit = self.cam_fc(gmp) # 4 x 1
      gmp_weight = tf.squeeze(self.cam_fc.trainable_variables[0] + self.cam_fc.trainable_variables[1])
      gmp = x * gmp_weight 

      cam_logit = tf.keras.layers.concatenate([gap_logit, gmp_logit], axis=-1)    
      x = tf.keras.layers.concatenate([gap, gmp], axis=-1)
      x = self.relu(self.conv1x1(x))
      heatmap = tf.reduce_sum(x, axis=-1, keepdims=True) ############ squeeze??

      if self.light:
        x_ = self.gap(x)
        x_ = tf.keras.layers.Flatten()(x_)
        for layer in self.FC:
          x_ = layer(x_)
      else:
        x_ = tf.keras.layers.Flatten()(x)
        for layer in self.FC:
          x_ = layer(x_)
        gamma, beta = self.gamma(x_), self.beta(x_)
     
      for layer in self.UpBlock1:
        x = layer(x, gamma, beta)
      
      for layer in self.UpBlock2:
        x = layer(x)

      return x, cam_logit, heatmap 



class Conv(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, pad, normal=None, act=None, use_bias=False, pad_type='REFLECT', **kwargs):
        super(Conv, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.pad = pad
        self.normal = normal
        self.act = act
        self.use_bias = use_bias
        self.pad_type = pad_type
        if 'ILN_factor' in kwargs:
            self.ILN_factor = kwargs['ILN_factor']

        # normalization
        if self.normal == 'IN':
            self.normal = tfa.layers.InstanceNormalization(axis=3, center=True, scale=True,
                                                           beta_initializer="random_uniform",
                                                           gamma_initializer="random_uniform")
        elif self.normal == 'ILN':
            self.normal = ILN(self.ILN_factor)
            
        elif self.normal == 'SN':
            self.conv_sn = SpectralNormalization(tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(self.kernel_size, self.kernel_size),
                                              kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02),
                                              kernel_regularizer=tf.keras.regularizers.l2(0.0001),
                                              strides=(self.strides, self.strides), use_bias=self.use_bias))
          
        # activation
        if self.act == 'relu':
            self.act = tf.keras.layers.ReLU()
        elif self.act == 'tanh':
            self.act = tf.keras.activations.tanh
        elif self.act == 'lrelu':
            self.act = tf.keras.layers.LeakyReLU(alpha=0.2)

        # padding
        if self.pad > 0:
            if (self.kernel_size - self.strides) % 2 == 0:
                self.pad_top = self.pad
                self.pad_bottom = self.pad
                self.pad_left = self.pad
                self.pad_right = self.pad

            else:
                self.pad_top = self.pad
                self.pad_bottom = self.kernel_size - self.strides - self.pad_top
                self.pad_left = self.pad
                self.pad_right = self.kernel_size - self.strides - self.pad_left

        # conv2d
        if self.normal != 'SN':
            self.conv = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(self.kernel_size, self.kernel_size),
                                              kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02),
                                              kernel_regularizer=tf.keras.regularizers.l2(0.0001),
                                              strides=(self.strides, self.strides), use_bias=self.use_bias)

    def __call__(self, x):
        if self.pad > 0:
            x = tf.pad(x, [[0, 0], [self.pad_top, self.pad_bottom], [self.pad_left, self.pad_right], [0, 0]], mode=self.pad_type)
        if self.normal in ('IN','ILN'):
            x = self.normal(self.conv(x))
        elif self.normal == 'SN':
            x = self.conv_sn(x)
        else:
            x = self.conv(x)
        if self.act is not None:
            x = self.act(x)
        return x


class ResnetBlock(tf.keras.layers.Layer): 
  def __init__(self, dim, use_bias):
        super(ResnetBlock, self).__init__()
        self.use_bias = use_bias
        self.dim = dim
        self.conv1 = Conv(filters = self.dim, kernel_size = 3, strides = 1, pad = 1, normal = 'IN', 
                          act = 'relu', use_bias=self.use_bias, pad_type='REFLECT')
        self.conv2 = Conv(filters = self.dim, kernel_size = 3, strides = 1, pad = 1, normal = 'IN', 
                          act = None, use_bias=self.use_bias, pad_type='REFLECT')

  def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x += inputs
        return x



class ResnetAdaILNBlock(tf.keras.layers.Layer):
    def __init__(self, dim, use_bias, smoothing=True):
        super(ResnetAdaILNBlock, self).__init__()
        self.dim = dim
        self.use_bias = use_bias
        self.smoothing = smoothing
        self.conv1 = tf.keras.layers.Conv2D(filters=self.dim, kernel_size=(3, 3),
                                            kernel_regularizer=tf.keras.regularizers.l2(0.0001),
                                            strides=(1, 1), use_bias=self.use_bias)
        self.norm1 = adaILN(self.dim, self.smoothing)
        self.relu1 = tf.keras.layers.ReLU()
        self.conv2 = tf.keras.layers.Conv2D(filters=self.dim, kernel_size=(3, 3),
                                            kernel_regularizer=tf.keras.regularizers.l2(0.0001), 
                                            strides=(1, 1), use_bias=self.use_bias)
        self.norm2 = adaILN(self.dim, self.smoothing)

    def call(self, inputs, gamma, beta):
        x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = self.conv1(x)
        x = self.norm1(x, gamma, beta)
        x = self.relu1(x)
        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = self.conv2(x)
        x = self.norm2(x, gamma, beta)
        x += inputs
        return x




class adaILN(tf.keras.layers.Layer):
    def __init__(self, num_features, smoothing=True):
        super(adaILN, self).__init__()
        self.num_features = num_features
        self.eps = 1e-12
        self.smoothing = smoothing

    def build(self, input_shape):  # 반드시 input_shape를 써야하는지? input shape에 상관 없으므로, 그냥 __init__에 넣어도 될 듯하다.
        self.rho = tf.Variable(initial_value=tf.fill([self.num_features], 1.0), dtype=tf.float32, trainable=True, constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))

    def call(self, inputs, gamma, beta):
        in_mean, in_var = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True), tf.math.reduce_variance(inputs, axis=[1, 2], keepdims=True)
        out_in = (inputs - in_mean) / tf.math.sqrt(in_var + self.eps)
        ln_mean, ln_var = tf.reduce_mean(inputs, axis=[1, 2, 3], keepdims=True), tf.math.reduce_variance(inputs, axis=[1, 2, 3], keepdims=True)
        out_ln = (inputs - ln_mean) / tf.math.sqrt(ln_var + self.eps)
        if self.smoothing :
            self.rho.assign(tf.clip_by_value(self.rho - tf.constant(0.1), 0.0, 1.0))
        out = self.rho * out_in + (1-self.rho) * out_ln
        out = out * tf.expand_dims(tf.expand_dims(gamma, axis=1), axis=1) + tf.expand_dims(tf.expand_dims(beta, axis=1), axis=1)  # (batch,channel -> batch, height, width, channel)
        return out


class ILN(tf.keras.layers.Layer):
  def __init__(self, num_features):
    super(ILN, self).__init__()
    self.num_features = num_features 
    self.eps = 1e-12

  def build(self,input_shape):
    self.rho = tf.Variable(initial_value=tf.fill([self.num_features], 0.0), dtype=tf.float32, trainable=True, constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
    self.gamma = tf.Variable(initial_value=tf.fill([self.num_features], 1.0), dtype=tf.float32, trainable=True)
    self.beta = tf.Variable(initial_value=tf.fill([self.num_features], 0.0), dtype=tf.float32, trainable=True)

  def call(self, inputs):
    in_mean, in_var = tf.reduce_mean(inputs, axis=[1,2], keepdims=True), tf.math.reduce_variance(inputs, axis=[1,2], keepdims=True)
    out_in = (inputs - in_mean) / tf.math.sqrt(in_var + self.eps)
    ln_mean, ln_var = tf.reduce_mean(inputs, axis=[1,2,3], keepdims=True), tf.math.reduce_variance(inputs, axis=[1,2,3], keepdims=True)
    out_ln = (inputs - ln_mean) / tf.math.sqrt(ln_var + self.eps)
    out = self.rho * out_in + (1-self.rho) * out_ln
    out = out * self.gamma + self.beta # 여기서는 바로 broadcasting가능
    return out   


class Discriminator_global(tf.keras.Model):
    def __init__(self, ndf=64, n_layers=6):
        assert(n_layers >= 0)
        super(Discriminator_global, self).__init__()
        self.ndf = ndf
        self.n_layers = n_layers
        
        self.model=tf.keras.Sequential()
        self.model.add(Conv(filters = self.ndf, kernel_size = 4, strides = 2, pad = 1, normal = 'SN', act = 'lrelu', 
                       use_bias=True, pad_type='REFLECT'))
        
        for i in range(1, self.n_layers - 1):
          mult = 2 ** (i - 1)
          self.model.add(Conv(filters = self.ndf * mult * 2, kernel_size=4, strides=2, pad=1, normal = 'SN', act = 'lrelu', 
                        use_bias=True, pad_type='REFLECT'))
          
        mult = 2 ** (self.n_layers - 1 - 1)
        self.model.add(Conv(filters = self.ndf * mult * 2, kernel_size=4, strides=1, pad=1, normal = 'SN', act = 'lrelu', 
                        use_bias=True, pad_type='REFLECT'))
        
        # Class Activation Map
        # mult = mult * 2
        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.cam_fc = SpectralNormalization(tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02), use_bias=True))
        # self.gap_fc = tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=False)
        self.gmp = tf.keras.layers.GlobalMaxPool2D()
        # self.gmp_fc = tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=False)
        # self.gmp_fc = SpectralNormalization(tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=True))
        # self.gmp_fc = SpectralNormalization(tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=False))
        self.conv1x1 = tf.keras.layers.Conv2D(filters= self.ndf * mult * 2 , kernel_size=(1,1), kernel_regularizer=tf.keras.regularizers.L2(0.0001), strides=(1,1), use_bias=True)
        self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
        self.conv = Conv(filters = 1, kernel_size=4, strides=1, pad=1, normal = 'SN', act = None, 
                        use_bias=True, pad_type='REFLECT')

    def call(self, inputs):
        x = self.model(inputs)

        gap = self.gap(x)
        gap_logit = self.cam_fc(gap) # 4 x 1
        gap_weight = tf.squeeze(self.cam_fc.trainable_variables[0]+self.cam_fc.trainable_variables[1])
        gap = x * gap_weight # feature map에 gap_weight를 element-wise로 multiply. broadcasting 가능. 즉, 4 x 32 x 32 x 256 이니까.. 4 x 1 x 1 x 256 로 multiply하면 broadcasting (test필요)   

        gmp = self.gmp(x)  # Global Average Pooling  # 4 x 32 x 32 x 256 -> 4 x 256 or 4 x 1 x 1 x 256?
        gmp_logit = self.cam_fc(gmp) # 4 x 1
        gmp_weight = tf.squeeze(self.cam_fc.trainable_variables[0]+self.cam_fc.trainable_variables[1])
        gmp = x * gmp_weight # feature map에 gap_weight를 element-wise로 multiply. broadcasting 가능. 즉, 4 x 32 x 32 x 256 이니까.. 4 x 1 x 1 x 256 로 multiply하면 broadcasting (test필요)   

        cam_logit = tf.keras.layers.concatenate([gap_logit, gmp_logit], axis=-1) 
        x = tf.keras.layers.concatenate([gap, gmp], axis=-1)
        x = self.lrelu(self.conv1x1(x)) 
        heatmap = tf.reduce_sum(x, axis=3, keepdims=True)

        out = self.conv(x)
 
        return out, cam_logit, heatmap



class Discriminator_local(tf.keras.Model):
    def __init__(self, ndf=64, n_layers=6):
        assert(n_layers >= 0)
        super(Discriminator_local, self).__init__()
        self.ndf = ndf
        self.n_layers = n_layers
        
        self.model=tf.keras.Sequential()
        self.model.add(Conv(filters = self.ndf, kernel_size = 4, strides = 2, pad = 1, normal = 'SN', act = 'lrelu', 
                       use_bias=True, pad_type='REFLECT'))
        
        for i in range(1, self.n_layers - 2 - 1):
          mult = 2 ** (i - 1)
          self.model.add(Conv(filters = self.ndf * mult * 2, kernel_size=4, strides=2, pad=1, normal = 'SN', act = 'lrelu', 
                        use_bias=True, pad_type='REFLECT'))
          
        mult = 2 ** (self.n_layers - 1 - 1)
        self.model.add(Conv(filters = self.ndf * mult * 2, kernel_size=4, strides=1, pad=1, normal = 'SN', act = 'lrelu', 
                        use_bias=True, pad_type='REFLECT'))
        
        # Class Activation Map
        # mult = mult * 2
        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.cam_fc = SpectralNormalization(tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02), use_bias=True))
        # self.gap_fc = tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=False)
        self.gmp = tf.keras.layers.GlobalMaxPool2D()
        # self.gmp_fc = tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=False)
        # self.gmp_fc = SpectralNormalization(tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=True))
        # self.gmp_fc = SpectralNormalization(tf.keras.layers.Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.0001), use_bias=False))
        self.conv1x1 = tf.keras.layers.Conv2D(filters= self.ndf * mult * 2 , kernel_size=(1,1), kernel_regularizer=tf.keras.regularizers.L2(0.0001), strides=(1,1), use_bias=True)
        self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
        self.conv = Conv(filters = 1, kernel_size=4, strides=1, pad=1, normal = 'SN', act = None, 
                        use_bias=True, pad_type='REFLECT')

    def call(self, inputs):
        x = self.model(inputs)

        gap = self.gap(x)
        gap_logit = self.cam_fc(gap) # 4 x 1
        gap_weight = tf.squeeze(self.cam_fc.trainable_variables[0]+self.cam_fc.trainable_variables[1])
        gap = x * gap_weight # feature map에 gap_weight를 element-wise로 multiply. broadcasting 가능. 즉, 4 x 32 x 32 x 256 이니까.. 4 x 1 x 1 x 256 로 multiply하면 broadcasting (test필요)   

        gmp = self.gmp(x)  # Global Average Pooling  # 4 x 32 x 32 x 256 -> 4 x 256 or 4 x 1 x 1 x 256?
        gmp_logit = self.cam_fc(gmp) # 4 x 1
        gmp_weight = tf.squeeze(self.cam_fc.trainable_variables[0]+self.cam_fc.trainable_variables[1])
        gmp = x * gmp_weight # feature map에 gap_weight를 element-wise로 multiply. broadcasting 가능. 즉, 4 x 32 x 32 x 256 이니까.. 4 x 1 x 1 x 256 로 multiply하면 broadcasting (test필요)   

        cam_logit = tf.keras.layers.concatenate([gap_logit, gmp_logit], axis=-1) 
        x = tf.keras.layers.concatenate([gap, gmp], axis=-1)
        x = self.lrelu(self.conv1x1(x)) 
        heatmap = tf.reduce_sum(x, axis=3, keepdims=True)

        out = self.conv(x)
 
        return out, cam_logit, heatmap



In [4]:
import time 
from glob import glob
import tensorflow_datasets as tfds
import datetime
import os
import numpy as np 
from matplotlib.pyplot import imsave

AUTOTUNE = tf.data.experimental.AUTOTUNE

def ad_loss(y_pred, y_true):
    return tf.reduce_mean(tf.math.squared_difference(y_pred, y_true))

def bce_loss(y_pred, y_true):
    return tf.keras.losses.BinaryCrossentropy(from_logits=True)(y_true, y_pred)

def id_loss(y_pred, y_true):
    return tf.reduce_mean(tf.abs(y_pred - y_true))

def recon_loss(y_pred, y_true):
    return tf.reduce_mean(tf.abs(y_pred - y_true))

def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1 # [-1,1]
  return image

def train_augment(image,label,img_size):
  image = tf.image.resize(image, [img_size+30, img_size+30], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image = tf.image.random_crop(image, size=[img_size, img_size, 3])
  image = tf.image.random_flip_left_right(image)
  image = tf.cast(image, dtype=tf.float32)
  image = normalize(image)
  return image

def test_augment(image,label,img_size):
  image = tf.image.resize(image, [img_size, img_size])
  image = tf.cast(image, dtype=tf.float32)
  image = normalize(image)
  return image


def image_save(real_images, fake_images, path, iter):
  assert real_images.shape == fake_images.shape
  real_images = real_images * 0.5 + 0.5  # [-1,1 -> 0,1]
  fake_images = fake_images * 0.5 + 0.5  
  batch_size, h, w, c = real_images.shape
  figure = np.zeros((batch_size * h ,w *2, c))
  idx = 0
  for real_img, fake_img in zip(real_images, fake_images):
    figure[h*idx:h*(idx+1), 0:w, ...] = real_img # loc [0-255, 0-255, ...]
    figure[h*idx:h*(idx+1), w: , ...] = fake_img # loc [0-255, 256-, ...]
    idx += 1  
  suffix = '.png'
  path = os.path.join(path, 'iter_' + str(iter) + suffix)
  plt.imsave(path ,figure)


class UGATIT(object) :
    def __init__(self):

        # if self.light :
        #     self.model_name = 'UGATIT_light'
        # else :
        #     self.model_name = 'UGATIT'

        # self.result_dir = args.result_dir
        # self.ckpt_dir = args.ckpt_dir
        # self.dataset = args.dataset

        # self.iterations = args.iterations
        # self.decay_flag = args.decay_flag

        # self.batch_size = args.batch_size
        # self.print_freq = args.print_freq
        # self.save_freq = args.save_freq

        # self.lr = args.lr
        # self.weight_decay = args.weight_decay
        # self.ch = args.ch

        # """ Weight """
        # self.adv_weight = args.adv_weight
        # self.cycle_weight = args.cycle_weight
        # self.identity_weight = args.identity_weight
        # self.cam_weight = args.cam_weight

        # """ Generator """
        # self.n_res = args.n_res

        # """ Discriminator """
        # self.n_dis = args.n_dis

        # self.img_size = args.img_size
        # self.img_ch = args.img_ch

        # self.resume = args.resume

        self.light = False
        self.model_name = 'UGATIT'

        self.ckpt_path = 'ckpt'
        self.tensorboard_path = 'tensorboard'
        self.img_save_path = 'img_save'
        self.dataset = 'horse2zebra'

        self.iterations = 100000
        self.batch_size = 1
        self.sample_num = 5
        self.print_freq = 1
        self.save_freq = 10000
        self.sample_freq = 1
        self.lr = 0.0001
        self.weight_decay = 0.0001
        self.ch = 64

        """ Weight """
        self.adv_weight = 1
        self.cycle_weight = 10
        self.identity_weight = 10
        self.cam_weight = 1000

        """ Generator """
        self.n_res = 4

        """ Discriminator """
        self.n_dis = 6

        self.img_size = 128
        self.img_ch = 3

        print()

        print("##### Information #####")
        print("# light : ", self.light)
        # print("# dataset : ", self.dataset)
        print("# batch_size : ", self.batch_size)
        print("# iteration per epoch : ", self.iterations)

        print()

        print("##### Generator #####")
        print("# residual blocks : ", self.n_res)

        print()

        print("##### Discriminator #####")
        print("# discriminator layer : ", self.n_dis)

        print()

        print("##### Weight #####")
        print("# adv_weight : ", self.adv_weight)
        print("# cycle_weight : ", self.cycle_weight)
        print("# identity_weight : ", self.identity_weight)
        print("# cam_weight : ", self.cam_weight)

    ##################################################################################
    # Model
    ##################################################################################

    def build(self):

    # Dataset
          
        if self.dataset != 'horse2zebra':            
            train_A = tf.data.Dataset.list_files(f"./{self.dataset}/trainA/*.jpg")
            train_A = train_A.map(lambda x:train_augment(x, self.img_size), num_parallel_calls=AUTOTUNE)
            train_A = train_A.repeat()
            train_A = train_A.batch(self.batch_size)
            train_A = train_A.prefetch(AUTOTUNE)
            train_B = tf.data.Dataset.list_files(f"./{self.dataset}/trainB/*.jpg")
            train_B = train_B.map(lambda x: train_augment(x, self.img_size), num_parallel_calls=AUTOTUNE)
            train_B = train_B.repeat()
            train_B = train_B.batch(self.batch_size)
            train_B = train_B.prefetch(AUTOTUNE)
            train_dataset = tf.data.Dataset.zip((train_A, train_B))
            self.train_iterator = iter(train_dataset)
            test_A = tf.data.Dataset.list_files(f"./{self.dataset}/testA/*.jpg")
            test_A = test_A.map(lambda x: test_augment(x, self.img_size), num_parallel_calls=AUTOTUNE)
            test_A = test_A.repeat()
            test_A = test_A.batch(self.sample_num)
            test_A = test_A.prefetch(AUTOTUNE)            
            self.test_iterator = iter(test_A) # only Selfie -> Anime

                    
        elif self.dataset == 'horse2zebra':
            dataset, _ = tfds.load('cycle_gan/horse2zebra',
                                    with_info=True, as_supervised=True)
            train_A, train_B = dataset['trainA'], dataset['trainB']
            test_A, test_B = dataset['testA'], dataset['testB']
            train_A = train_A.map(lambda x, y: train_augment(x, y, self.img_size), num_parallel_calls=AUTOTUNE)
            train_A = train_A.repeat()
            train_A = train_A.batch(self.batch_size)
            train_A = train_A.prefetch(AUTOTUNE)
            train_B = train_B.map(lambda x, y: train_augment(x, y, self.img_size), num_parallel_calls=AUTOTUNE)
            train_B = train_B.repeat()
            train_B = train_B.batch(self.batch_size)
            train_B = train_B.prefetch(AUTOTUNE)
            train_dataset = tf.data.Dataset.zip((train_A, train_B))
            self.train_iterator = iter(train_dataset)
            test_A = test_A.map(lambda x, y: test_augment(x, y, self.img_size), num_parallel_calls=AUTOTUNE)
            test_A = test_A.batch(self.sample_num)
            test_A = test_A.repeat()
            self.test_iterator = iter(test_A) # only Horses -> Zebras

        # Model building (Total 6 models)
        self.genA2B =  ResnetGenerator(output_nc=3, ngf=self.ch, n_blocks=self.n_res, light=self.light) # A -> B gen
        self.genB2A =  ResnetGenerator(output_nc=3, ngf=self.ch, n_blocks=self.n_res, light=self.light) # B -> A gen
        self.disGA = Discriminator_global(ndf=self.ch, n_layers=self.n_dis) # Global A disc
        self.disGB = Discriminator_global(ndf=self.ch, n_layers=self.n_dis) # Global B disc
        self.disLA = Discriminator_local(ndf=self.ch, n_layers=self.n_dis) # Local A disc
        self.disLB = Discriminator_local(ndf=self.ch, n_layers=self.n_dis) # Local B disc

        # Optimizer
        self.gen_opt = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=0.5, beta_2=0.999)
        self.disc_opt = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=0.5, beta_2=0.999)


        # dir
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        self.ckpt_path = os.path.join(current_time, self.ckpt_path)
        self.tensorboard_path = os.path.join(current_time, self.tensorboard_path)
        self.img_save_path = os.path.join(current_time, self.img_save_path)
        os.makedirs(self.tensorboard_path)
        os.makedirs(self.img_save_path)

        # ckpt manager 
        ckpt = tf.train.Checkpoint(genA2B=self.genA2B, 
                                   genB2A=self.genB2A,
                                   disGA=self.disGA,
                                   disGB=self.disGB,
                                   disLA=self.disLA,
                                   disLB=self.disLB)
        self.ckpt_manager = tf.train.CheckpointManager(ckpt, self.ckpt_path, max_to_keep=4)

        # summary writer
        train_log_path = os.path.join(self.tensorboard_path, 'train')
        self.train_summary_writer = tf.summary.create_file_writer(train_log_path)
        print(f'******* Train result log will be written to {train_log_path} ******')   
        test_log_path = os.path.join(self.tensorboard_path, 'test')
        self.test_summary_writer = tf.summary.create_file_writer(test_log_path)
        print(f'******* Test result log will be written to {test_log_path} ******')

    def train(self):

      print('training start !')     

      # for each epoch
      for iter in range(1, self.iterations+1):

          start_time = time.time()
        
          real_A, real_B = self.train_iterator.get_next()

          # discriminator update
          with tf.GradientTape() as disc_tape:

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = ad_loss(real_GA_logit,tf.ones_like(real_GA_logit, dtype=tf.float32)) + ad_loss(fake_GA_logit,tf.zeros_like(fake_GA_logit, dtype=tf.float32))
            D_ad_cam_loss_GA = ad_loss(real_GA_cam_logit,tf.ones_like(real_GA_cam_logit, dtype=tf.float32)) + ad_loss(fake_GA_cam_logit, tf.zeros_like(fake_GA_cam_logit, dtype=tf.float32))
            D_ad_loss_LA = ad_loss(real_LA_logit,tf.ones_like(real_LA_logit, dtype=tf.float32)) + ad_loss(fake_LA_logit, tf.zeros_like(fake_LA_logit, dtype=tf.float32))
            D_ad_cam_loss_LA = ad_loss(real_LA_cam_logit,tf.ones_like(real_LA_cam_logit, dtype=tf.float32)) + ad_loss(fake_LA_cam_logit, tf.zeros_like(fake_LA_cam_logit, dtype=tf.float32))
            D_ad_loss_GB = ad_loss(real_GB_logit,tf.ones_like(real_GB_logit, dtype=tf.float32)) + ad_loss(fake_GB_logit,tf.zeros_like(fake_GB_logit, dtype=tf.float32))
            D_ad_cam_loss_GB = ad_loss(real_GB_cam_logit,tf.ones_like(real_GB_cam_logit, dtype=tf.float32)) + ad_loss(fake_GB_cam_logit, tf.zeros_like(fake_GB_cam_logit, dtype=tf.float32))
            D_ad_loss_LB = ad_loss(real_LB_logit,tf.ones_like(real_LB_logit, dtype=tf.float32)) + ad_loss(fake_LB_logit, tf.zeros_like(fake_LB_logit, dtype=tf.float32))
            D_ad_cam_loss_LB = ad_loss(real_LB_cam_logit,tf.ones_like(real_LB_cam_logit, dtype=tf.float32)) + ad_loss(fake_LB_cam_logit, tf.zeros_like(fake_LB_cam_logit, dtype=tf.float32))
            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)
            D_reg_loss = tf.reduce_sum(self.disGA.losses) + tf.reduce_sum(self.disLA.losses) + tf.reduce_sum(self.disGB.losses) + tf.reduce_sum(self.disLB.losses)
            D_loss = D_loss_A + D_loss_B + D_reg_loss

            if iter == 1:            
              self.disc_tot_vars = []
              for disc_model in [self.disGA, self.disGB, self.disLA, self.disLB]:
                self.disc_tot_vars.extend(disc_model.trainable_variables)
          
          disc_grad = disc_tape.gradient(D_loss, self.disc_tot_vars)
          self.disc_opt.apply_gradients(zip(disc_grad, self.disc_tot_vars))

          # generator update
          with tf.GradientTape() as gen_tape:

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = ad_loss(fake_GA_logit,tf.ones_like(fake_GA_logit, dtype=tf.float32))
            G_ad_cam_loss_GA = ad_loss(fake_GA_cam_logit, tf.ones_like(fake_GA_cam_logit, dtype=tf.float32))
            G_ad_loss_LA = ad_loss(fake_LA_logit, tf.ones_like(fake_LA_logit, dtype=tf.float32))
            G_ad_cam_loss_LA = ad_loss(fake_LA_cam_logit, tf.ones_like(fake_LA_cam_logit, dtype=tf.float32))
            G_ad_loss_GB = ad_loss(fake_GB_logit,tf.ones_like(fake_GB_logit, dtype=tf.float32))
            G_ad_cam_loss_GB = ad_loss(fake_GB_cam_logit, tf.ones_like(fake_GB_cam_logit, dtype=tf.float32))
            G_ad_loss_LB = ad_loss(fake_LB_logit, tf.ones_like(fake_LB_logit, dtype=tf.float32))
            G_ad_cam_loss_LB = ad_loss(fake_LB_cam_logit, tf.ones_like(fake_LB_cam_logit, dtype=tf.float32))

            G_recon_loss_A = recon_loss(fake_A2B2A, real_A)
            G_recon_loss_B = recon_loss(fake_B2A2B, real_B)

            G_identity_loss_A = id_loss(fake_A2A, real_A)
            G_identity_loss_B = id_loss(fake_B2B, real_B)

            G_cam_loss_A = bce_loss(fake_B2A_cam_logit, tf.ones_like(fake_B2A_cam_logit)) + bce_loss(fake_A2A_cam_logit, tf.zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = bce_loss(fake_A2B_cam_logit, tf.ones_like(fake_A2B_cam_logit)) + bce_loss(fake_B2B_cam_logit, tf.zeros_like(fake_B2B_cam_logit))
            
            G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B
            G_reg_loss = tf.reduce_sum(self.genA2B.losses) + tf.reduce_sum(self.genB2A.losses)
            G_loss = G_loss_A + G_loss_B + G_reg_loss

            if iter == 1:
              self.gen_tot_vars= []            
              self.gen_tot_vars = self.genA2B.trainable_variables + self.genB2A.trainable_variables

          gen_grad = gen_tape.gradient(G_loss, self.gen_tot_vars)
          self.gen_opt.apply_gradients(zip(gen_grad, self.gen_tot_vars))

          with self.train_summary_writer.as_default():
            tf.summary.scalar('Discriminator_loss', D_loss, step=iter)
            tf.summary.scalar('Generator_loss', G_loss, step=iter)

          if iter % self.print_freq == 0:            
            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (iter, self.iterations, time.time() - start_time, D_loss, G_loss))

          if iter % self.sample_freq == 0: # Only A->B
            real_A = self.test_iterator.get_next() 
            fake_A2B, _, _ = self.genA2B(real_A)            
            image_save(real_A, fake_A2B, self.img_save_path, iter)
       
          if iter % self.save_freq == 0:
            self.ckpt_manager.save(checkpoint_number=iter)
            print(f'******* {str(iter)} checkpoint saved to {self.ckpt_path} ******') 
 

SyntaxError: ignored

In [None]:
import os
os.listdir()

In [None]:
tmp = UGATIT()

In [None]:
tmp.build()



*   항목 추가
*   항목 추가



In [None]:
tmp.train()