Skip to content

Commit

Permalink
Added code
Browse files Browse the repository at this point in the history
  • Loading branch information
abdulfatir committed Mar 14, 2020
1 parent 1a4166b commit 084041f
Show file tree
Hide file tree
Showing 18 changed files with 1,718 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
@@ -1 +1,2 @@
.DS_Store
__pycache__
15 changes: 11 additions & 4 deletions README.md
Expand Up @@ -30,10 +30,10 @@ and the ECFD between two distributions is given by

### Generating samples from pre-trained models

* Download the pre-trained models from releases.
* Download the pre-trained generators from releases.
* Run the following command to generate an 8x8 grid of samples from a model trained on CIFAR10 dataset:
```bash
python gen_samples.py\
python src/gen_samples.py\
--png\
--imsize 32\
--noise_dim 32\
Expand All @@ -46,7 +46,7 @@ python gen_samples.py\

* **Downloading Datasets**: All the datasets will download by themselves when the code is run, except CelebA. CelebA can be downloaded by executing `python download.py celebA`. Rename the directory `./data/img_align_celeba` to `./data/celebA` after the script finishes execution.
* Run `python src/main.py --help` to see a description of all the available command-line arguments.
* Run the following command to train OCFGAN-GP on the CIFAR10 dataset:
* **Example**: run the following command to train on the CIFAR10 dataset:
```bash
python src/main.py\
--dataset cifar10\
Expand Down Expand Up @@ -80,4 +80,11 @@ For any questions regarding the code or the paper, please email me at [abdulfati
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
```

#### Acknowledgements
Parts of the code/network structures in this repository have been adapted from the following repos:

* [ozanciga/gans-with-pytorch](https://github.com/ozanciga/gans-with-pytorch)
* [OctoberChang/MMD-GAN](https://github.com/OctoberChang/MMD-GAN)
* [mbinkowski/MMD-GAN](https://github.com/mbinkowski/MMD-GAN)
37 changes: 37 additions & 0 deletions celeba.sh
@@ -0,0 +1,37 @@
#!/bin/bash

BS=64
GPU_ID=0
MAX_GITER=125000
DATA_PATH=./data
DATASET=celeba
DATAROOT=${DATA_PATH}/celebA
ISIZE=32
NC=3
NOISE_DIM=64
MODEL=cfgangp
DOUT_DIM=32
NUM_FREQS=8
WEIGHT=gaussian_ecfd
SIGMA=0.

cmd="python src/main.py\
--dataset ${DATASET}\
--dataroot ${DATAROOT}\
--model ${MODEL}\
--batch_size ${BS}\
--image_size ${ISIZE}\
--nc ${NC}\
--noise_dim ${NOISE_DIM}\
--dout_dim ${DOUT_DIM}\
--max_giter ${MAX_GITER}\
--resultsroot ./out
--gpu_device ${GPU_ID}"

if [ ${MODEL} == 'cfgangp' ]; then
cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}"
fi

echo $cmd
eval $cmd

40 changes: 40 additions & 0 deletions celeba128.sh
@@ -0,0 +1,40 @@
#!/bin/bash

BS=64
GPU_ID=2
MAX_GITER=125000
DATA_PATH=./data
DATASET=celeba128
DATAROOT=${DATA_PATH}/celebA
ISIZE=128
NC=3
NOISE_DIM=100

MODEL=cfgangp
DOUT_DIM=1
NUM_FREQS=8
WEIGHT=gaussian_ecfd
SIGMA=0.

cmd="python src/main.py\
--dataset ${DATASET}\
--dataroot ${DATAROOT}\
--model ${MODEL}\
--gen resnet
--disc dcgan5
--batch_size ${BS}\
--image_size ${ISIZE}\
--nc ${NC}\
--noise_dim ${NOISE_DIM}\
--dout_dim ${DOUT_DIM}\
--max_giter ${MAX_GITER}\
--resultsroot ./out
--gpu_device ${GPU_ID}"

if [ ${MODEL} == 'cfgangp' ]; then
cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}"
fi

echo $cmd
eval $cmd

36 changes: 36 additions & 0 deletions cifar10.sh
@@ -0,0 +1,36 @@
#!/bin/bash

BS=64
GPU_ID=0
MAX_GITER=125000
DATA_PATH=./data
DATASET=cifar10
DATAROOT=${DATA_PATH}/cifar10
ISIZE=32
NC=3
NOISE_DIM=32
MODEL=cfgangp
DOUT_DIM=${NOISE_DIM}
NUM_FREQS=8
WEIGHT=gaussian_ecfd
SIGMA=0.

cmd="python src/main.py\
--dataset ${DATASET}\
--dataroot ${DATAROOT}\
--model ${MODEL}\
--batch_size ${BS}\
--image_size ${ISIZE}\
--nc ${NC}\
--noise_dim ${NOISE_DIM}\
--dout_dim ${DOUT_DIM}\
--max_giter ${MAX_GITER}\
--resultsroot ./out
--gpu_device ${GPU_ID}"

if [ ${MODEL} == 'cfgangp' ]; then
cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}"
fi

echo $cmd
eval $cmd
Empty file added data/.gitkeep
Empty file.
180 changes: 180 additions & 0 deletions download.py
@@ -0,0 +1,180 @@
"""
Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py
Downloads the following:
- Celeb-A dataset
- LSUN dataset
- MNIST dataset
"""

from __future__ import print_function
import os
import sys
import gzip
import json
import shutil
import zipfile
import argparse
import requests
import subprocess
from tqdm import tqdm
from six.moves import urllib

parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'],
help='name of dataset to download [celebA, lsun, mnist]')

def download(url, dirpath):
filename = url.split('/')[-1]
filepath = os.path.join(dirpath, filename)
u = urllib.request.urlopen(url)
f = open(filepath, 'wb')
filesize = int(u.headers["Content-Length"])
print("Downloading: %s Bytes: %s" % (filename, filesize))

downloaded = 0
block_sz = 8192
status_width = 70
while True:
buf = u.read(block_sz)
if not buf:
print('')
break
else:
print('', end='\r')
downloaded += len(buf)
f.write(buf)
status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
print(status, end='')
sys.stdout.flush()
f.close()
return filepath

def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()

response = session.get(URL, params={ 'id': id }, stream=True)
token = get_confirm_token(response)

if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params=params, stream=True)

save_response_content(response, destination)

def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None

def save_response_content(response, destination, chunk_size=32*1024):
total_size = int(response.headers.get('content-length', 0))
with open(destination, "wb") as f:
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
unit='B', unit_scale=True, desc=destination):
if chunk: # filter out keep-alive new chunks
f.write(chunk)

def unzip(filepath):
print("Extracting: " + filepath)
dirpath = os.path.dirname(filepath)
with zipfile.ZipFile(filepath) as zf:
zf.extractall(dirpath)
os.remove(filepath)

def download_celeb_a(dirpath):
data_dir = 'celebA'
if os.path.exists(os.path.join(dirpath, data_dir)):
print('Found Celeb-A - skip')
return

filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
save_path = os.path.join(dirpath, filename)

if os.path.exists(save_path):
print('[*] {} already exists'.format(save_path))
else:
download_file_from_google_drive(drive_id, save_path)

zip_dir = ''
with zipfile.ZipFile(save_path) as zf:
zip_dir = zf.namelist()[0]
zf.extractall(dirpath)
os.remove(save_path)
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))

def _list_categories(tag):
url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
f = urllib.request.urlopen(url)
return json.loads(f.read())

def _download_lsun(out_dir, category, set_name, tag):
url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
'&category={category}&set={set_name}'.format(**locals())
print(url)
if set_name == 'test':
out_name = 'test_lmdb.zip'
else:
out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
out_path = os.path.join(out_dir, out_name)
cmd = ['curl', url, '-o', out_path]
print('Downloading', category, set_name, 'set')
subprocess.call(cmd)

def download_lsun(dirpath):
data_dir = os.path.join(dirpath, 'lsun')
if os.path.exists(data_dir):
print('Found LSUN - skip')
return
else:
os.mkdir(data_dir)

tag = 'latest'
#categories = _list_categories(tag)
categories = ['bedroom']

for category in categories:
_download_lsun(data_dir, category, 'train', tag)
_download_lsun(data_dir, category, 'val', tag)
_download_lsun(data_dir, '', 'test', tag)

def download_mnist(dirpath):
data_dir = os.path.join(dirpath, 'mnist')
if os.path.exists(data_dir):
print('Found MNIST - skip')
return
else:
os.mkdir(data_dir)
url_base = 'http://yann.lecun.com/exdb/mnist/'
file_names = ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']
for file_name in file_names:
url = (url_base+file_name).format(**locals())
print(url)
out_path = os.path.join(data_dir,file_name)
cmd = ['curl', url, '-o', out_path]
print('Downloading ', file_name)
subprocess.call(cmd)
cmd = ['gzip', '-d', out_path]
print('Decompressing ', file_name)
subprocess.call(cmd)

def prepare_data_dir(path = './data'):
if not os.path.exists(path):
os.mkdir(path)

if __name__ == '__main__':
args = parser.parse_args()
prepare_data_dir()

if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']):
download_celeb_a('./data')
if 'lsun' in args.datasets:
download_lsun('./data')
if 'mnist' in args.datasets:
download_mnist('./data')
36 changes: 36 additions & 0 deletions mnist.sh
@@ -0,0 +1,36 @@
#!/bin/bash

BS=64
GPU_ID=0
MAX_GITER=50000
DATA_PATH=./data
DATASET=mnist
DATAROOT=${DATA_PATH}/mnist
ISIZE=32
NC=1
NOISE_DIM=10
MODEL=cfgangp
DOUT_DIM=${NOISE_DIM}
NUM_FREQS=8
WEIGHT=gaussian_ecfd
SIGMA=0.

cmd="python src/main.py\
--dataset ${DATASET}\
--dataroot ${DATAROOT}\
--model ${MODEL}\
--batch_size ${BS}\
--image_size ${ISIZE}\
--nc ${NC}\
--noise_dim ${NOISE_DIM}\
--dout_dim ${DOUT_DIM}\
--max_giter ${MAX_GITER}\
--resultsroot ./out
--gpu_device ${GPU_ID}"

if [ ${MODEL} == 'cfgangp' ]; then
cmd+=" --num_freqs ${NUM_FREQS} --weight ${WEIGHT} --sigmas ${SIGMA}"
fi

echo $cmd
eval $cmd

0 comments on commit 084041f

Please sign in to comment.