In [7]:
# !pip install tensorflow-gpu=2.0.0

In [1]:
import tensorflow as tf
import numpy as np
from PIL import Image
tf.__version__

'2.0.0'

In [2]:
BUFFER_SIZE = 60000
BATCH_SIZE = 25
IMG_SIZE = (128,128,3)
EPOCHS = 100

In [3]:
import pathlib
import csv
import json
ROOT = pathlib.Path("./Dataset/Dataset-CelebA")
DATASET_ROOT = ROOT/"CelebA"
DATA_ROOT = DATASET_ROOT/"Img"/"img_align_celeba"
DATA_PATH = str(DATA_ROOT)
EVAL_PATH = str(DATASET_ROOT/"Eval"/"list_eval_partition.txt")
ANNO_PATH = str(DATASET_ROOT/"Anno"/"list_attr_celeba.txt")
LANDMARK_PATH = str(ROOT/"all_landmark.json")
SEL_ATTRS = ["Goatee"]

In [4]:
from tensorflow.keras.layers import *
from tensorflow.keras import Model

In [5]:
class autoencoder_model(Model):
    def __init__(self, num_filters,outputChannels,img_size,data_format):
        super(autoencoder_model, self).__init__()
        # channels_last
        
        # formula = (W-Kernel+2P)/S + 1
        self.ZeroPadding2D1 = ZeroPadding2D(3)
        self.Conv1 = Conv2D(filters=num_filters,kernel_size=7,strides=1,
                            input_shape=img_size,data_format=data_format)
        self.norm1 = BatchNormalization()
        self.Conv2 = Conv2D(filters=num_filters*2,kernel_size=3,strides=2,padding='same')
        self.norm2 = BatchNormalization()
        # 64x64
        self.Conv3 = Conv2D(filters=num_filters*4,kernel_size=3,strides=2,padding='same')
        self.norm3 = BatchNormalization()
        # 32x32
        self.Conv4 = Conv2D(filters=num_filters*8,kernel_size=3,strides=2,padding='same')
        self.norm4 = BatchNormalization()
        # 16x16
        
        # formula = (W-1)*S-2P+Kernel
        self.Dcon1 = Conv2DTranspose(filters=num_filters*4,kernel_size=4,strides=2,padding='same',
                                    input_shape=(16,16,num_filters*8))
        self.normd1 = BatchNormalization()
        # 32x32
        self.Dcon2 = Conv2DTranspose(filters=num_filters*2,kernel_size=4,strides=2,padding='same')
        self.normd2 = BatchNormalization()
        # 64x64
        self.Dcon3 = Conv2DTranspose(filters=num_filters,kernel_size=4,strides=2,padding='same')
        self.normd3 = BatchNormalization()
        # 128x128
        # need self.ZeroPadding2D = ZeroPadding2D(3)
        self.ZeroPadding2D2 = ZeroPadding2D(3)
        self.Dcon4 = Conv2D(filters=outputChannels,kernel_size=7,strides=1)
        
        self.Conv1_out = None
        self.Conv2_out = None
        self.Conv3_out = None
        
    def call(self, inputTensor):
        # Whole Network
        EncoderOutput = self.Encoder(inputTensor)
        DecoderOutput = self.Decoder(EncoderOutput)
        return DecoderOutput
    def Encoder(self, inputTensor):
        # a half of network 
        inputTensor = self.ZeroPadding2D1(inputTensor)
        Conv1_out = self.Conv1(inputTensor)
        Conv1_out = self.norm1(Conv1_out)
        self.Conv1_out = activations.relu(Conv1_out)
        
        Conv2_out = self.Conv2(Conv1_out)
        Conv2_out = self.norm2(Conv2_out)
        self.Conv2_out = activations.relu(Conv2_out)
        
        Conv3_out = self.Conv3(Conv2_out)
        Conv3_out = self.norm3(Conv3_out)
        self.Conv3_out = activations.relu(Conv3_out)
        
        Conv4_out = self.Conv4(Conv3_out)
        Conv4_out = self.norm1(Conv4_out)
        Conv4_out = activations.relu(Conv4_out)
        
        return Conv4_out
    def Decoder(self, inputTensor):
        # a half of network 
        Dcon1_out = self.Dcon1(inputTensor)
        Dcon1_out += self.Conv3_out
        Dcon1_out = self.normd1(Dcon1_out)
        Dcon1_out = activations.relu(Dcon1_out)
        
        Dcon2_out = self.Dcon2(Dcon1_out)
        Dcon2_out += self.Conv2_out
        Dcon2_out = self.normd2(Dcon2_out)
        Dcon2_out = activations.relu(Dcon2_out)
        
        Dcon3_out = self.Dcon3(Dcon2_out)
        Dcon3_out += self.Conv1_out
        Dcon3_out = self.normd3(Dcon3_out)
        Dcon3_out = activations.relu(Dcon3_out)
        
        Dcon4_out = self.ZeroPadding2D2(Dcon3_out)
        Dcon4_out = self.Dcon4(Dcon4_out)
        return Dcon4_out

