From 4f7f9e6c4c3628bdb32c23a1aca621e319fc968d Mon Sep 17 00:00:00 2001 From: khanrc Date: Sun, 10 Sep 2017 23:23:17 +0900 Subject: [PATCH] add lsun dataset / remove useless shape information in tfrecord --- config.py | 17 +++++++++++++++++ convert.py | 41 ++++++++++++++++++++++++++++++++++++++--- inputpipe.py | 6 ++---- train.py | 14 ++++++++++---- 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/config.py b/config.py index 3547f8b..9fa207f 100644 --- a/config.py +++ b/config.py @@ -27,6 +27,23 @@ def get_model(mtype, name, training): return model(name=name, training=training) +def get_dataset(dataset_name): + celebA_64 = './data/celebA_tfrecords/*.tfrecord' + celebA_128 = './data/celebA_128_tfrecords/*.tfrecord' + lsun_bedroom_128 = './data/lsun/bedroom_128_tfrecords/*.tfrecord' + + if dataset_name == 'celeba': + path = celebA_128 + n_examples = 202599 + elif dataset_name == 'lsun': + path = lsun_bedroom_128 + n_examples = 3033042 + else: + raise ValueError('{} is does not supported. dataset must be celeba or lsun.'.format(dataset_name)) + + return path, n_examples + + def pprint_args(FLAGS): print("\nParameters:") for attr, value in sorted(vars(FLAGS).items()): diff --git a/convert.py b/convert.py index 8130bd1..63bd4b4 100644 --- a/convert.py +++ b/convert.py @@ -82,7 +82,7 @@ def convert(source_dir, target_dir, crop_size, out_size, exts=[''], num_shards=1 im = scipy.misc.imresize(im, out_size) example = tf.train.Example(features=tf.train.Features(feature={ - "shape": _int64_features(im.shape), + # "shape": _int64_features(im.shape), "image": _bytes_features([im.tostring()]) })) writer.write(example.SerializeToString()) @@ -90,6 +90,41 @@ def convert(source_dir, target_dir, crop_size, out_size, exts=[''], num_shards=1 writer.close() +''' Below function burrowed from https://github.com/fyu/lsun. +Process: LMDB => images => tfrecords +It is more efficient method to skip intermediate images, but that is a little messy job. +The method through images is inefficient but convenient. +''' +def export_images(db_path, out_dir, flat=False, limit=-1): + print('Exporting {} to {}'.format(db_path, out_dir)) + env = lmdb.open(db_path, map_size=1099511627776, max_readers=100, readonly=True) + num_images = env.stat()['entries'] + count = 0 + with env.begin(write=False) as txn: + cursor = txn.cursor() + for key, val in cursor: + if not flat: + image_out_dir = join(out_dir, '/'.join(key[:6])) + else: + image_out_dir = out_dir + if not exists(image_out_dir): + os.makedirs(image_out_dir) + image_out_path = join(image_out_dir, key + '.webp') + with open(image_out_path, 'w') as fp: + fp.write(val) + count += 1 + if count == limit: + break + if count % 10000 == 0: + print('{}/{} ...'.format(count, num_images)) + + if __name__ == "__main__": - convert('./data/celebA', './data/celebA_tfrecords_test', crop_size=[128, 128], out_size=[64, 64], exts=['jpg'], - num_shards=128, tfrecords_prefix='celebA') + # CelebA + # convert('./data/celebA', './data/celebA_tfrecords_test', crop_size=[128, 128], out_size=[64, 64], + # exts=['jpg'], num_shards=128, tfrecords_prefix='celebA') + + # LSUN + # export_images('./tf.gans-comparison/data/lsun/bedroom_val_lmdb/', './tf.gans-comparison/data/lsun/bedroom_val_images/', flat=True) + convert('./data/lsun/bedroom_train_images', './data/lsun/bedroom_128_tfrecords', crop_size=[128, 128], out_size=[128, 128], + exts=['webp'], num_shards=128, tfrecords_prefix='lsun_bedroom') \ No newline at end of file diff --git a/inputpipe.py b/inputpipe.py index 0ea6abf..b9e6e50 100644 --- a/inputpipe.py +++ b/inputpipe.py @@ -12,15 +12,13 @@ def read_parse_preproc(filename_queue): features = tf.parse_single_example( records, features={ - "shape": tf.FixedLenFeature([3], tf.int64), "image": tf.FixedLenFeature([], tf.string) } ) image = tf.decode_raw(features["image"], tf.uint8) - # shape = tf.cast(features["shape"], tf.int32) # useless - - image = tf.reshape(image, [64, 64, 3]) # The image_shape must be explicitly specified + image = tf.reshape(image, [128, 128, 3]) # The image_shape must be explicitly specified + image = tf.image.resize_images(image, [64, 64]) image = tf.cast(image, tf.float32) image = image / 127.5 - 1.0 # preproc - normalize diff --git a/train.py b/train.py index 7237e52..0bd0cbe 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ def build_parser(): models_str = ' / '.join(config.model_zoo) parser.add_argument('--model', help=models_str, required=True) # DRAGAN, CramerGAN parser.add_argument('--name', help='default: name=model') + parser.add_argument('--dataset', help='CelebA / LSUN', required=True) parser.add_argument('--renew', action='store_true', help='train model from scratch - \ clean saved checkpoints and summaries', default=False) # more arguments: dataset @@ -26,9 +27,9 @@ def build_parser(): def input_pipeline(glob_pattern, batch_size, num_threads, num_epochs): tfrecords_list = glob.glob(glob_pattern) - num_examples = utils.num_examples_from_tfrecords(tfrecords_list) + # num_examples = utils.num_examples_from_tfrecords(tfrecords_list) # takes too long time for lsun X = ip.shuffle_batch_join(tfrecords_list, batch_size=batch_size, num_threads=num_threads, num_epochs=num_epochs) - return X, num_examples + return X def sample_z(shape): @@ -66,7 +67,8 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False): # make config_summary before define of summary_writer - bypass bug of tensorboard # It seems that batch_size should have been contained in the model config ... - model_config_list = [[k, str(w)] for k, w in sorted(model.args.items()) + [('batch_size', batch_size)]] + config_list = [('batch_size', batch_size), ('dataset', FLAGS.dataset)] + model_config_list = [[k, str(w)] for k, w in sorted(model.args.items()) + config_list] model_config_summary_op = tf.summary.text('config', tf.convert_to_tensor(model_config_list), collections=[]) model_config_summary = sess.run(model_config_summary_op) @@ -119,12 +121,16 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False): parser = build_parser() FLAGS = parser.parse_args() FLAGS.model = FLAGS.model.upper() + FLAGS.dataset = FLAGS.dataset.lower() if FLAGS.name is None: FLAGS.name = FLAGS.model.lower() config.pprint_args(FLAGS) + # get information for dataset + dataset_pattern, n_examples = config.get_dataset(FLAGS.dataset) + # input pipeline - X, n_examples = input_pipeline('./data/celebA_tfrecords/*.tfrecord', batch_size=FLAGS.batch_size, + X = input_pipeline(dataset_pattern, batch_size=FLAGS.batch_size, num_threads=FLAGS.num_threads, num_epochs=FLAGS.num_epochs) model = config.get_model(FLAGS.model, FLAGS.name, training=True) train(model=model, input_op=X, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size,