<a href="https://colab.research.google.com/github/kimhwijin/TensorflowWithKeras/blob/master/CNN/DCGAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import tensorflow as tf
from tensorflow.keras import models, layers, datasets, optimizers

import matplotlib.pyplot as plt
import sys
import numpy as np

In [20]:
class DCGAN():
  def __init__(self, rows, cols, channels, z=10):
    #input shape
    self.img_rows = rows
    self.img_cols = cols
    self.channels = channels
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    self.latent_dim = z

    optimizer = optimizers.Adam(0.0002, 0.5)
    
    #build and compile the discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    #build the generator
    self.generator = self.build_generator()

    #the generator takes noise as input and generates imgs
    z = tf.keras.Input(shape=(self.latent_dim,))
    img = self.generator(z)

    # The discriminator takes generated images as input and determines validity
    valid = self.discriminator(img)

    #생성기 판별기 결합
    self.combined = models.Model(z, valid)
    self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

  def build_generator(self):

    model = models.Sequential()

    model.add(layers.Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
    model.add(layers.Reshape((7, 7, 128)))
    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(128, kernel_size=3, padding="same"))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Activation("relu"))
    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(64, kernel_size=3, padding="same"))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Activation("relu"))
    model.add(layers.Conv2D(self.channels, kernel_size=3, padding="same"))
    model.add(layers.Activation("tanh"))

    model.summary()

    noise = tf.keras.Input(shape=(self.latent_dim,))
    img = model(noise)

    return models.Model(noise, img)

  def build_discriminator(self):

    model = models.Sequential()

    model.add(layers.Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))
    model.add(layers.Conv2D(64, kernel_size=3, strides=2, padding='same'))
    model.add(layers.ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))
    model.add(layers.Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))
    model.add(layers.Conv2D(256, kernel_size=3, strides=1, padding="same"))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))

    model.summary()

    img = tf.keras.Input(shape=self.img_shape)
    validity = model(img)

    return models.Model(img, validity)


  def train(self, epochs, batch_size=256, save_interval=50):
    #load dataset
    (X_train, _), (_, _) = datasets.mnist.load_data()

    #Rescale -1 to 1
    X_train = X_train / 127.5 -1.
    X_train = np.expand_dims(X_train, axis=3)

    #Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

      #판별기 훈련

      #이미지의 랜덤 절반 선택
      idx = np.random.randint(0, X_train.shape[0], batch_size)
      imgs = X_train[idx]

      #noise 생성
      noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
      gen_imgs = self.generator.predict(noise)

      #판별기 훈련 (진짜이미지 1 , 가짜 이미지 0)
      d_loss_real = self.discriminator.train_on_batch(imgs, valid)
      d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
      d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

      #생성기 훈련

      #판별기의 값이 1이 되도록
      g_loss = self.combined.train_on_batch(noise, valid)
      
      print('%d [D loss : %f, acc : %.2f%%] [G loss : %f]' %(epoch, d_loss[0], 100*d_loss[1], g_loss))
      
      if epoch % save_interval == 0:
        self.save_imgs(epoch)
        
  def save_imgs(self, epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, self.latent_dim))
    gen_imgs = self.generator.predict(noise)

    #Rescale image 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r,c)
    cnt = 0
    for i in range(r):
      for j in range(c):
        axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
        axs[i,j].axis('off')
        cnt += 1
    fig.savefig('images/dcgan_mnist_%d.png' % epoch)

    plt.close()

In [22]:
dcgan = DCGAN(28, 28, 1)
dcgan.train(epochs=500, batch_size=256, save_interval=50)

Model: "sequential_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_46 (Conv2D)           (None, 14, 14, 32)        320       
_________________________________________________________________
leaky_re_lu_28 (LeakyReLU)   (None, 14, 14, 32)        0         
_________________________________________________________________
dropout_28 (Dropout)         (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_47 (Conv2D)           (None, 7, 7, 64)          18496     
_________________________________________________________________
zero_padding2d_7 (ZeroPaddin (None, 8, 8, 64)          0         
_________________________________________________________________
batch_normalization_33 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
leaky_re_lu_29 (LeakyReLU)   (None, 8, 8, 64)        

In [None]:
from google.colab import drive
drive.mount('/content/drive')