In [10]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
import numpy as np
import matplotlib.pyplot as plt

In [11]:
import os
if not os.path.exists('./gan_images'):
    os.makedirs('./gan_images')

In [13]:
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization()) #정규화
generator.add(Reshape((7, 7, 128)))
generator.add(UpSampling2D()) #가로세로를 2배씩 확장해주는것(MAXPOOLING 이랑 반대)
generator.add(Conv2D(64, kernel_size=5, padding='same')) #가상의 패딩을 줘서 사이즈가 똑같이 나오게
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))

In [14]:
generator.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
batch_normalization_2 (Batch (None, 6272)              25088     
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        204864    
_________________________________________________________________
batch_normalization_3 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 64)       

In [15]:
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28, 28, 1), padding='same'))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding='same'))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))

In [16]:
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False #gan핵심!! 판별할때마다 기준이 바뀌면 곤란하므로 학습하지 않음

In [17]:
discriminator.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_3 (Conv2D)            (None, 14, 14, 64)        1664      
_________________________________________________________________
activation_2 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
activation_3 (Activation)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 6272)             

In [18]:
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output) #ginput이 입력, dis_output이 출력
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 28, 28, 1)         865281    
_________________________________________________________________
sequential_2 (Sequential)    (None, 1)                 212865    
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________


In [21]:
def gan_train(epoch, batch_size, saving_interval):
    (X_train, _), (_, _) = mnist.load_data()
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float')
    X_train = (X_train - 127.5) / 127.5 #전체 픽셀데이터를 -1에서 1사이로 맞춤
    
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for i in range(epoch):
        idx = np.random.randint(0, X_train.shape[0], batch_size) #0에서 x_train.shape[0]까지 배치사이즈만큼 인덱스를 뽑아와서
        imgs = X_train[idx]
        d_loss_real = discriminator.train_on_batch(imgs, true) #실제 mnist이미지를 가지고 한거
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        g_loss = gan.train_on_batch(noise, true) #generator 입장에서의 loss
        
        print('epoch:%d' % i, 'd_loss:%.4f' % d_loss, 'g_loss:%.4f' % g_loss)
        
        if i % saving_interval == 0:
            noise = np.random.normal(0, 1, (25, 100))
            gen_imgs = generator.predict(noise)
            
            gen_imgs = 0.5 * gen_imgs + 0.5 # image를 0 ~ 1 사이로 나오게 rescale
            
            fig, axs = plt.subplots(5,5)
            count = 0
            for j in range(5):
                for k in range(5):
                    axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                    axs[j, k].axis('off')
                    count += 1
            fig.savefig('gan_images/gan_mnist_%d.png' % i)


In [22]:
gan_train(4001, 32, 200)

epoch:0 d_loss:0.4157 g_loss:0.1824
epoch:1 d_loss:0.3847 g_loss:0.0236
epoch:2 d_loss:0.4015 g_loss:0.0058
epoch:3 d_loss:0.4217 g_loss:0.0039
epoch:4 d_loss:0.4385 g_loss:0.0097
epoch:5 d_loss:0.3476 g_loss:0.0638
epoch:6 d_loss:0.4076 g_loss:0.1584
epoch:7 d_loss:0.5641 g_loss:0.0943
epoch:8 d_loss:0.5938 g_loss:0.1209
epoch:9 d_loss:0.5536 g_loss:0.3142
epoch:10 d_loss:0.4543 g_loss:0.7974
epoch:11 d_loss:0.4250 g_loss:1.0659
epoch:12 d_loss:0.3488 g_loss:1.0352
epoch:13 d_loss:0.2719 g_loss:1.0529
epoch:14 d_loss:0.1781 g_loss:1.6356
epoch:15 d_loss:0.1102 g_loss:2.3455
epoch:16 d_loss:0.1043 g_loss:2.7151
epoch:17 d_loss:0.1163 g_loss:2.4296
epoch:18 d_loss:0.0948 g_loss:2.1150
epoch:19 d_loss:0.2392 g_loss:2.8236
epoch:20 d_loss:0.3359 g_loss:3.2631
epoch:21 d_loss:0.4599 g_loss:3.2626
epoch:22 d_loss:0.8498 g_loss:2.7804
epoch:23 d_loss:0.9605 g_loss:1.9332
epoch:24 d_loss:0.5969 g_loss:2.1694
epoch:25 d_loss:0.7963 g_loss:2.4410
epoch:26 d_loss:0.6080 g_loss:2.1704
epoch:27 d_

  fig, axs = plt.subplots(5,5)


epoch:3801 d_loss:0.5251 g_loss:1.4907
epoch:3802 d_loss:0.5162 g_loss:1.5358
epoch:3803 d_loss:0.4874 g_loss:1.5930
epoch:3804 d_loss:0.6497 g_loss:1.3648
epoch:3805 d_loss:0.6041 g_loss:1.6907
epoch:3806 d_loss:0.3741 g_loss:1.9557
epoch:3807 d_loss:0.5077 g_loss:1.7072
epoch:3808 d_loss:0.4365 g_loss:1.4326
epoch:3809 d_loss:0.4627 g_loss:1.3649
epoch:3810 d_loss:0.4677 g_loss:1.4132
epoch:3811 d_loss:0.4331 g_loss:1.6378
epoch:3812 d_loss:0.5523 g_loss:1.9279
epoch:3813 d_loss:0.4565 g_loss:1.8926
epoch:3814 d_loss:0.5401 g_loss:2.0428
epoch:3815 d_loss:0.3064 g_loss:1.9661
epoch:3816 d_loss:0.3500 g_loss:2.0125
epoch:3817 d_loss:0.4895 g_loss:1.5005
epoch:3818 d_loss:0.4120 g_loss:1.7462
epoch:3819 d_loss:0.5079 g_loss:1.8638
epoch:3820 d_loss:0.5733 g_loss:1.6179
epoch:3821 d_loss:0.3611 g_loss:1.6588
epoch:3822 d_loss:0.4352 g_loss:1.8533
epoch:3823 d_loss:0.4207 g_loss:2.0905
epoch:3824 d_loss:0.4534 g_loss:1.4575
epoch:3825 d_loss:0.6233 g_loss:1.4748
epoch:3826 d_loss:0.5285 