In [3]:
import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, BatchNormalization, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.utils import plot_model

In [5]:
class Generator:
    def __init__(self, width = 28, height = 28, channels = 1, latent_size = 100, initial_lr = 0.0002):
        self.W = width
        self.H = height
        self.C = channels
        self.OPTIMIZER = Adam(learning_rate = ExponentialDecay(
                        initial_lr,
                        decay_steps = 100000,
                        decay_rate = 8e-9,
                        staircase = True))
        
        
        self.LATENT_SPACE_SIZE = latent_size
        self.latent_space = np.random.normal(0,1, size = self.LATENT_SPACE_SIZE) # 표준정규분포
        
        self.Generator = self.model()
        self.Generator.compile(loss = 'binary_crossentropy', optimizer = self.OPTIMIZER)
    
    def model(self, block_starting_size = 128, num_blocks = 4):
        model = Sequential()
        
        block_size = block_starting_size
        model.add(Dense(block_size, input_shape = (self.LATENT_SPACE_SIZE,)))
        model.add(LeakyReLU(alpha = 0.2))
        model.add(BatchNormalization(momentum = 0.8))
        
        
        for i in range(num_blocks - 1):
            block_size = block_size * 2
            model.add(Dense(block_size))
            model.add(LeakyReLU(alpha = 0.2))
            model.add(BatchNormalization(momentum = 0.8))
            
            
        model.add(Dense(self.W * self.H * self.C, activation = 'tanh')) # 벡터 형태
        model.add(Reshape((self.W, self.H, self.C))) # 벡터 형태를 이미지 형태로 변환
        return model
    
    # 모델 요약
    def summary(self):
        return self.Generator.summary()
    
    # 모델 그림 추출
    def save_model(self):
        plot_model(self.Generator, 
                   to_file = '/Users/gimhyeongeun/Desktop/세미나/Machine Learning/GAN/Generator_Model.png')