$L_{flow} = L_{lm}+L_{tv}, L_{adv}, L_{cls}, L_{recon}$

In [6]:
from tensorflow.keras.losses import *
from tensorflow.keras import activations

In [7]:
def flow_loss():
    mae = MeanAbsoluteError()
    def landmark_loss(ori_a,ori_b,wrap_a):
        # (wrap_a.x+ori_b.x-ori_a.x)^2 + (wrap_a.y+ori_b.y-ori_a.y)^2
        distance = tf.math.square(wrap_a.x+ori_b.x-ori_a.x) +\
        tf.math.square(wrap_a.y+ori_b.y-ori_a.y)
        summation = tf.math.reduce_sum(summation)
        return summation
    def totalVariation_loss(flow_x,flow_y):
        loss_x = tf.image.total_variation(flow_x)
        loss_y = tf.image.total_variation(flow_y)
        return loss_x + loss_y
    return landmark_loss
def lsgan_loss():
    mse = MeanSquaredError()
    def discriminator_loss(real_output, fake_output):
        real_loss = mse(tf.ones_like(real_output),real_output)
        fake_loss = mse(tf.zeros_like(fake_output),fake_output)
        return real_loss + fake_loss
    def generator_loss(fake_output):
        return mse(tf.ones_like(fake_output),fake_output)
    return discriminator_loss, generator_loss
def cls_loss(real_output,fake_output):
    bce = BinaryCrossentropy()
    real_loss = bce(tf.ones_like(real_output),real_output)
    fake_loss = bce(tf.zeros_like(fake_output),fake_output)
    return real_loss + fake_loss
def recon_loss(a,b):
    mae = MeanAbsoluteError()
    return mae(a,b)

In [8]:
class generator_model(Model):
    def __init__(self, img_size):
        super(generator_model, self).__init__()
        # autoencoder_model(num_filters,outputChannels,img_size,data_format)
        self.flowNet = autoencoder_model(32,2,img_size,"channels_last")
        self.maskNet = autoencoder_model(32,1,img_size,"channels_last")
        grid_np = np.mgrid[-1:1 + 1e-7 :2 / (img_size[0] - 1), -1:1 + 1e-7:2 / (img_size[1] - 1)]
        self.flow_grid_np = tf.convert_to_tensor(grid_np.reshape((128, 128, 2)))
        self.refinementNet = autoencoder_model(32,3,img_size,"channels_last")
        self.removeResidualNet = autoencoder_model(32,3,img_size,"channels_last")
    # x means be without attribute
    # y means be with attribute
    def call(self, Ax, By):
        # Figure 2 on Paper
        
        # Level 2 in Figure 2
        fake_Ay, Ax_flow, By_flow,\
        By_mask, fake_Ay_wo_refine, By_warpped = self.addAttribute(Ax, By, epoch)
        
        fake_Bx = self.removeAttribute(By, By_mask)
        
        # Level 3 in Figure 2
        fake_By, fake_Bx_flow, fake_Ay_flow,\
        Ax_mask, _           , _          = self.addAttribute(fake_Bx, fake_Ay, epoch)
        fake_Ax = self.removeAttribute(fake_Ay, Ax_mask)
        
        return_items = {}
        return_items['fake_Ay'] = fake_Ay
        return_items['fake_Ax'] = fake_Ax
        return_items['fake_By'] = fake_By
        return_items['fake_Bx'] = fake_Bx
        return_items['flows'] = [Ax_flow, By_flow, fake_Bx_flow, fake_Ay_flow]
        return_items['masks'] = [Ax_mask, By_mask]
        return return_items
        
    # Main Module
    def addAttribute(self, Ax, By, epoch):
        # ------------Flow Sub-Network------------
        # ---------Flow Net---------
        Ax_front = self.flowNet.Encoder(Ax)
        By_front = self.flowNet.Encoder(By)
        fusion_BottleNeck = tf.concat([Ax_front, By_front], axis=3)
        By_flow = self.flowNet.Decoder(fusion_BottleNeck)
        # ---------Bilinear Sampler---------
        By_warpped = self.warp_flow(By, By_flow)
        
        # ---------Mask Net---------
        Ax_mask_front = self.maskNet.Encoder(Ax)
        By_mask_front = self.maskNet.Encoder(By_warp_flow)
        # ---------Blend---------
        mask_fusion_BottleNeck = tf.concat([Ax_mask_front, By_mask_front], axis=3)
        By_mask = self.maskNet.Decoder(mask_fusion_BottleNeck)
        By_mask = Activation('sigmoid')(By_mask)
        AyStar = self.blend(By_mask,Ax,By_warp_flow)
        
        # ------------Refinement Sub-Network------------
        mask_input = tf.stop_gradient(By_mask)
        _, residual_Ay = self.refinementNet(tf.concat([AyStar, mask_input],axis=3))
        # Watch out
        refineWeight = min(input_refineWeight, 0.1*max(epoch - 10,0))
        residual_Ay = 2*tf.math.tanh(residual_Ay) * refineWeight * By_mask
        Ay = tf.clip_by_value(AyStar,residual_Ay,-1,1)
        
        # For Removal Module need(I don't understand on -By_flow)
        raw_By_mask = self.warp_flow(By_mask,-By_flow)
        
        return Ay, tf.zeros_like(By_flow), By_flow, raw_By_mask, AyStar, By_warpped
        
    # Another Module
    def removeAttribute(self, By, mask):
        _, rBy = self.removeResidualNet(tf.concat([By,mask],axis=1))
        B = tf.clip_by_value(By+ 2*tf.math.tanh(rBy) ,-1,1)
        return B
        
    # Little Function
    def warp_flow(self, image,flow):
        flow_grid = self.flow_grid_np + flow
        permute = tf.bijectors.Permute(permutation=[0, 2, 3, 1])
        warp_image = self.bilinearSampler(image,permute.forward(flow_grid))
        return warp_image
    
    def bilinearSampler(self,image,flow):
        return tf.contrib.resampler.resampler(image,warp,name='resampler')

    def blend(self,mask,a,b):
        return mask*a + (1-mask)*b

