In [1]:
CONFIG_FILE = 'experiments/cfgs/gans/mnist.yml'
DATASETS = ['mnist']

In [2]:
import argparse
import sys

import tensorflow as tf

from models.gan import MnistDefenseGAN, FmnistDefenseDefenseGAN, \
    CelebADefenseGAN
from utils.config import load_config

Instructions for updating:
Use the retry module or similar alternatives.


In [3]:
from __future__ import print_function
import os
import sys
import json
import zipfile
import argparse
import requests
import subprocess
from tqdm import tqdm
from six.moves import urllib

parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist','f-mnist'],
           help='name of dataset to download [celebA, lsun, mnist, fmnist]')

def download(url, dirpath):
  filename = url.split('/')[-1]
  filepath = os.path.join(dirpath, filename)
  u = urllib.request.urlopen(url)
  f = open(filepath, 'wb')
  filesize = int(u.headers["Content-Length"])
  print("Downloading: %s Bytes: %s" % (filename, filesize))

  downloaded = 0
  block_sz = 8192
  status_width = 70
  while True:
    buf = u.read(block_sz)
    if not buf:
      print('')
      break
    else:
      print('', end='\r')
    downloaded += len(buf)
    f.write(buf)
    status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
      ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
    print(status, end='')
    sys.stdout.flush()
  f.close()
  return filepath

def download_file_from_google_drive(id, destination):
  URL = "https://docs.google.com/uc?export=download"
  session = requests.Session()

  response = session.get(URL, params={ 'id': id }, stream=True,verify=False)
  token = get_confirm_token(response)

  if token:
    params = { 'id' : id, 'confirm' : token }
    response = session.get(URL, params=params, stream=True,verify=False)

  save_response_content(response, destination)

def get_confirm_token(response):
  for key, value in response.cookies.items():
    if key.startswith('download_warning'):
      return value
  return None

def save_response_content(response, destination, chunk_size=32*1024):
  total_size = int(response.headers.get('content-length', 0))
  with open(destination, "wb") as f:
    for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
              unit='B', unit_scale=True, desc=destination):
      if chunk: # filter out keep-alive new chunks
        f.write(chunk)

def unzip(filepath):
  print("Extracting: " + filepath)
  dirpath = os.path.dirname(filepath)
  with zipfile.ZipFile(filepath) as zf:
    zf.extractall(dirpath)
  os.remove(filepath)

def download_celeb_a(dirpath):
  data_dir = 'celebA'
  if os.path.exists(os.path.join(dirpath, data_dir)):
    print('Found Celeb-A - skip')
    return

  filename, drive_id  = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
  save_path = os.path.join(dirpath, filename)

  if os.path.exists(save_path):
    print('[*] {} already exists'.format(save_path))
  else:
    download_file_from_google_drive(drive_id, save_path)

  zip_dir = ''
  with zipfile.ZipFile(save_path) as zf:
    zip_dir = zf.namelist()[0]
    zf.extractall(dirpath)
  os.remove(save_path)
  os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))

def _list_categories(tag):
  url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
  f = urllib.request.urlopen(url)
  return json.loads(f.read())

def _download_lsun(out_dir, category, set_name, tag):
  url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
      '&category={category}&set={set_name}'.format(**locals())
  print(url)
  if set_name == 'test':
    out_name = 'test_lmdb.zip'
  else:
    out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
  out_path = os.path.join(out_dir, out_name)
  cmd = ['curl', url, '-o', out_path]
  print('Downloading', category, set_name, 'set')
  subprocess.call(cmd)

def download_lsun(dirpath):
  data_dir = os.path.join(dirpath, 'lsun')
  if os.path.exists(data_dir):
    print('Found LSUN - skip')
    return
  else:
    os.mkdir(data_dir)

  tag = 'latest'
  categories = ['bedroom']

  for category in categories:
    _download_lsun(data_dir, category, 'train', tag)
    _download_lsun(data_dir, category, 'val', tag)
  _download_lsun(data_dir, '', 'test', tag)

def download_mnist(dirpath):
  data_dir = os.path.join(dirpath, 'mnist')
  if os.path.exists(data_dir):
    print('Found MNIST - skip')
    return
  else:
    os.mkdir(data_dir)
  url_base = 'http://yann.lecun.com/exdb/mnist/'
  file_names = ['train-images-idx3-ubyte.gz',
                'train-labels-idx1-ubyte.gz',
                't10k-images-idx3-ubyte.gz',
                't10k-labels-idx1-ubyte.gz']
  for file_name in file_names:
    url = (url_base+file_name).format(**locals())
    print(url)
    out_path = os.path.join(data_dir,file_name)
    cmd = ['curl', url, '-o', out_path]
    print('Downloading ', file_name)
    subprocess.call(cmd)
    cmd = ['gzip', '-d', out_path]
    print('Decompressing ', file_name)
    subprocess.call(cmd)


