# Worksheet 26

Name:  
UID: 

### Topics

- Generative Adversarial Networks

## Generative Adversarial Networks



In [None]:
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense
from keras.layers.activation import ReLU
from PIL import Image as im

TEMPFILE = 'temp.png'

# Define the parameters
np.random.seed(0)
gen_input_dim = 100
epochs = 100
batch_size = 128
images = []

# Define the generator model
generator = Sequential()
generator.add(Dense(32, input_dim=gen_input_dim, activation='tanh'))
generator.add(Dense(2))

# Define the discriminator model
discriminator = Sequential()
discriminator.add(Dense(16, input_dim=2))
discriminator.add(ReLU())
discriminator.add(Dense(1, activation='sigmoid'))

# Compile the models
generator.compile(loss='mse')
discriminator.compile(loss='binary_crossentropy')

# Define the GAN model
gan = Sequential()
gan.add(generator)
gan.add(discriminator)
gan.compile(loss='binary_crossentropy')

# Define the real data
x_real = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000)

# Train the GAN
for epoch in range(epochs):
    # Generate fake data
    z = np.random.normal(size=(batch_size, gen_input_dim))
    x_fake = generator.predict(z)

    # Train the discriminator
    discriminator.trainable = True
    discriminator.train_on_batch(x_real, np.ones((len(x_real), 1)))
    discriminator.train_on_batch(x_fake, np.zeros((batch_size, 1)))

    # Train the generator
    discriminator.trainable = False
    gan.train_on_batch(z, np.ones((batch_size, 1)))

    # Plot the progress
    fig, ax = plt.subplots()
    ax.scatter(x_real[:, 0], x_real[:, 1], c='r')
    ax.scatter(x_fake[:, 0], x_fake[:, 1], c='b')
    ax.set_title('Epoch {}'.format(epoch))
    fig.savefig(TEMPFILE)
    plt.close()
    images.append(im.fromarray(np.asarray(im.open(TEMPFILE))))

images[0].save(
    'gan.gif',
    optimize=False,
    save_all=True,
    append_images=images[1:],
    loop=0,
    duration=10
)