Your Locan GAN — Designing Two Dimensional Local Attention Mechanisms for Generative Models
Generated gif from interpolating latent variables for maltese dogs.
This repository hosts the official Tensorflow implementation of the paper "Your Local GAN: Designing Two Dimensional Local Attention Mechanisms for Generative Models".
We introduce a new local sparse attention layer that preserves two-dimensional geometry and locality. We show that by just replacing the dense attention layer of SAGAN with our construction, we obtain very significant FID, Inception score and pure visual improvements. FID score is improved from 18.65 to 15.94 on ImageNet, keeping all other parameters the same. The sparse attention patterns that we propose for our new layer are designed using a novel information theoretic criterion that uses information flow graphs. We also present a novel way to invert Generative Adversarial Networks with attention. Our method extracts from the attention layer of the discriminator a saliency map, which we use to construct a new loss function for the inversion. This allows us to visualize the newly introduced attention heads and show that they indeed capture interesting aspects of two-dimensional geometry of real images.
You can read the full paper here.
Explore our model
Probably the easiest way to explore our model is to play directly with it in this Collab Notebook.
However, trying it locally should be easy, following the instructions bellow.
We recommend installing YLG using an Anaconda virtual environment. For installing Anaconda refer to the official docs.
First, create a new virtual environment with Python 3.6:
conda create -n ylg python=3.6 conda activate ylg
Next, install the project requirements:
pip install -r requirements.txt
We make available pre-trained model for YLG SAGAN, after 1M steps training on ImageNet.
If you want to try the model, download it from here.
We recommend saving the pre-trained model under the
ylg/ folder, but you can also choose another location and set the
Generating images for any category of the ImageNet dataset is one command away.
python generate_images.py --category=valley to generate valleys! For a complete list of the categories names, please check
There are several parameters that you can control, such as the number of generated images. You can discover them by running:
python generate_images.py --help
As you can see, the model is able to generate some really good-looking images, but not all generated images are photo-realistic. We expect that training bigger architectures, such as BigGAN, with our 2-d local sparse attention layers, will improve significantly the quality of the generated images.
Invert your own images
In our paper, we present a new inversion technique: we extract a saliency map for the real image out of the attention layer of the Discriminator and we use it to weight a novel loss function in the discriminator's embedding space.
To the best of our knowledge, inversion of big models with attention is achieved in a satisfying degree. You are one command away of trying it out!
python inverse_image.py to invert a cute maltese dog that is saved in the
real_images/ folder. You can run with your own images as well!
python inverse_image.py --image_path=<path> --category=<path> is the command to run.
Train from scratch
We totally understand that you might want to train your own model for a variety of reasons: experimentation with new modules, different datasets, etc. For that reason, we have created the branch
train, which slighly changes the API of Generator and Discriminator for the training. You can checkout in this branch and then use the
train_experiment_main.py script for training YLG from scratch. Please refer to the instructions of the tensorflow-gan library for setting up your training environment (host VM, TPUs/GPUs, bucket, etc) and feel free to open us an issue if you encounter any problem, so we can look on it.
We would like to wholeheartedly thank the TensorFlow Research Cloud (TFRC) program that gave us access to v3-8 Cloud TPUs and GCP credits to train our models on ImageNet.
The code of this repository is heavily based in the tensorflow-gan library. We add the library as a dependency and we only re-implement parts that need modification for YLG. Every file which is modified from tensorflow-gan has a header indicating that it is subject to the license of the tensorflow-gan library.