This project is a simple implementation of the paper Generative Adversarial Net.
As stated in the paper, Generative Adversarial Network (GAN) consists of 2 models. The Generator
model and the Discriminator
model.
Using the MNIST Dataset, the GAN is aimed to generate acurate image representation of digits from 0 - 9. The following screenshot is an example generated result on a model that was trained with epoch=50
.
Install the required packages
pip3 install -r requirements.txt
Training the model is pretty straight forward. Once you have created an instance of MNISTGAN
, you can immediately use its train
method to start training.
mnist_gan = MNISTGAN()
mnist_gan.train(epochs=100, batch_size=64, lr=0.001, is_save_images=True)
Similar to testing, you can use the test
method to generate images using your trained model.
mnist_gan.test(model_path='mymodel.pt')