In [None]:
from keras.datasets import mnist
from keras.optimizers import Adam

from keras.models import Sequential, Model

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        optimizer = Adam(0.0002, 0.5)
        
        # discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer = optimizer, metrics=['accuracy'])
        
        # generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer = optimizer)
        
        # noise
        z = Input(shape=(100,))
        img = self.generator(z)
        
        self.discriminator.trainable = False
        
        valid = self.discriminator(img)
        
        self.combined = Model(z, valid)
        self.combined.compile(loss = 'binary_crossentropy', optimizer=optimizer)
     
    def build_generator(self):
        
        noise_shape = (100,)
        
        model = Sequential()
        
        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
        
        model.summary()
        
        noise = Input(shape= noise_shape)
        img = model(noise)
        
        return Model(noise, img)
    
    def build_discriminator(self):
        
        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = Sequential()
        
        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        
        img = Input(shape=img_shape)
        validity = model(img)
        
        return Model(img, validity)
    
    