In [3]:
!pip install git+https://github.com/tensorflow/examples.git

In [4]:
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
import tensorflow as tf
import tensorflow_datasets
from tensorflow_examples.models.pix2pix import pix2pix
from tqdm import tqdm

AUTOTUNE = tf.data.AUTOTUNE

## CFG

In [5]:
class CFG:
    lr = 2e-4
    beta = 0.5
    L = 10
    epochs = 20
    buffer = 1000
    batch = 1
    image_width = 256
    image_height = 256
    out_channels = 3
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    data_path = 'cycle_gan/horse2zebra'
    ckpt_path = "./ckpt"

## 导入数据

In [7]:
data, _ = tensorflow_datasets.load(CFG.data_path,with_info=True, as_supervised=True)

h_train, z_train = data['trainA'], data['trainB']
h_test, z_test = data['testA'], data['testB']
h_train,h_test

## 数据处理

In [8]:
# 训练集处理
def train_preprocess(x, flag=None, normalize=True):
    # 调整图片大小至 286*286
    x_resized = tf.image.resize(x, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # 随机裁剪
    x_cropped = tf.image.random_crop(x_resized, size=[CFG.image_height, CFG.image_width, 3])
    # 随机水平翻转
    x_flipped = tf.image.random_flip_left_right(x_cropped)
    if normalize == False:
        return x_flipped
    # 归一化
    x_normalized = (tf.cast(x_flipped, tf.float32) / 127.5) - 1
    return x_normalized

# 测试集处理
def test_preprocess(x, flag=None):
    # 归一化
    x_normalized = (tf.cast(x, tf.float32) / 127.5) - 1
    return x_normalized

In [9]:
h_train = h_train.cache().map(train_preprocess, num_parallel_calls=AUTOTUNE).shuffle(CFG.buffer).batch(CFG.batch)
z_train = z_train.cache().map(train_preprocess, num_parallel_calls=AUTOTUNE).shuffle(CFG.buffer).batch(CFG.batch)

h_test = h_test.map(test_preprocess, num_parallel_calls=AUTOTUNE).cache().shuffle(CFG.buffer).batch(CFG.batch)
z_test = z_test.map(test_preprocess, num_parallel_calls=AUTOTUNE).cache().shuffle(CFG.buffer).batch(CFG.batch)

h_train,z_train

## 生成迭代器

In [43]:
first_h = next(iter(h_train))
first_z = next(iter(z_train))

## 生成器和判别器

In [11]:
generator_g = pix2pix.unet_generator(CFG.out_channels, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(CFG.out_channels, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

## 损失函数

In [12]:
# 生成器损失
def generator_loss(g):
    return CFG.loss(tf.ones_like(g), g)

# 判别器损失
def discriminator_loss(x, g):
    loss = CFG.loss(tf.ones_like(x), x)
    loss_g = CFG.loss(tf.zeros_like(g), g)
    return (loss + loss_g) * 0.5 

#循环损失
def cycle_loss(real_x, cycle_x):
    loss = tf.reduce_mean(tf.abs(real_x - cycle_x))
    return CFG.L * loss

# 一致性损失
def identity_loss(real_x, same_x):
    loss = tf.reduce_mean(tf.abs(real_x - same_x))
    return CFG.L * 0.5 * loss

## 优化器

In [13]:
optimizer_g_g = tf.keras.optimizers.Adam(CFG.lr, beta_1=CFG.beta)
optimizer_g_f = tf.keras.optimizers.Adam(CFG.lr, beta_1=CFG.beta)

optimizer_d_x = tf.keras.optimizers.Adam(CFG.lr, beta_1=CFG.beta)
optimizer_d_y = tf.keras.optimizers.Adam(CFG.lr, beta_1=CFG.beta)

## checkpoint

In [14]:
checkpoint = tf.train.Checkpoint(generator_g=generator_g,
                                   generator_f=generator_f,
                                   discriminator_x=discriminator_x,
                                   discriminator_y=discriminator_y,
                                   generator_g_optimizer=optimizer_g_g,
                                   generator_f_optimizer=optimizer_g_f,
                                   discriminator_x_optimizer=optimizer_d_x,
                                   discriminator_y_optimizer=optimizer_d_y)

checkpoint_manager = tf.train.CheckpointManager(checkpoint, CFG.ckpt_path, max_to_keep=5)

In [15]:
@tf.function
def train_loop(real_x, real_y):
    
    with tf.GradientTape(persistent=True) as tape:
        fake_y = generator_g(real_x, training=True)
        cycle_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycle_y = generator_g(fake_x, training=True)

        # same_x and same_y 用于计算一致性损失
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        real_x_d = discriminator_x(real_x, training=True)
        real_y_d = discriminator_y(real_y, training=True)

        fake_x_d = discriminator_x(fake_x, training=True)
        fake_y_d = discriminator_y(fake_y, training=True)

        # 计算损失
        loss_g_g = generator_loss(fake_y_d)
        loss_g_f = generator_loss(fake_x_d)
        loss_c = cycle_loss(real_x, cycle_x) + cycle_loss(real_y, cycle_y)
        total_g_g = loss_g_g + loss_c + identity_loss(real_y, same_y)
        total_g_f = loss_g_f + loss_c + identity_loss(real_x, same_x)
        
        loss_d_x = discriminator_loss(real_x_d, fake_x_d)
        loss_d_y = discriminator_loss(real_y_d, fake_y_d)

    # 计算梯度
    gradient_g_g = tape.gradient(total_g_g, generator_g.trainable_variables)
    gradient_g_f = tape.gradient(total_g_f, generator_f.trainable_variables)

    gradient_d_x = tape.gradient(loss_d_x, discriminator_x.trainable_variables)
    gradient_d_y = tape.gradient(loss_d_y, discriminator_y.trainable_variables)

    # 优化
    optimizer_g_g.apply_gradients(zip(gradient_g_g, generator_g.trainable_variables))
    optimizer_g_f.apply_gradients(zip(gradient_g_f, generator_f.trainable_variables))

    optimizer_d_x.apply_gradients(zip(gradient_d_x, discriminator_x.trainable_variables))
    optimizer_d_y.apply_gradients(zip(gradient_d_y, discriminator_y.trainable_variables))

In [16]:
def generate(model, x):
    pred = model(x)
    # 得到[0, 1]之间的像素值
    plt.figure(dpi=70,figsize=(8, 8))
    plt.subplot(121)
    plt.title('Input Image',fontsize=15)
    plt.imshow(x[0] * 0.5 + 0.5)
    plt.axis('off')
    
    plt.subplot(122)
    plt.title('Predicted Image',fontsize=15)
    plt.imshow(pred[0] * 0.5 + 0.5)
    plt.axis('off')
    
    plt.show()

In [44]:
plt.figure(dpi=70,figsize=(8,8))
plt.subplot(221)
plt.title('Original image',fontsize=15)
plt.imshow(first_h[0] * 0.5 + 0.5)
plt.axis('off')

plt.subplot(223)
plt.imshow(first_z[0] * 0.5 + 0.5)
plt.axis('off')

plt.subplot(222)
plt.title('Preprocessed image',fontsize=15)
plt.imshow(train_preprocess(first_h[0], normalize=False) * 0.5 + 0.5)
plt.axis('off')

plt.subplot(224)
plt.imshow(train_preprocess(first_z[0], normalize=False) * 0.5 + 0.5)
plt.axis('off')

plt.show()

In [45]:
to_z = generator_g(first_h)
to_h = generator_f(first_z)

plt.figure(dpi=70,figsize=(8, 8))
plt.subplot(221)
plt.title('Horse',fontsize=15)
plt.imshow(first_h[0] * 0.5 + 0.5)
plt.axis('off')

plt.subplot(222)
plt.title('To Zebra',fontsize=15)
plt.imshow(to_z[0] * 0.5 * 8 + 0.5)
plt.axis('off')

plt.subplot(223)
plt.title('Zebra',fontsize=15)
plt.title('Preprocessed image',fontsize=15)
plt.imshow(first_z[0] * 0.5 + 0.5)
plt.axis('off')

plt.subplot(224)
plt.title('To Horse',fontsize=15)
plt.imshow(to_h[0] * 0.5 * 8 + 0.5)
plt.axis('off')

plt.show()

In [49]:
def train(h_train, z_train, first_h):
    for epoch in range(CFG.epochs):
        print('Epoch:',epoch+1)
        dataset = tf.data.Dataset.zip((h_train, z_train))
        with tqdm(total=len(dataset), desc='train', leave=True, ncols=100, unit='B', unit_scale=True) as pbar:
            for x, y in dataset:
                train_loop(x, y)
                pbar.update(1)

        if (epoch + 1) % 5 == 0:
            generate(generator_g, first_h)
            path = checkpoint_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,path))

def test(h_test):
    for x in h_test:
        generate(generator_g, x)

In [47]:
train(h_train, z_train, first_h)

In [62]:
test(h_test)