Skip to content

Commit

Permalink
add lsun dataset / remove useless shape information in tfrecord
Browse files Browse the repository at this point in the history
  • Loading branch information
khanrc committed Sep 10, 2017
1 parent fd0a9db commit 4f7f9e6
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 11 deletions.
17 changes: 17 additions & 0 deletions config.py
Expand Up @@ -27,6 +27,23 @@ def get_model(mtype, name, training):
return model(name=name, training=training)


def get_dataset(dataset_name):
celebA_64 = './data/celebA_tfrecords/*.tfrecord'
celebA_128 = './data/celebA_128_tfrecords/*.tfrecord'
lsun_bedroom_128 = './data/lsun/bedroom_128_tfrecords/*.tfrecord'

if dataset_name == 'celeba':
path = celebA_128
n_examples = 202599
elif dataset_name == 'lsun':
path = lsun_bedroom_128
n_examples = 3033042
else:
raise ValueError('{} is does not supported. dataset must be celeba or lsun.'.format(dataset_name))

return path, n_examples


def pprint_args(FLAGS):
print("\nParameters:")
for attr, value in sorted(vars(FLAGS).items()):
Expand Down
41 changes: 38 additions & 3 deletions convert.py
Expand Up @@ -82,14 +82,49 @@ def convert(source_dir, target_dir, crop_size, out_size, exts=[''], num_shards=1

im = scipy.misc.imresize(im, out_size)
example = tf.train.Example(features=tf.train.Features(feature={
"shape": _int64_features(im.shape),
# "shape": _int64_features(im.shape),
"image": _bytes_features([im.tostring()])
}))
writer.write(example.SerializeToString())

writer.close()


''' Below function burrowed from https://github.com/fyu/lsun.
Process: LMDB => images => tfrecords
It is more efficient method to skip intermediate images, but that is a little messy job.
The method through images is inefficient but convenient.
'''
def export_images(db_path, out_dir, flat=False, limit=-1):
print('Exporting {} to {}'.format(db_path, out_dir))
env = lmdb.open(db_path, map_size=1099511627776, max_readers=100, readonly=True)
num_images = env.stat()['entries']
count = 0
with env.begin(write=False) as txn:
cursor = txn.cursor()
for key, val in cursor:
if not flat:
image_out_dir = join(out_dir, '/'.join(key[:6]))
else:
image_out_dir = out_dir
if not exists(image_out_dir):
os.makedirs(image_out_dir)
image_out_path = join(image_out_dir, key + '.webp')
with open(image_out_path, 'w') as fp:
fp.write(val)
count += 1
if count == limit:
break
if count % 10000 == 0:
print('{}/{} ...'.format(count, num_images))


if __name__ == "__main__":
convert('./data/celebA', './data/celebA_tfrecords_test', crop_size=[128, 128], out_size=[64, 64], exts=['jpg'],
num_shards=128, tfrecords_prefix='celebA')
# CelebA
# convert('./data/celebA', './data/celebA_tfrecords_test', crop_size=[128, 128], out_size=[64, 64],
# exts=['jpg'], num_shards=128, tfrecords_prefix='celebA')

# LSUN
# export_images('./tf.gans-comparison/data/lsun/bedroom_val_lmdb/', './tf.gans-comparison/data/lsun/bedroom_val_images/', flat=True)
convert('./data/lsun/bedroom_train_images', './data/lsun/bedroom_128_tfrecords', crop_size=[128, 128], out_size=[128, 128],
exts=['webp'], num_shards=128, tfrecords_prefix='lsun_bedroom')
6 changes: 2 additions & 4 deletions inputpipe.py
Expand Up @@ -12,15 +12,13 @@ def read_parse_preproc(filename_queue):
features = tf.parse_single_example(
records,
features={
"shape": tf.FixedLenFeature([3], tf.int64),
"image": tf.FixedLenFeature([], tf.string)
}
)

image = tf.decode_raw(features["image"], tf.uint8)
# shape = tf.cast(features["shape"], tf.int32) # useless

image = tf.reshape(image, [64, 64, 3]) # The image_shape must be explicitly specified
image = tf.reshape(image, [128, 128, 3]) # The image_shape must be explicitly specified
image = tf.image.resize_images(image, [64, 64])
image = tf.cast(image, tf.float32)
image = image / 127.5 - 1.0 # preproc - normalize

Expand Down
14 changes: 10 additions & 4 deletions train.py
Expand Up @@ -17,6 +17,7 @@ def build_parser():
models_str = ' / '.join(config.model_zoo)
parser.add_argument('--model', help=models_str, required=True) # DRAGAN, CramerGAN
parser.add_argument('--name', help='default: name=model')
parser.add_argument('--dataset', help='CelebA / LSUN', required=True)
parser.add_argument('--renew', action='store_true', help='train model from scratch - \
clean saved checkpoints and summaries', default=False)
# more arguments: dataset
Expand All @@ -26,9 +27,9 @@ def build_parser():

def input_pipeline(glob_pattern, batch_size, num_threads, num_epochs):
tfrecords_list = glob.glob(glob_pattern)
num_examples = utils.num_examples_from_tfrecords(tfrecords_list)
# num_examples = utils.num_examples_from_tfrecords(tfrecords_list) # takes too long time for lsun
X = ip.shuffle_batch_join(tfrecords_list, batch_size=batch_size, num_threads=num_threads, num_epochs=num_epochs)
return X, num_examples
return X


def sample_z(shape):
Expand Down Expand Up @@ -66,7 +67,8 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False):
# make config_summary before define of summary_writer - bypass bug of tensorboard

# It seems that batch_size should have been contained in the model config ...
model_config_list = [[k, str(w)] for k, w in sorted(model.args.items()) + [('batch_size', batch_size)]]
config_list = [('batch_size', batch_size), ('dataset', FLAGS.dataset)]
model_config_list = [[k, str(w)] for k, w in sorted(model.args.items()) + config_list]
model_config_summary_op = tf.summary.text('config', tf.convert_to_tensor(model_config_list), collections=[])
model_config_summary = sess.run(model_config_summary_op)

Expand Down Expand Up @@ -119,12 +121,16 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False):
parser = build_parser()
FLAGS = parser.parse_args()
FLAGS.model = FLAGS.model.upper()
FLAGS.dataset = FLAGS.dataset.lower()
if FLAGS.name is None:
FLAGS.name = FLAGS.model.lower()
config.pprint_args(FLAGS)

# get information for dataset
dataset_pattern, n_examples = config.get_dataset(FLAGS.dataset)

# input pipeline
X, n_examples = input_pipeline('./data/celebA_tfrecords/*.tfrecord', batch_size=FLAGS.batch_size,
X = input_pipeline(dataset_pattern, batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_threads, num_epochs=FLAGS.num_epochs)
model = config.get_model(FLAGS.model, FLAGS.name, training=True)
train(model=model, input_op=X, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size,
Expand Down

0 comments on commit 4f7f9e6

Please sign in to comment.