In [9]:
generator = generator_model(IMG_SIZE)

In [10]:
flow_loss_func = flow_loss()
lsgan_loss_func = lsgan_loss()
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [11]:
#@tf.function
def train_step(Ax,By,landmark_Ax,landmark_By):
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        generator_return_items = generator(Ax, By, training=True)
        All_fake = tf.concat([generator_return_items['fake_Ay'],\
                             generator_return_items['fake_Ax'],\
                             generator_return_items['fake_By'],\
                             generator_return_items['fake_Bx']],axis=0)
        All_real = tf.concat([Ax,By],axis=0)
        # Generator
        G_GAN_loss = \
            lsgan_loss_func.generator_loss(All_fake)
        G_recon_loss = recon_loss(generator_return_items['fake_Ax'],Ax) +\
                       recon_loss(generator_return_items['fake_By'],By)
        # cls_loss(1,0)
        G_classify_loss = cls_loss(generator_return_items['fake_By'],generator_return_items['fake_Ax'])
        G_landmark_loss = flow_loss_func.landmark_loss(landmark_By,landmark_Ax,generator_return_items['flows'][1])
        G_flow_loss = 0.0
        G_flow_loss = flow_loss_func.totalVariation_loss(flow_x,flow_y)
        for flow in generator_return_items['flows']:
            G_flow_loss += flow_loss_func.totalVariation_loss(flow[0],flow[1])
        G_loss = G_GAN_loss+G_recon_loss+G_classify_loss+G_landmark_loss+G_flow_loss
        
        # Discriminator
        D_GAN_loss = \
            lsgan_loss_func.discriminator_loss(All_real, All_fake)
        D_classify_loss = cls_loss(By,Ax)
        D_loss = D_GAN_loss+D_classify_loss
        print("G_loss: {:.10f}, D_loss: {:.10f}".format(G_loss,D_loss))
    gradients_of_generator = g_tape.gradient(G_loss,generator.trainable_variables)
    gradients_of_discriminator = g_tape.gradient(D_loss,generator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, generator.trainable_variables))

In [12]:
def save_images(Ax,By,Ay,epoch):
    fig = plt.figure(figsize=(1,3))
    plt.subplot(1,3,1)
    plt.imshow(Ax+1)
    plt.axis('off')
    plt.subplot(1,3,2)
    plt.imshow(By+1)
    plt.axis('off')
    plt.subplot(1,3,3)
    plt.imshow(Ay+1)
    plt.axis('off')
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))

