# 데이터셋 불러오기

In [1]:
from tensorflow.keras.datasets import mnist
(train_images, _), (_, _) = mnist.load_data()

In [2]:
train_images.dtype

dtype('uint8')

In [3]:
train_images = (train_images.astype('float32') - 127.5) / 127.5 # -1~1 사이로 표준화
print(train_images.dtype)

float32


In [4]:
import numpy as np
train_images = np.expand_dims(train_images, axis=-1)
print(train_images.shape) # channel last 이미지 데이터

(60000, 28, 28, 1)


# generator, discriminator 모델 정의

In [5]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

generator_model = Sequential([
    layers.Input(shape=(100,)),
    layers.Dense(7*7*256, use_bias=False, activation='elu'),
    layers.BatchNormalization(),

    layers.Reshape((7,7,256)),
    layers.Conv2DTranspose(128, (5,5), padding='same', activation='elu'),
    layers.BatchNormalization(),

    layers.Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', activation='elu'),
    layers.BatchNormalization(),

    layers.Conv2DTranspose(1, (5,5), strides=(2,2), padding='same', activation='tanh')
])
generator_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 12544)             1254400   
                                                                 
 batch_normalization (Batch  (None, 12544)             50176     
 Normalization)                                                  
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose (Conv2DTr  (None, 7, 7, 128)         819328    
 anspose)                                                        
                                                                 
 batch_normalization_1 (Bat  (None, 7, 7, 128)         512       
 chNormalization)                                                
                                                        

In [6]:
discriminator_model = Sequential([
    layers.Input(shape=(28,28,1)),
    layers.Conv2D(64, (5,5), strides=(2,2), padding='same', activation='elu'),
    layers.Dropout(0.3),

    layers.Conv2D(128, (5,5), strides=(2,2), padding='same', activation='elu'),
    layers.Dropout(0.3),

    layers.Flatten(),
    layers.Dense(1)
])
discriminator_model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 14, 14, 64)        1664      
                                                                 
 dropout (Dropout)           (None, 14, 14, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 128)         204928    
                                                                 
 dropout_1 (Dropout)         (None, 7, 7, 128)         0         
                                                                 
 flatten (Flatten)           (None, 6272)              0         
                                                                 
 dense_1 (Dense)             (None, 1)                 6273      
                                                                 
Total params: 212865 (831.50 KB)
Trainable params: 212

# 손실함수와 옵티마이저 정의

In [7]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    return real_loss + fake_loss

generator_optimizer = tf.keras.optimizers.Adam()
discriminator_optimizer = tf.keras.optimizers.Adam()

# 배치 데이터 준비

In [8]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).batch(256)
print(train_dataset)

<_BatchDataset element_spec=TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None)>
