In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds


# 定义生成器
def build_generator():
    model = keras.Sequential()
    model.add(layers.Input(shape=(256, 256, 3)))
    model.add(layers.Conv2D(64, kernel_size=7, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2D(256, kernel_size=3, strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
    model.add(layers.ReLU())
    model.add(layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same'))
    model.add(layers.ReLU())
    model.add(layers.Conv2D(3, kernel_size=7, padding='same', activation='tanh'))
    return model


# 定义判别器
def build_discriminator():
    model = keras.Sequential()
    model.add(layers.Input(shape=(256, 256, 3)))
    model.add(layers.Conv2D(64, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2D(128, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2D(256, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2D(512, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2D(1, kernel_size=4, padding='same'))
    return model


# 定义CycleGAN类
class CycleGAN(keras.Model):
    def __init__(self, generator_g, generator_f, discriminator_x, discriminator_y, lambda_cycle=10):
        super(CycleGAN, self).__init__()
        self.generator_g = generator_g
        self.generator_f = generator_f
        self.discriminator_x = discriminator_x
        self.discriminator_y = discriminator_y
        self.lambda_cycle = lambda_cycle

    def train_step(self, real_x, real_y):
        with tf.GradientTape(persistent=True) as tape:
            fake_y = self.generator_g(real_x)
            cycled_x = self.generator_f(fake_y)

            fake_x = self.generator_f(real_y)
            cycled_y = self.generator_g(fake_x)

            disc_real_x = self.discriminator_x(real_x)
            disc_real_y = self.discriminator_y(real_y)

            disc_fake_x = self.discriminator_x(fake_x)
            disc_fake_y = self.discriminator_y(fake_y)

            # 计算损失
            loss_gen_g = keras.losses.binary_crossentropy(tf.ones_like(disc_fake_y), disc_fake_y)
            loss_gen_f = keras.losses.binary_crossentropy(tf.ones_like(disc_fake_x), disc_fake_x)
            loss_cycle_x = tf.reduce_mean(tf.abs(real_x - cycled_x))
            loss_cycle_y = tf.reduce_mean(tf.abs(real_y - cycled_y))

            total_gen_g = loss_gen_g + self.lambda_cycle * loss_cycle_x
            total_gen_f = loss_gen_f + self.lambda_cycle * loss_cycle_y

            loss_disc_x = keras.losses.binary_crossentropy(tf.ones_like(disc_real_x), disc_real_x) + \
                          keras.losses.binary_crossentropy(tf.zeros_like(disc_fake_x), disc_fake_x)

            loss_disc_y = keras.losses.binary_crossentropy(tf.ones_like(disc_real_y), disc_real_y) + \
                          keras.losses.binary_crossentropy(tf.zeros_like(disc_fake_y), disc_fake_y)

        # 计算梯度
        grads_g = tape.gradient(total_gen_g, self.generator_g.trainable_variables)
        grads_f = tape.gradient(total_gen_f, self.generator_f.trainable_variables)
        grads_disc_x = tape.gradient(loss_disc_x, self.discriminator_x.trainable_variables)
        grads_disc_y = tape.gradient(loss_disc_y, self.discriminator_y.trainable_variables)

        # 更新权重
        self.optimizer.apply_gradients(zip(grads_g, self.generator_g.trainable_variables))
        self.optimizer.apply_gradients(zip(grads_f, self.generator_f.trainable_variables))
        self.optimizer.apply_gradients(zip(grads_disc_x, self.discriminator_x.trainable_variables))
        self.optimizer.apply_gradients(zip(grads_disc_y, self.discriminator_y.trainable_variables))

        return {
            "g_loss": total_gen_g,
            "f_loss": total_gen_f,
            "d_x_loss": loss_disc_x,
            "d_y_loss": loss_disc_y,
        }


# 构建模型
generator_g = build_generator()  # G: Horse -> Zebra
generator_f = build_generator()  # F: Zebra -> Horse
discriminator_x = build_discriminator()  # D_X
discriminator_y = build_discriminator()  # D_Y

cyclegan = CycleGAN(generator_g, generator_f, discriminator_x, discriminator_y)
cyclegan.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5))

# 加载Horse2Zebra数据集
dataset, metadata = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']


# 数据预处理
def preprocess_image(image, label):
    image = tf.image.resize(image, [256, 256])
    image = (image / 127.5) - 1  # 归一化到[-1, 1]
    return image


train_horses = train_horses.map(preprocess_image).batch(1


SyntaxError: incomplete input (2009901198.py, line 116)