This is an implementation of the AttnGAN in PyTorch, with some experimental additions and changes.
- Download the Caltech-UCSD Birds-200-2011 dataset and extract it to the root folder of the project.
- Download metadata (includes captions) and copy its contents to the dataset folder.
-
To train a DAMSM model, use the
python -m src.main train-damsm <EPOCHS> <NAME> [OPTIONS]
command.EPOCHS
sets the number of training epochs,NAME
is the name the model is going to be saved with and further referenced by. Options include:- Set patience for early stopping:
--patience=20
- Set device:
--device=cuda:0
- Set patience for early stopping:
-
To train the GAN, use
python -m src.main train-gan <EPOCHS> <NAME> <DAMSM> [OPTIONS]
.EPOCHS
andNAME
are the number of training epochs and the name of the model respectively.DAMSM
is the name of the DAMSM model to be used for text-encoding and auxiliary DAMSM-loss. Options include:- Continue training of a saved model:
--gan=ExampleModelName
- Set device:
--device=cuda:1
- Continue training of a saved model:
-
To generate an image for each sample in the test set, use
python -m src.main validate-gan GAN DAMSM SAVEDIR [OPTIONS]
.GAN
andDAMSM
are the names of the models to be used.SAVEDIR
is the output directory. Options include:- Set device:
--device=cuda:2
- Set device:
For different hyperparameters, change values in config.py
.