Yiyong Li1, Zhun Sun1*, Lichao2 *corresponding author, zhunsun@gmail.com 1 BIGO Ltd, 2 Tohoku University, 3 AIP, RIKEN https://bmvc2022.mpi-inf.mpg.de/0551.pdf
Image-to-image translation by (compressed) CycleGAN:
-
Clone this repo:
git clone https://github.com/liyiyong-nk/gan_BIGP.git cd gan_BIGP
-
Install dependencies.
conda create -n BIGP python=3.7.11 conda activate BIGP pip install torch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 pip install -r requirements.txt
bash ./download_dataset <dataset_name>
This will download the dataset to folder datasets/<dataset_name>
(e.g., datasets/horse2zebra
).
Use the official CycleGAN codes to train original dense CycleGAN. and load their parameter in our model
Use the official CycleGAN codes to train original dense CycleGAN. and load their parameter in our model
Use the pretrained dense generator to generate style transfer results on training set and put the style transfer results to folder train_set_result/<dataset_name>
.
For example, train_set_result/horse2zebra/B/n02381460_2_fake.png
is the fake zebra image transferred from the real horse image datasets/horse2zebra/train/A/n02381460_2.jpg
using the original dense CycleGAN.
stage 1
python gan_bigp.py --lrw 1e-5 --alpha 4e-6 --contral_rate 0.001 --epochs 200 --dataset horse2zebra --task A2B --gpu 1
remark: The difference between alpha*sparse_loss= 4e-6 * 334967.2812=1.3399 and perceptual=0.5922 should be not large. The compression ratio you want to achieve may take several stages, which you should follw the remark and modify the g_path in gan_bigp.py. The fuctions (BIG_loss, update_in) in gan_bigp.py can be flexibly applied to your project.
stage 2
python gan_bigp.py --lrw 1e-5 --alpha 1.25e-5 --contral_rate 0.006 --epochs 200 --dataset horse2zebra --task A2B --gpu 1
stage 3
python gan_bigp.py --lrw 1e-5 --alpha 1.25e-5 --contral_rate 0.0125 --epochs 200 --dataset horse2zebra --task A2B --gpu 1
The training results (checkpoints, loss curves, etc.) will be saved in results/<dataset_name>/<task_name>
.
We already updated the weight and bias of IN to 0. So it's easy to extract the subnet model.
python extract_subnet.py --dataset <dataset_name> --task <task_name> --model_str <model_str>
The extracted subnetworks will be saved in subnet_structures/<dataset_name>/<task_name>
Given the (image_dir, result_dir, g_path), you can get the generated images.
python test.py
The generated images will be saved in result_dir.
If you use this code for your research, please cite our paper.
@inproceedings{bigp2022bmvc,
title={Are we pruning the correct channels in image-to-image translation models?},
author={Yiyong Li, Zhun Sun, and Chao Li},
booktitle={BMVC},
year={2022}
}
Our code is developed based on https://github.com/VITA-Group/GAN-Slimming