CycleGAN is a deep learning architecture for image-to-image translation that can be trained on unpaired datatsets. We reimplement and train it on several datasets using the JAX framework.
TODO: script to automatically install dependencies?
conda create -n cg-jax python=3.9
conda activate cg-jax
If CUDA is available (highly recommended for training):
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Otherwise:
pip install --upgrade "jax[cpu]"
Finally:
pip install flax optax matplotlib jupyter tqdm
pip install torch torchvision # For data loading
To train, prepare a dataset directory with subdirectories trainA
, trainB
, testA
, testB
populated accordingly, then run
python main.py --train -d <dataset dir path> -m <model checkpoints and outputs path>
To generated an image, run
python main.py --predict <A|B> -d <file path> -m <model checkpoints and outputs path>
where <A|B>
is the set that the starting image belongs to.
Hyperparameters and other configuration options are gathered in train.get_default_opts
.
Grant execution permission to get-*.sh
script, then run. For example,
chmod +x ./get-horse2zebra.sh
./get-horse2zebra.sh
The datasets we use are hosted by the original CycleGAN authors.