In [269]:
import  tensorflow as tf
import  os
# 1 隱藏 INFO logs, 2 額外隱藏WARNING logs, 設置為3所有 ERROR logs也不顯示
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from    tensorflow import keras
from    tensorflow.keras import layers
import  numpy as np


class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
        self.fc = layers.Dense(3*3*512)

        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

    def call(self, inputs, training=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        x = tf.nn.leaky_relu(x)

        #
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        x = tf.tanh(x)
        print("Generator call")

        return x

In [270]:

class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        # [b, 64, 64, 3] => [b, 1]
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()
        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        # [b, h, w ,c] => [b, -1]
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)


    def call(self, inputs, training=None):

        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        # [b, h, w, c] => [b, -1]
        x = self.flatten(x)
        # [b, -1] => [b, 1]
        logits = self.fc(x)
        print("Discriminator call")

        return logits

In [271]:
d = Discriminator()
g = Generator()

x = tf.random.normal([2, 64, 64, 3])
z = tf.random.normal([2, 100])

prob = d(x)
print(prob)
x_hat = g(z)
print(x_hat.shape)


Discriminator call
tf.Tensor(
[[0.00307981]
 [0.08289697]], shape=(2, 1), dtype=float32)
Generator call
(2, 64, 64, 3)


In [272]:
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
    # @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        # img = tf.image.random_crop(img,[resize, resize])
        # img = tf.image.random_flip_left_right(img)
        # img = tf.image.random_flip_up_down(img)
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1 #-1~1
        return img

    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size

    return dataset, img_shape, len_dataset

In [273]:
def celoss_ones(logits):
    return - tf.reduce_mean(logits)

def celoss_zeros(logits):
    return tf.reduce_mean(logits)

In [274]:
def gradient_penalty(discriminator, batch_x, fake_image):

    batchsz = batch_x.shape[0]

    # [b, h, w, c]
    t = tf.random.uniform([batchsz, 1, 1, 1])
    # [b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)

    interplate = t * batch_x + (1 - t) * fake_image

    with tf.GradientTape() as tape:
        tape.watch([interplate])
        d_interplote_logits = discriminator(interplate, training=True)
    grads = tape.gradient(d_interplote_logits, interplate)

    # grads:[b, h, w, c] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1) #[b]
    gp = tf.reduce_mean( (gp-1)**2 )

    return gp

In [275]:

def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    # 1. treat real image as real
    # 2. treat generated image as fake
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    d_real_logits = discriminator(batch_x, is_training)

    d_loss_real = celoss_ones(d_real_logits)
    d_loss_fake = celoss_zeros(d_fake_logits)
    gp = gradient_penalty(discriminator, batch_x, fake_image)

    loss = d_loss_real + d_loss_fake + 10. * gp

    return loss, gp


def g_loss_fn(generator, discriminator, batch_z, is_training):

    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    loss = celoss_ones(d_fake_logits)

    return loss

In [276]:

from tensorflow.keras import layers, optimizers, datasets  # 导入 TF 子库
(x, y), (x_val, y_val) = datasets.mnist.load_data()  # 加载数据集
x = np.resize(x, (len(x),64, 64,3))
x = x[:512]
print(x.shape)
db_iter = iter(x)
def train_gan():
    tf.random.set_seed(233)
    np.random.seed(233)
    assert tf.__version__.startswith('2.')


    # hyper parameters
    z_dim = 100
    epochs = 3000000
    batch_size = 512
    learning_rate = 0.0005
    is_training = True

    generator = Generator() 
    generator.build(input_shape = (None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))
    z_sample = tf.random.normal([100, z_dim])


    g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)


    for epoch in range(epochs):

        for _ in range(5):
            batch_z = tf.random.normal([batch_size, z_dim])
            batch_x = next(db_iter)
            batch_x = np.resize(batch_x,(1, batch_x.shape[0], batch_x.shape[1], batch_x.shape[2]))
            
            batch_x = np.array(batch_x, dtype = 'float64') 

            # train D
            with tf.GradientTape() as tape:
                d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
            grads = tape.gradient(d_loss, discriminator.trainable_variables)
            d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
        
        batch_z = tf.random.normal([batch_size, z_dim])

        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss),
                  'gp:', float(gp))

            # z = tf.random.normal([100, z_dim])
            # fake_image = generator(z, training=False)
            # img_path = os.path.join('images', 'wgan-%d.png'%epoch)
            # save_result(fake_image.numpy(), 10, img_path, color_mode='P')

train_gan()

(512, 64, 64, 3)
Generator call
Discriminator call
Generator call
Discriminator call


KeyboardInterrupt: 