In [5]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Flatten
from keras.models import Model
from keras.layers.advanced_activations import LeakyReLU

import random

In [3]:
mnist = pd.read_csv("../datasets/mnist_train_small.csv", header=None).values

In [4]:
X, y = mnist[:, 1:].reshape(-1, 28, 28, 1), mnist[:, 0]

In [9]:
in_disc = Input(shape=(28, 28, 1))
x = Conv2D(16, (3, 3))(in_disc)
x = LeakyReLU(alpha=0.2)(x)
x = MaxPooling2D()(x)
x = Conv2D(32, (3, 3))(x)
x = LeakyReLU(alpha=0.2)(x)
x = MaxPooling2D()(x)
x = Conv2D(32, (3, 3))(x)
x = LeakyReLU(alpha=0.2)(x)
x = MaxPooling2D()(x)
x = Flatten()(x)
x = Dense(10, activation="tanh")(x)
out = Dense(1, activation="sigmoid")(x)

In [10]:
disc = Model(in_disc, out)

In [12]:
in_gen = Input(shape=(7, 7, 1))
x = Conv2D(16, (3, 3), padding="same")(in_gen)
x = LeakyReLU(alpha=0.2)(x)
x = UpSampling2D()(x)
x = Conv2D(8, (3, 3), padding="same")(x)
x = LeakyReLU(alpha=0.2)(x)
x = UpSampling2D()(x)
x = Conv2D(4, (3, 3), padding="same")(x)
out = Conv2D(1, (3, 3), padding="same", activation="sigmoid")(x)

In [13]:
gen = Model(in_gen, out)

In [16]:
in_gan = Input(shape=(7, 7, 1))
layer_gen = gen(in_gan)
layer_disc = disc(layer_gen)

In [17]:
gan = Model(in_gan, layer_disc)

In [18]:
X_mod = X / 255

In [19]:
gan.compile(optimizer="adam", loss="binary_crossentropy")

In [20]:
disc.compile(optimizer="adam", loss="binary_crossentropy")

In [21]:
gen.compile(optimizer="adam", loss="binary_crossentropy")

In [None]:
for i in range(10):
    noise = np.random.rand(500, 50)

    pred = gen.predict(noise)
    real = X_mod[np.random.randint(0, len(X_mod), 500)]

    disc_input = np.vstack([pred, real])

    result = np.vstack([np.zeros([500, 1]), np.ones([500, 1])])

    disc.trainable = True
    disc.fit(disc_input, result, batch_size=100, epochs=5, verbose=0)

    disc.trainable = False
    gan.fit(noise, np.ones([500, 1]), batch_size=50, epochs=10, verbose=0)
    
    plt.figure()
    
    plt.imshow(pred[0].reshape(28, 28))

  'Discrepancy between trainable weights and collected trainable'


In [42]:
gan.summary()

Model: "model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         (None, 50)                0         
_________________________________________________________________
model_5 (Model)              (None, 784)               420084    
_________________________________________________________________
model_4 (Model)              (None, 1)                 415321    
Total params: 1,250,726
Trainable params: 835,405
Non-trainable params: 415,321
_________________________________________________________________


  'Discrepancy between trainable weights and collected trainable'
