GAN analysis toolkit accompanying the paper A Neural Tangent Kernel Perspective of GANs (Jean-Yves Franceschi,* Emmanuel de Bézenac,* Ibrahim Ayed,* Mickaël Chen, Sylvain Lamprier, Patrick Gallinari), accepted at ICML 2022.
This code was tested with Python 3.8.1 and 3.9.2, and run on GPUs Nvidia Titan RTX (24GB of VRAM) with CUDA 11.2 as well as Nvidia Titan V (12GB) and Nvidia GeForce RTX 2080 Ti (11 GB) with CUDA 10.2.
The code is primarily based on JAX and Neural Tangents.
A list of required Python packages is available in the requirements.txt
file.
We refer to Jax installation instructions in order to perform computations on GPU.
To download the Density and AB datasets, execute the following command, which will save them in gantk2/data/images
.
bash gantk2/data/download_images.sh
We provide the following proxy command in order to reproduce the experiments of the paper.
python -m gantk2.train --loss_config $LOSS_CONFIG --arch_config $ARCH_CONFIG --data_config $DATA_CONFIG --save_path $SAVE_PATH --save_name $SAVE_NAME --device $DEVICE
where $DEVICE
is the GPU index, $SAVE_PATH
is the directory where the experiment folder will be created and $SAVE_NAME
is the name of the experiment folder.
Different options are available for $LOSS_CONFIG
, $ARCH_CONFIG
and $DATA_CONFIG
, corresponding to the sets of hyperparameters used for the experiments of the paper:
- for
$LOSS_CONFIG
:inf_ipm
(infinite-width IPM),ipm
(finite-width IPM),ipm_reset
(finite-width IPM with reset),inf_lsgan
(infinite-width LSGAN) orlsgan
(finite-width LSGAN); - for
$ARCH_CONFIG
:rbf
(RBF kernel, only for infinite-width losses),relu
,relu_nobias
,relu_highbias
(used for CelebA); - for
$DATA_CONFIG
:eight_gaussians
,density
,ab
,mnist
,celeba
.
For example, to reproduce the experiment on the eight Gaussians dataset with a ReLU network in the infinite-width regime and the IPM loss:
python -m gantk2.train --loss_config inf_ipm --arch_config relu --data_config eight_gaussians --device 0 --save_path saves --save_name test
The saved experiment folder contains a configuration file, visualizations in the img
subfolder and checkpoints and metrics in chkpt
.
In particular, chkpt/metrics.csv
contains metrics for all tested timesteps during training (the Sinkhorn divergence corresponding to the s
column).
We refer to gantk2/args/exp_configs.py
for details about these premade configurations, and to gantk2/args/args.py
for the complete set of arguments of the training script, which can also be obtained via:
python -m gantk2.train --help
We provide here commands to reproduce the plots shown in the paper.
Execute the following command:
python -m gantk2.plots.plot_adequation_1d --ade1d_config $ADE1D_CONFIG --device $DEVICE [--plot_output_file $PLOT_OUTPUT_FILE]
where $DEVICE
is the GPU index and $PLOT_OUTPUT_FILE
is the file name where the plot will be saved.
By default, the plot is shown and not saved.
Two options are available for $ADE1D_CONFIG
, corresponding to the sets of hyperparameters used for the 1d plots of the paper: ipm_relu
(IPM with ReLU Discriminator), lsgan_relu
(LSGAN with ReLU Discriminator).
python -m gantk2.plots.plot_adequation_2d --ade2d_config $ADE2D_CONFIG --device $DEVICE [--plot_output_file $PLOT_OUTPUT_FILE]
where $ADE2D_CONFIG
takes only lsgan_relu
(LSGAN with ReLU Discriminator) as option, corresponding to the sets of hyperparameters used for the 2d plots of the paper.
Note that other arguments may also be tested, such as --loss_config $LOSS_CONFIG
, or --arch_config $ARCH_CONFIG
, etc... where
$LOSS_CONFIG
takes value ipm
or lsgan
. For other argument values, refer to the last section and to Reproducing Experiments.
python -m gantk2.plots.plot_vector_field --loss_config $LOSS_CONFIG --arch_config $ARCH_CONFIG --device $DEVICE [--plot_output_file $PLOT_OUTPUT_FILE]
For argument values, refer to the last sections and to Reproducing Experiments.
Corresponding plots can be found in the img
subfolder of the chosen experiment directory.