In [2]:
from tensorflow import keras 
from tensorflow.keras import layers 
from tensorflow_docs.vis import embed 
import matplotlib.pyplot as plt 
import tensorflow as tf 
import numpy as np 
import imageio

# 파라미터 설정 

In [10]:
batch_size = 64 
num_channels = 1 
num_classes = 10 
image_size = 28 
latent_dim = 128
tf.random.set_seed(999) 

# 데이터셋 로드 및 전처리 

In [20]:
#데이터 로드 
(x_train,y_train), (x_test,y_test) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train,x_test])
all_labels = np.concatenate([y_train,y_test])

#전처리 
all_digits = all_digits.astype('float32')/255. #dtype 변경 및 정규화 
all_digits = np.reshape(all_digits,(-1,28,28,1)) #shape 변경 
all_labels = keras.utils.to_categorical(all_labels,10)

#데이터 제너레이터 
dataset = tf.data.Dataset.from_tensor_slices((all_digits,all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

#생성자, 식별자 채널 세팅 
generator_in_channels = latent_dim + num_classes #128 + 10 = 138 
discriminator_in_channels = num_channels + num_classes # 10 + 1 = 11 


# 모델 생성 

## 식별자, 생성자 모델 생성 

In [37]:
#식별자 
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((28, 28, discriminator_in_channels)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

#생성자 
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)), #138
        layers.Dense(7 * 7 * generator_in_channels), #7*7*138
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, generator_in_channels)), #7,7,138
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), #(7,7,128) -> (14,14,128)
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), #(14,14,128) -> (28,28,128)
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), #(28,28,128) -> (28,28,1)
    ],
    name="generator",
)


## CGAN 생성 

In [66]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN,self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim 
        self.gen_loss_tracker = keras.metrics.Mean(name = "generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name= "discriminator_loss")

    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]
    
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN,  self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn 

    def train_step(self,data):
        real_images, one_hot_labels = data 
        
        """ 
        식별자용 
        """
        #real_image 와 concat 하기 위해 condition(one_hot_labels)를 28,28,1 형태로 reshape 함 
        image_one_hot_labels = one_hot_labels[:,:,None,None] #batch_size,10 -> batch_size,10,1,1
        image_one_hot_labels = tf.repeat(image_one_hot_labels,repeats=[image_size * image_size]) #batch_size,10,1,1 -> batch_size*28*28*10
        image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1,image_size,image_size,num_classes)) #to batch_size,28,28,10
        
        """
        생성자용
        """
        #랜덤 추출 및 concat 
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size,self.latent_dim))
        random_vector_labels = tf.concat([random_latent_vectors,one_hot_labels],axis=1)
        
        #noise + condition 한 걸로 이미지 생성 
        generated_images = self.generator(random_vector_labels)

        """
        이미지 + 라벨 -> 가짜 이미지 + 진짜 이미지 
        """
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        """
        가짜 이미지 -> 1로 라벨링 
        """
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        """
        식별자 학습 
        """
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        
        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }


In [68]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    #property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = tf.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = tf.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

# CGAN 학습 

In [69]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset, epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x25374f6e190>