def download_fmnist(dirpath):
  data_dir = os.path.join(dirpath, 'f-mnist')
  if os.path.exists(data_dir):
    print('Found F-MNIST - skip')
    return
  else:
    os.mkdir(data_dir)
  url_base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
  file_names = ['train-images-idx3-ubyte.gz',
                'train-labels-idx1-ubyte.gz',
                't10k-images-idx3-ubyte.gz',
                't10k-labels-idx1-ubyte.gz']
  for file_name in file_names:
    url = (url_base+file_name).format(**locals())
    print(url)
    out_path = os.path.join(data_dir,file_name)
    cmd = ['curl', url, '-o', out_path]
    print('Downloading ', file_name)
    subprocess.call(cmd)
    cmd = ['gzip', '-d', out_path]
    print('Decompressing ', file_name)
    subprocess.call(cmd)

def prepare_data_dir(path = './data'):
  if not os.path.exists(path):
    os.mkdir(path)

In [None]:
  prepare_data_dir()

  if any(name in DATASETS for name in ['CelebA', 'celebA', 'celebA']):
    download_celeb_a('./data')
  if 'mnist' in DATASETS:
    download_mnist('./data')
  if 'fmnist' in DATASETS:
    download_fmnist('./data')

In [4]:
def main(cfg, *args):
    FLAGS = tf.app.flags.FLAGS
    ds_gan = {
        'mnist': MnistDefenseGAN, 'f-mnist': FmnistDefenseDefenseGAN,
        'celeba': CelebADefenseGAN,
    }
    GAN = ds_gan[FLAGS.dataset_name]

    gan = GAN(cfg=cfg, test_mode=not FLAGS.is_train)

    if FLAGS.is_train:
        gan.train()

    if FLAGS.train_encoder:
        gan.load(checkpoint_dir=FLAGS.init_path)
        gan.train(phase='just_enc')

    if FLAGS.save_recs:
        gan.reconstruct_dataset(ckpt_path=FLAGS.init_path,
                                max_num=FLAGS.max_num)

    if FLAGS.test_generator:
        gan.load_generator(ckpt_path=FLAGS.init_path)
        gan.sess.run(gan.global_step.initializer)
        gan.generate_image(iteration=0)

    if FLAGS.test_batch:
        gan.test_batch()

    if FLAGS.save_ds:
        gan.save_ds()


In [None]:
    # Note: The load_config() call will convert all the parameters that are defined in
    # experiments/config files into FLAGS.param_name and can be passed in from command line.
    # arguments : python train.py --cfg <config_path> --<param_name> <param_value>
    cfg = load_config(CONFIG_FILE)
    flags = tf.app.flags

    flags.DEFINE_boolean("is_train", False,
                         "True for training, False for testing. [False]")
    flags.DEFINE_boolean("save_recs", False,
                         "True for saving reconstructions. [False]")
    flags.DEFINE_boolean("debug", False,
                         "True for debug. [False]")
    flags.DEFINE_boolean("test_generator", False,
                         "True for generator samples. [False]")
    flags.DEFINE_boolean("test_decoder", False,
                         "True for decoder samples. [False]")
    flags.DEFINE_boolean("test_again", False,
                         "True for not using cache. [False]")
    flags.DEFINE_boolean("test_batch", False,
                         "True for visualizing the batches and labels. [False]")
    flags.DEFINE_boolean("save_ds", False,
                         "True for saving the dataset in a pickle file. ["
                         "False]")
    flags.DEFINE_boolean("tensorboard_log", True, "True for saving "
                                                  "tensorboard logs. [True]")
    flags.DEFINE_boolean("train_encoder", False,
                         "Add an encoder to a pretrained model. ["
                         "False]")
    flags.DEFINE_boolean("init_with_enc", False,
                         "Initializes the z with an encoder, must run "
                         "--train_encoder first. [False]")
    flags.DEFINE_integer("max_num", -1,
                         "True for saving the dataset in a pickle file ["
                         "False]")
    flags.DEFINE_string("init_path", None, "Checkpoint path. [None]")


In [13]:
%tb

main_cfg = lambda x: main(cfg, x)
tf.app.run(main=main_cfg)


SystemExit: 

[#] MnistDefenseGAN.dataset_name is set to mnist.
[#] MnistDefenseGAN.batch_size is set to 50.
[#] MnistDefenseGAN.use_bn is set to False.
[#] MnistDefenseGAN.test_batch_size is set to 20.
[#] MnistDefenseGAN.mode is set to wgan-gp.
[#] MnistDefenseGAN.gradient_penalty_lambda is set to 10.0.
[#] MnistDefenseGAN.train_iters is set to 200000.
[#] MnistDefenseGAN.critic_iters is set to 5.
[#] MnistDefenseGAN.latent_dim is set to 128.
[#] MnistDefenseGAN.net_dim is set to 64.
[#] MnistDefenseGAN.input_transform_type is set to 0.
[#] MnistDefenseGAN.debug is set to False.
[#] MnistDefenseGAN.rec_iters is set to 200.
[#] MnistDefenseGAN.image_dim is set to [28, 28, 1].
[#] MnistDefenseGAN.rec_rr is set to 10.
[#] MnistDefenseGAN.rec_lr is set to 10.0.
[#] MnistDefenseGAN.test_again is set to False.
[-] MnistDefenseGAN.loss_type is not set.
[#] MnistDefenseGAN.loss_type is set to None.
[-] MnistDefenseGAN.attribute is not set.
[#] MnistDefenseGAN.attribute is set to None.
[#] MnistDefenseGAN.

SystemExit: 