In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

import keras
from keras.models import Model, Sequential
from keras.layers import Dense, Flatten, Activation, Dropout, Input
from keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.layers.embeddings import Embedding
from keras import losses, optimizers

from keras.datasets import mnist

import keras.backend as K

In [None]:
class BIGAN:
    def __init__(self):
        img_rows = 28
        img_cols = 28
        img_channels = 1
        self.img_shape = (img_row, img_cols, img_channels)
        self.latet_dim = 100
        
        # optimizer
        optimizer = optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999)
        
        # build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss = losses.binary_crossentropy, optimizer = optimizer, metrics = ['accuracy'])
        
        self.generator = self.build_generator()
        self.encoder = self.build_encoder()
        
        # The part of the bigan that trains the discriminator and encoder
        self.discriminator.trainable = False
        
        # Generate image from sampled noise
        z = Input(shape = (self.latet_dim, ))
        img_ = self.generator(z)
        
        # encode image
        img = Input(shape = self.img_shape)
        z_ = self.encoder(img)
        
        # Latent -> img is fake, and img -> latent is valid
        fake = self.discriminator([z, img_])
        valid = self.discriminator([z_, img])
        
        self.bigan_generator = Model(inputs = [z, img], outputs = [fake, valid])
        self.bigan_generator.compile(loss = [losses.binary_crossentropy, losses.binary_crossentropy], optimizer = optimizer)
        
    def build_encoder(self):
        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.latent_dim))

        model.summary()

        img = Input(shape=self.img_shape)
        z = model(img)

        return Model(img, z)

    def build_generator(self):
        model = Sequential()

        model.add(Dense(512, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        z = Input(shape=(self.latent_dim,))
        gen_img = model(z)

        return Model(z, gen_img)

    def build_discriminator(self):

        z = Input(shape=(self.latent_dim, ))
        img = Input(shape=self.img_shape)
        d_in = concatenate([z, Flatten()(img)])

        model = Dense(1024)(d_in)
        model = LeakyReLU(alpha=0.2)(model)
        model = Dropout(0.5)(model)
        model = Dense(1024)(model)
        model = LeakyReLU(alpha=0.2)(model)
        model = Dropout(0.5)(model)
        model = Dense(1024)(model)
        model = LeakyReLU(alpha=0.2)(model)
        model = Dropout(0.5)(model)
        validity = Dense(1, activation="sigmoid")(model)

        return Model([z, img], validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):


            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Sample noise and generate img
            z = np.random.normal(size=(batch_size, self.latent_dim))
            imgs_ = self.generator.predict(z)

            # Select a random batch of images and encode
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            z_ = self.encoder.predict(imgs)
            
            