In [0]:
#@tf.function
def test_step():
    

In [5]:
# celebA 資料處理
# 要Ax, By, landmark_Ax, landmark_By

def set_celebA_data():
    # 取得Train,Validation圖片檔案名
    with open(EVAL_PATH) as infile:
        lines = infile.readlines()
        lines = [line.strip() for line in lines]
        train_list = [line.split()[0] for line in lines if line.split()[1] == '0']
        # valid_list = [line.split()[0] for line in lines if line.split()[1] == '1']
    with open(LANDMARK_PATH) as infile:
        landmark_dict = json.load(infile)
    # 篩選有Landmark的圖片檔案名
    train_list = [img for img in train_list if img in landmark_dict]
    # valid_list = [img for img in valid_list if img in landmark_dict]
    
    # 載入屬性標記
    with open(ANNO_PATH) as infile:
        lines = infile.readlines()
        # 所有屬性欄位名稱
        all_attrNames = lines[1].split()
        attribute_dict = {}
        # SEL_ATTRS位在該檔案屬性欄位Index
        selected_attribute_index = [all_attrNames.index(sel_attr) \
                                   for idx, sel_attr in enumerate(SEL_ATTRS) \
                                   if sel_attr in all_attrNames]
        selected_attribute_index = np.array(selected_attribute_index)
        # 檔名迭代
        for line in lines[2:]:
            splits = line.split()
            # 取得所有欄位戳記
            attribute_value = [int(x) for x in splits[1:]]
            attribute_value = np.array(attribute_value)
            # attribute_value[attribute_value == -1] = 0
            # 選擇SEL_ATTRS
            attribute_dict[splits[0]] = attribute_value[selected_attribute_index]
    attribute_dict = {img: attribute_dict[img] \
                      for img in attribute_dict \
                      if  img in landmark_dict}
    train_attribute_dict = {img: attribute_dict[img]\
                            for img in attribute_dict
                            if img in train_list}
    # Filter into Ax and By
    Ax_list = []
    By_list = []
    for fname in train_attribute_dict:
        if train_attribute_dict[fname] == 1:
            By_list.append(fname)
        else:
            Ax_list.append(fname)
    return Ax_list,By_list,landmark_dict

In [6]:
Ax_filenames,By_filenames, landmark_dict = set_celebA_data()

In [13]:
Ax_filenames_ds = Ax_filenames[:10000]
By_filenames_ds = By_filenames[:10000]
with open(str(DATASET_ROOT/'Ax.csv'),'w',newline='') as myfile:
    wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
    wr.writerow(Ax_filenames_ds)
with open(str(DATASET_ROOT/'By.csv'),'w',newline='') as myfile:
    wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
    wr.writerow(By_filenames_ds)

NameError: name 'Ax_filenames' is not defined

In [14]:
Ax_filenames_ds = []
By_filenames_ds = []
with open(str(DATASET_ROOT/'Ax.csv')) as myfile:
    rows = csv.reader(myfile)
    for row in rows:
        Ax_filenames_ds = row
with open(str(DATASET_ROOT/'By.csv')) as myfile:
    rows = csv.reader(myfile)
    for row in rows:
        By_filenames_ds = row
with open(LANDMARK_PATH) as infile:
    landmark_dict = json.load(infile)

In [37]:
landmark_dict

