In [0]:
# example of defining and using the generator model
from numpy import zeros
from numpy.random import randn
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from matplotlib import pyplot

In [0]:
# define the standalone generator model
def define_generator(latent_dim):
  model = Sequential()
  # foundation for 7x7 image
  n_nodes = 128 * 7 * 7
  model.add(Dense(n_nodes, input_dim=latent_dim))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Reshape((7, 7, 128)))
  # upsample to 14x14
  model.add(Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  # upsample to 28x28
  model.add(Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Conv2D(1, (7, 7), activation="sigmoid", padding ="same"))
  return model

In [0]:
def generate_latent_points(latent_dim, n_samples):
  # generate points in the latent space
  x_input = randn(latent_dim * n_samples)
  # reshape into a batch of inputs for the network
  x_input = x_input.reshape(n_samples, latent_dim)
  return x_input

In [0]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
  # generate points in latent space
  x_input = generate_latent_points(latent_dim, n_samples)
  # predict outputs
  X = g_model.predict(x_input)
  # create "fake" class labels (0)
  y = zeros((n_samples, 1))
  return X, y

In [0]:
# size of the latent space
latent_dim = 100

# define the discriminator model
model = define_generator(latent_dim)

# generate samples
n_samples = 25
X, _ = generate_fake_samples(model, latent_dim, n_samples)

In [0]:
# plot the generated samples
for i in range(n_samples):
  # define subplot
  pyplot.subplot(5, 5, 1 + i)
  # turn off axis labels
  pyplot.axis("off")
  # plot single image
  pyplot.imshow(X[i, :, :, 0], cmap="gray_r")

# show the figure
pyplot.show()