In [1]:
import tensorflow as tf 
import numpy as np

# define Upsample block 

In [11]:
print(tf.__version__)

2.0.0-alpha0


In [12]:
class UpSampleLayer(tf.keras.Model):
    def __init__(self, filter, **kwargs):
        super(UpSampleLayer, self).__init__()
        # set default conv attributes
        self.conv_dict = {}
        self.conv_dict['filter'] = filter
        self.conv_dict['kernel' ] = 4
        self.conv_dict['strides'] = 2
        self.conv_dict['padding'] = 'same'
        self.conv_dict['activation'] = 'relu'
        
        for key, value in kwargs.items():
            self.conv_dict[key] = value

        self.conv = tf.keras.layers.Conv2DTranspose(
            self.conv_dict['filter'],
            kernel_size=self.conv_dict['kernel'], 
            strides=self.conv_dict['strides'],
            padding=self.conv_dict['padding'], 
            activation=self.conv_dict['activation'])

    def call(self, x):
        #this for debug
        #print(self.conv_dict)
        return self.conv(x)

In [5]:
class UpSampleBlock(tf.keras.Model):
    def __init__(self):
        super(UpSampleBlock, self).__init__()
        self.cv1 = UpSampleLayer(128, kernel=4, strides=2, padding='same', activation='relu')
        self.cv2 = UpSampleLayer(64, kernel=4, strides=2, padding='same', activation='relu')
        self.cv3 = UpSampleLayer(3, kernel=7, strides=1, padding='same', activation='relu')

    def call(self, x):
        x = self.cv1(x)
        x = self.cv2(x)
        x = self.cv3(x)
        return x

# Test upsample block

In [6]:
ub = UpSampleBlock()
data_size=10
img_width = 128
img_height=128
img_channel = 3

x_data = np.random.normal(size = [data_size, img_width, img_height, img_channel])
hair_color = np.random.uniform(low=0, high=3, size=[data_size])
gender = np.random.uniform(low=0, high = 1, size=[data_size])
old = np.random.uniform(low=0, high=1, size=[data_size])


train_dataset = tf.data.Dataset.from_tensor_slices(x_data)
train_dataset = train_dataset.shuffle(2).batch(2)

for img in train_dataset.take(1) :
    print(f'chekck input must be 4 dim :{img.shape}')
    
    x = ub(img)
    print(f'x.shape:{x.shape}')
#     l_adv = np.log(d_src(img))
    
# x = ub(img)

chekck input must be 4 dim :(2, 128, 128, 3)
x.shape:(2, 512, 512, 3)


# define Down-sample block 

In [7]:
class DownSampleLayer(tf.keras.Model):
    def __init__(self,filter, **kwargs):
        super(DownSampleLayer, self).__init__()
        # set default conv attributes
        self.conv_dict = {}
        self.conv_dict['filter'] = filter
        self.conv_dict['kernel' ] = 4
        self.conv_dict['strides'] = 2
        self.conv_dict['padding'] = 'same'
        self.conv_dict['activation'] = 'relu'
        
        for key, value in kwargs.items():
            self.conv_dict[key] = value
            
        self.conv = tf.keras.layers.Conv2D(
            self.conv_dict['filter'],
            kernel_size=self.conv_dict['kernel'], 
            strides=self.conv_dict['strides'],
            padding=self.conv_dict['padding'], 
            activation=self.conv_dict['activation'])
        
    def call(self, x):
        return self.conv(x)

In [6]:
class DownSampleBlock(tf.keras.Model):
    def __init__(self):
        super(DownSampleBlock, self).__init__()
        self.cv1 = DownSampleLayer(64, kernel=7, strides = 1)
        self.cv2 = DownSampleLayer(128, kernel=4, strides = 2)
        self.cv3 = DownSampleLayer(256, kernel=4, strides = 2)
        
    def call(self,x):
        x = self.cv1(x)
        x = self.cv2(x)
        x = self.cv3(x)
        return x

# Test downsample block

In [7]:
db = DownSampleBlock()

for img in train_dataset.take(1):
    print(f'chekck input must be 4 dim :{img.shape}')
    x = db(img)
    print(f'x.shape:{x.shape}')

chekck input must be 4 dim :(2, 128, 128, 3)
x.shape:(2, 32, 32, 256)


In [15]:
class HiddenBlock(tf.keras.Model):
    def __init__(self):
        super(HiddenBlock, self).__init__()
        self.hidden1 = DownSampleLayer(128, kernel=4, strides=2, padding='valid', activation='LeakyReLU')
        self.hidden2 = DownSampleLayer(256, kernel=4, strides=2, padding='valid', activation='LeakyReLU')
        self.hidden3 = DownSampleLayer(512, kernel=4, strides=2, padding='valid', activation='LeakyReLU')
        self.hidden4 = DownSampleLayer(1024, kernel=4, strides=2, padding='valid', activation='LeakyReLU')
        self.hidden5 = DownSampleLayer(2048, kernel=4, strides=2, padding='valid', activation='LeakyReLU')

    def call(self, input):
        output = self.hidden1(input)
        output = self.hidden2(output)
        output = self.hidden3(output)
        output = self.hidden4(output)
        output = self.hidden5(output)
        return output

