This repository consits of a simple implementation of a generative adversarial network (GAN) using Tensorflow 2.4.1 on the MNIST handwriting data. The GAN constists of two neural network in competition with each other – one trying to generate samples that are as close to the real samples provided (the generator) and the other trying to distinguish the real samples from the generated samples as well as possible (the discriminator).
Simply place train.csv
in the Python working directory and run main.py
. By default, the network will train for 75,000 steps, so it may take a while.
Here are 16 examples of real digits drawn randomly from the dataset
For comparison, here are 16 digits generated by the neural network.
They do resemble real digits, but it looks like the generator found 1 to be the easiest digit to fool the discriminator with, so it chose to generate more of those. In addition, there is some noise in the generated digits, indicating a possible deficiency in the discriminator (since it could not pick up these obvious differences). Perhaps using a convolutional neural network would yield better results.