Skip to content
[CVPR 2020] GAN Compression: Efficient Architectures for Interactive Conditional GANs
Python Jupyter Notebook Shell TeX
Branch: master
Clone or download

Latest commit

Latest commit e1b80e8 Apr 3, 2020

Files

Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
configs first commit Mar 20, 2020
data first commit Mar 20, 2020
datasets first commit Mar 20, 2020
distillers first commit Mar 20, 2020
docs first commit Mar 20, 2020
imgs Update demo gif. Apr 3, 2020
metric first commit Mar 20, 2020
models first commit Mar 20, 2020
options first commit Mar 20, 2020
scripts first commit Mar 20, 2020
supernets first commit Mar 20, 2020
utils first commit Mar 20, 2020
.gitignore first commit Mar 20, 2020
LICENSE first commit Mar 20, 2020
README.md Update demo gif. Apr 3, 2020
cycle_gan.ipynb first commit Mar 20, 2020
distill.py first commit Mar 20, 2020
environment.yaml first commit Mar 20, 2020
export.py first commit Mar 20, 2020
get_real_stat.py first commit Mar 20, 2020
get_test_opt.py first commit Mar 20, 2020
pix2pix.ipynb first commit Mar 20, 2020
requirements.txt first commit Mar 20, 2020
search.py first commit Mar 20, 2020
search_multi.py first commit Mar 20, 2020
test.py first commit Mar 20, 2020
tox.ini first commit Mar 20, 2020
train.py first commit Mar 20, 2020
train_supernet.py first commit Mar 20, 2020
trainer.py first commit Mar 20, 2020

README.md

GAN Compression

paper | demo

[NEW!] The PyTorch implementation of a general conditional GAN Compression framework is released.

teaserWe introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.

GAN Compression: Efficient Architectures for Interactive Conditional GANs
Muyang Li, Ji Lin, Yaoyao Ding, Zhijian Liu, Jun-Yan Zhu, and Song Han
MIT, Adobe Research, SJTU
In CVPR 2020.

Demos

Overview

overviewGAN Compression framework: ① Given a pre-trained teacher generator G', we distill a smaller “once-for-all” student generator G that contains all possible channel numbers through weight sharing. We choose different channel numbers for the student generator G at each training step. ② We then extract many sub-generators from the “once-for-all” generator and evaluate their performance. No retraining is needed, which is the advantage of the “once-for-all” generator. ③ Finally, we choose the best sub-generator given the compression ratio target and performance target (FID or mAP), perform fine-tuning, and obtain the final compressed model.

Colab Notebook

PyTorch Colab notebook: CycleGAN and pix2pix.

Prerequisites

  • Linux
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:

    git clone git@github.com:mit-han-lab/gan-compression.git
    cd gan-compression
  • Install PyTorch 1.4 and other dependencies (e.g., torchvision).

    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, we provide an installation script ./scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml.
  • Install torchprofile.

    pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git

CycleGAN

Setup

  • Download the CycleGAN dataset (e.g., horse2zebra).

    bash datasets/download_cyclegan_dataset.sh horse2zebra
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistic information for several datasets. For example,

    bash ./datasets/download_real_stat.sh horse2zebra A
    bash ./datasets/download_real_stat.sh horse2zebra B

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
  • Test the original full model.

    bash scripts/cycle_gan/horse2zebra/test_full.sh
  • Test the compressed model.

    bash scripts/cycle_gan/horse2zebra/test_compressed.sh

Pix2pix

Setup

  • Download the pix2pix dataset (e.g., edges2shoes).

    bash ./datasets/download_pix2pix_dataset.sh edges2shoes-r
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,

    bash datasets/download_real_stat.sh edges2shoes-r B

Apply a Pre-trained Model

  • Download the pre-trained models.

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
  • Test the original full model.

    bash scripts/pix2pix/edges2shoes-r/test_full.sh
  • Test the compressed model.

    bash scripts/pix2pix/edges2shoes-r/test_compressed.sh

Cityscapes Dataset

For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script datasets/prepare_cityscapes_dataset.py to preprocess it. You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip and unzip them in the same folder. For example, you may put gtFine and leftImg8bit in database/cityscapes-origin. You could prepare the dataset with the following command:

python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--table_path ./dataset/table.txt

You will get a preprocessed dataset in database/cityscapes and a mapping table (used to compute mAP) in dataset/table.txt.

FID Computation

To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script get_real_stat.py to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:

python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB

Citation

If you use this code for your research, please cite our paper.

@inproceedings{li2020gan,
  title={GAN Compression: Efficient Architectures for Interactive Conditional GANs},
  author={Li, Muyang and Lin, Ji and Ding, Yaoyao and Liu, Zhijian and Zhu, Jun-Yan and Han, Song},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2020}
}

Acknowledgements

Our code is developed based on pytorch-CycleGAN-and-pix2pix and SPADE.

We also thank pytorch-fid for FID computation and drn for mAP computation.

You can’t perform that action at this time.