In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers

# 超参数设置
latent_dim = 100
n_critic = 5
epochs = 10000
batch_size = 64
clip_value = 0.01


# 生成器模型
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_dim=latent_dim))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(2, activation='tanh'))  # 生成2D数据
    return model


# 判别器模型
def build_critic():
    model = tf.keras.Sequential()
    model.add(layers.Dense(512, input_dim=2))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(1))  # 输出为一个标量
    return model


# 初始化生成器和判别器
generator = build_generator()
critic = build_critic()

# 编译判别器
critic.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.00005), loss='mean_squared_error')


# 生成数据
def generate_real_samples(n):
    x1 = np.random.uniform(-1.0, 1.0, n)
    x2 = np.random.uniform(-1.0, 1.0, n)
    return np.array(list(zip(x1, x2)))


def generate_latent_points(n):
    return np.random.normal(0, 1, (n, latent_dim))


def generate_fake_samples(n):
    z = generate_latent_points(n)
    return generator.predict(z)


# 训练WGAN
for epoch in range(epochs):
    # 训练判别器
    for _ in range(n_critic):
        # 生成真实样本和假样本
        X_real = generate_real_samples(batch_size)
        X_fake = generate_fake_samples(batch_size)

        # 更新判别器
        d_loss_real = critic.train_on_batch(X_real, np.ones((batch_size, 1)))
        d_loss_fake = critic.train_on_batch(X_fake, -np.ones((batch_size, 1)))

        # 进行权重裁剪
        for layer in critic.layers:
            weights = layer.get_weights()
            weights = [np.clip(w, -clip_value, clip_value) for w in weights]
            layer.set_weights(weights)

    # 训练生成器
    z = generate_latent_points(batch_size)
    g_loss = critic.train_on_batch(generator.predict(z), np.ones((batch_size, 1)))

    # 每1000轮打印一次损失
    if epoch % 1000 == 0:
        print(f'Epoch: {epoch}, D Loss Real: {d_loss_real}, D Loss Fake: {d_loss_fake}, G Loss: {g_loss}')


# 可视化生成的样本
def plot_generated_samples(generator, n=100):
    samples = generate_fake_samples(n)
    plt.scatter(samples[:, 0], samples[:, 1])
    plt.title("Generated Samples")
    plt.xlim(-2, 2)
    plt.ylim(-2, 2)
    plt.show()


plot_generated_samples(generator)


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step 


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 906us/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step  
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 785us/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step  
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 986us/step
Epoch: 0, D Loss Real: 0.9971429109573364, D Loss Fake: 0.9974877238273621, G Loss: 0.9976844787597656
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step  
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 824us/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 934us/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 926us/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 866us/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 999us/step
[1m2/2[0m [32m━━━━━━

KeyboardInterrupt: 