In [16]:
class InstanceNormalization(tf.keras.layers.Layer):
    """InstanceNormalization for only 4-rank Tensor (image data)
    """
    def __init__(self, epsilon=1e-5):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        shape = tf.TensorShape(input_shape)
        param_shape = shape[-1]
        # Create a trainable weight variable for this layer.
        self.gamma = self.add_weight(name='gamma',
                                     shape=param_shape,
                                     initializer='ones',
                                     trainable=True)
        self.beta = self.add_weight(name='beta',
                                    shape=param_shape,
                                    initializer='zeros',
                                    trainable=True)
        # Make sure to call the `build` method at the end
        super(InstanceNormalization, self).build(input_shape)
        
    @tf.function
    def call(self, inputs):
        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        reduction_axes = [1, 2] # only shape index
        mean, variance = tf.nn.moments(inputs, reduction_axes, keep_dims=True)
        normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon)
        return self.gamma * normalized + self.beta


In [19]:
class ResidualBlock(tf.keras.Model):
    def __init__(self, filter, kernel ):
        super(ResidualBlock, self).__init__()
        self.layer1 = DownSampleLayer(filter, kernel=kernel, activation='relu')
        self.bn1 = InstanceNormalization()
        self.layer2 = DownSampleLayer(filter, kernel=kernel, activation='none')
        self.bn2 = InstanceNormalization()
    
    @tf.function
    def call(self, input):
        output = self.layer1(input)
        output = self.bn1(output)
        output = self.layer2(output)
        output = self.bn2(output)
        return ReLU(input+output)

In [18]:
class DiscSrc(tf.keras.Model):
    def __init__(self):
        super(DiscSrc, self).__init__()
        self.cv1 = DownSampleLayer(64, kernel=4, strides=2, padding='valid', activation='LeakyReLU')
        self.hidden = HiddenBlock()
        self.fc = DownSampleLayer(1, kernel=3, strides=1, padding='valid')

    @tf.function
    def call(self, input):
        output = self.cv1(input)
        output = self.hidden(output)
        output = self.fc(output)
        return output 

In [20]:
class DiscCls(tf.keras.Model):
    def __init__(self, unit, kernel):
        super(DiscCls, self).__init__()
        self.cv1 = DownSampleLayer(64, kernel=4, strides=2, padding='valid', activation='LeakyReLU')
        self.hidden = HiddenBlock()
        self.fc = DownSampleLayer(unit, kernel=kernel, strides=1, padding='same')

    @tf.function
    def call(self, input):
        output = self.cv1(input)
        output = self.hidden(output)
        output = self.fc(output)
        return output 

In [21]:
class Generator(tf.keras.Model):
    def __init__(self, unit, kernel):
        super(Generator, self).__init__()
        self.ds1 = DownSampleLayer(64, kernel=7, strides=1, padding='valid', activation='ReLU')
        self.ds2 = DownSampleLayer(128,kernel=4, strides=2, padding='valid', activation='ReLU')
        self.ds3 = DownSampleLayer(256,kernel=4, strides=2, padding='valid', activation='ReLU')
        self.hidden1 = HiddenBlock()
        self.us1 = UpSampleBlock(128, kernel=4, strides=2, padding='valid', activation='ReLU')
        self.us2 = UpSampleBlock(128, kernel=4, strides=2, padding='valid', activation='ReLU')
        self.conv = DownSampleLayer(3, kernel=7, strides=1, padding='same', activation='tanh')
        
    @tf.function       
    def call(self, input):
        output = self.ds1(input)
        output = self.ds2(output)
        output = self.ds3(output)
        output = self.hidden1(output)
        output = self.us1(output)
        output = self.us2(output)
        output = self.conv(output)
        return output 
    

In [22]:
print('test')

test


In [None]:

lambda_rec = 10
lambda_cls = 1 


def adverserial_loss(logits, real=True):
  if real == True:
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(logits),logits = logits)
  else:
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(logits),logits = logits)
  return loss

def reconstruction_loss(image,rec_image):
  return lambda_rec * np.abs(tf.reduce_mean(image  - rec_image))

def domain_cls_loss(domain, logits):
  return lambda_cls * tf.losses.sigmoid_cross_entropy(multi_class_labels = domain, logits = logits)

def G_loss(fake_D_src, target_D_cls, target_domain,input_image, reconstructed_image, lambda_cls,lambda_rec):
  loss = adverserial_loss(fake_D_src, real=True) + lambda_cls * domain_cls_loss(target_domain, target_D_cls) + lambda_rec * reconstruction_loss(input_image, reconstructed_image)
  return loss

def D_loss(real_D_src, fake_D_src, original_domain, original_D_cls, lambda_cls):
  loss = -1 * (adverserial_loss(real_D_src, real=True) + adverserial_loss(fake_D_src, real=False)) + lambda_cls* domain_cls_loss(original_domain,original_D_cls)
  return loss

generator_optimizer = tf.train.AdamOptimizer(1e-4)
discriminator_optimizer = tf.train.AdamOptimizer(1e-4)

def train_step(input_image, original_domain, target_domain):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
    # generator + discriminator combined
    fake_image = generator(input_image, target_domain)  # step(b)
    fake_D_src, target_D_cls = discriminator(fake_image)  # step(d) # 우선 fake image를 넣어서 보조 classification을 학습
    reconstructed_image = generator(fake_image, original_domain) # step(c)

    # discriminator
    real_D_src, original_D_cls = discriminator(input_image) #step(a) 
    fake_D_src, fake_D_cls = discriminator(fake_image) #step(a) 
    
    generator_loss = G_loss(fake_D_src, target_D_cls, target_domain,input_image, reconstructed_image, lambda_cls,lambda_rec)
    discriminator_loss = D_loss(real_D_src, fake_D_src, original_domain, original_D_cls, lambda_cls)
    
    gradients_of_generator = gen_tape.gradient(generator_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    print()

def train(train_dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for input_image, original_domain in train_dataset:
            
      target_domain = random_target_domain_generation()
      train_step(input_image, original_domain, target_domain)
          



