In [None]:
# 생성자 모델 : 가상의 이미지를 생성하는 모델
                
# 판별자 모델 : 이미지가 진짜인지 가짜인지 판별하는 모델..

# gan 모델 : 생성자와 판별자가 결합된 모델..
          #  생성자가 생성한 이미지를 판별자에게 판별하도록 하는 모델
          #  생성자가 생성한 이미지가 True 라벨로 학습 => 생성자가 진짜이미지와 같은 이미지를 생성하도록 학습..

In [1]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
import os 
if not os.path.exists("./gan_images"):
  os.mkdir("./gan_images")

In [None]:
# 생성자

generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())
generator.add(Reshape((7,7,128)))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))  # 28*28*1
generator.summary()

In [None]:
# 판별자

discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding='same'))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding='same'))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.summary()
discriminator.compile(loss='binary_crossentropy', optimizer='adam')

In [12]:
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output)

In [None]:
gan.summary()

In [19]:
gan.compile(loss='binary_crossentropy', optimizer='adam')

In [20]:
def gan_train(epoch, batch_size, saving_interval):

  (x_train, _), (_, _) = mnist.load_data()

  x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
  x_train = (x_train - 127.5)/ 127.5

  true = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for i in range(epoch):

    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]
    d_loss_real = discriminator.train_on_batch(imgs, true)

    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
    g_loss = gan.train_on_batch(noise, true)

    print('epoch :', i, 'd_loss :', d_loss, 'g_loss :', g_loss)

    if i % saving_interval == 0:
      noise = np.random.normal(0,1,(25,100))
      gen_imgs = generator.predict(noise)

      gen_imgs = 0.5* gen_imgs + 0.5

      fig, axs = plt.subplots(5,5)
      count = 0

      for i in range(5):
        for k in range(5):
          axs[i, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
          axs[i, k].axis('off')
          count += 1

      fig.savefig('gan_images/gan_mnist_%d.png'%i)

In [None]:
gan_train(4001, 32, 200)

epoch : 0 d_loss : 0.45196833461523056 g_loss : 0.3055127263069153
epoch : 1 d_loss : 0.5142690779175609 g_loss : 0.01482410915195942
epoch : 2 d_loss : 0.5793946535013674 g_loss : 0.0010411683470010757
epoch : 3 d_loss : 0.6295207321500129 g_loss : 0.0005792236188426614
epoch : 4 d_loss : 0.5845661186122015 g_loss : 0.0015240119537338614
epoch : 5 d_loss : 0.5344323874014663 g_loss : 0.006040321663022041
epoch : 6 d_loss : 0.47872307407669723 g_loss : 0.024379204958677292
epoch : 7 d_loss : 0.4556487174704671 g_loss : 0.08127699792385101
epoch : 8 d_loss : 0.45551938004791737 g_loss : 0.17867746949195862
epoch : 9 d_loss : 0.44498256500810385 g_loss : 0.28093814849853516
epoch : 10 d_loss : 0.44790707901120186 g_loss : 0.40272897481918335
epoch : 11 d_loss : 0.4497331604361534 g_loss : 0.4608496129512787
epoch : 12 d_loss : 0.465970441699028 g_loss : 0.49562737345695496
epoch : 13 d_loss : 0.4657941460609436 g_loss : 0.5925681591033936
epoch : 14 d_loss : 0.45344604179263115 g_loss : 