{'021434.jpg': [[63.0, 116.0],
  [63.0, 128.0],
  [65.0, 139.0],
  [68.0, 149.0],
  [71.0, 159.0],
  [75.0, 169.0],
  [76.0, 175.0],
  [81.0, 179.0],
  [91.0, 184.0],
  [105.0, 181.0],
  [118.0, 175.0],
  [129.0, 168.0],
  [138.0, 158.0],
  [144.0, 146.0],
  [145.0, 136.0],
  [147.0, 125.0],
  [148.0, 112.0],
  [61.0, 105.0],
  [61.0, 103.0],
  [65.0, 103.0],
  [69.0, 103.0],
  [74.0, 105.0],
  [92.0, 103.0],
  [98.0, 102.0],
  [105.0, 102.0],
  [112.0, 103.0],
  [121.0, 105.0],
  [82.0, 113.0],
  [81.0, 122.0],
  [78.0, 129.0],
  [78.0, 136.0],
  [78.0, 140.0],
  [79.0, 142.0],
  [82.0, 142.0],
  [86.0, 140.0],
  [91.0, 140.0],
  [68.0, 113.0],
  [68.0, 110.0],
  [74.0, 110.0],
  [79.0, 113.0],
  [74.0, 115.0],
  [69.0, 115.0],
  [98.0, 113.0],
  [101.0, 110.0],
  [106.0, 110.0],
  [112.0, 112.0],
  [108.0, 115.0],
  [101.0, 115.0],
  [76.0, 155.0],
  [76.0, 153.0],
  [81.0, 151.0],
  [84.0, 152.0],
  [86.0, 151.0],
  [94.0, 153.0],
  [102.0, 156.0],
  [94.0, 159.0],
  [89.0, 161.0],


#### 進行中
* Batch Normalization
* Image Preprocessing
* Training Function
#### 未來
* Flow Visualizer

In [15]:
def decode_img(img):
    # img = tf.image.decode_jpeg(img,channels=3)
    img = tf.image.crop_to_bounding_box(img,20,0,178,178)
    img = tf.image.convert_image_dtype(img,tf.float32)
    img = tf.image.resize(img, [IMG_SIZE[0], IMG_SIZE[1]])
    img = img - 1
    return img
'''def process_path(file_path):
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    filename = tf.strings.split(file_path, '/')
    landmark = landmark_dict[filename[-1]]
    return img, landmark

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.shuffle(buffer_size=shuffle_buffer_size)
    ds = ds.repeat()
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds
# image_batch, label_batch = next(iter(train_ds))
'''

"def process_path(file_path):\n    img = tf.io.read_file(file_path)\n    img = decode_img(img)\n    filename = tf.strings.split(file_path, '/')\n    landmark = landmark_dict[filename[-1]]\n    return img, landmark\n\ndef prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):\n    if cache:\n        if isinstance(cache, str):\n            ds = ds.cache(cache)\n        else:\n            ds = ds.cache()\n\n    ds = ds.shuffle(buffer_size=shuffle_buffer_size)\n    ds = ds.repeat()\n    ds = ds.batch(BATCH_SIZE)\n    ds = ds.prefetch(buffer_size=AUTOTUNE)\n    return ds\n# image_batch, label_batch = next(iter(train_ds))\n"

In [38]:
Ax[0].shape

TensorShape([128, 128, 3])

In [16]:
import cv2

In [17]:
Ax = []
By = []
Ax_landmark = []
By_landmark = []
for imagename in Ax_filenames_ds:
    filename = str(DATA_ROOT/imagename)
    im = cv2.imread(filename,cv2.IMREAD_COLOR)
    im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
    im = decode_img(im)
    Ax.append(im)
    Ax_landmark.append(landmark_dict[imagename])
for imagename in By_filenames_ds:
    filename = str(DATA_ROOT/imagename)
    im = cv2.imread(filename,cv2.IMREAD_COLOR)
    im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
    im = decode_img(im)
    By.append(im)
    By_landmark.append(landmark_dict[imagename])

In [18]:
Ax_ds = tf.data.Dataset.from_tensor_slices((Ax, Ax_landmark))
Ax_ds = Ax_ds.shuffle(10000).batch(BATCH_SIZE)

In [19]:
By_ds = tf.data.Dataset.from_tensor_slices((By, By_landmark))
By_ds = By_ds.shuffle(10000).batch(BATCH_SIZE)

In [None]:
'''Axlist_ds = tf.data.Dataset.list_files(Ax_filenames_ds)
Bylist_ds = tf.data.Dataset.list_files(By_filenames_con)
Axlabeled_ds = Axlist_ds.map(process_path, num_parallel_calls=AUTOTUNE)
Bylabeled_ds = Bylist_ds.map(process_path, num_parallel_calls=AUTOTUNE)
Ax_ds = prepare_for_training(Axlabeled_ds)
By_ds = prepare_for_training(Bylabeled_ds)
'''

In [20]:
import matplotlib.pyplot as plt

In [0]:
def show_batch(image_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.axis('off')

In [21]:
for epoch in range(EPOCHS):
    By_iter = iter(By_ds)
    for one_Ax, one_Ax_landmark in Ax_ds:
        one_By, one_By_landmark = next(By_iter)
        train_step(one_Ax,one_By,one_Ax_landmark,one_By_landmark)
    testAx = np.random.randint(10000)
    testBy = np.random.randint(10000)
    testAy = generator(testAx,testBy)['fake_Ay']
    save_images(testAx,testBy,testAy,epoch)
    By_ds = By_ds.shuffle(10000).batch(BATCH_SIZE)

<class 'tensorflow.python.framework.ops.EagerTensor'>
(25, 128, 128, 2)
(128, 128, 2)


InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2] name: generator_model/add/

In [0]:
# Flow 視覺化