Permalink
Browse files

Added sources

  • Loading branch information...
1 parent bb3fdd6 commit 153587ba60fbd926ba8ceeaee472e7df8ca7d181 David Garcia committed Aug 26, 2016
Showing with 866 additions and 0 deletions.
  1. +28 −0 srez_demo.py
  2. +53 −0 srez_input.py
  3. +190 −0 srez_main.py
  4. +481 −0 srez_model.py
  5. +114 −0 srez_train.py
View
@@ -0,0 +1,28 @@
+import moviepy.editor as mpe
+import numpy as np
+import numpy.random
+import os.path
+import scipy.misc
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+def demo1(sess):
+ """Demo based on images dumped during training"""
+
+ # Get images that were dumped during training
+ filenames = tf.gfile.ListDirectory(FLAGS.train_dir)
+ filenames = sorted(filenames)
+ filenames = [os.path.join(FLAGS.train_dir, f) for f in filenames if f[-4:]=='.png']
+
+ assert len(filenames) >= 1
+
+ fps = 30
+
+ # Create video file from PNGs
+ print("Producing video file...")
+ filename = os.path.join(FLAGS.train_dir, 'demo1.mp4')
+ clip = mpe.ImageSequenceClip(filenames, fps=fps)
+ clip.write_videofile(filename)
+ print("Done!")
+
View
@@ -0,0 +1,53 @@
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+def setup_inputs(sess, filenames, image_size=None, capacity_factor=3):
+
+ if image_size is None:
+ image_size = FLAGS.sample_size
+
+ # Read each JPEG file
+ reader = tf.WholeFileReader()
+ filename_queue = tf.train.string_input_producer(filenames)
+ key, value = reader.read(filename_queue)
+ channels = 3
+ image = tf.image.decode_jpeg(value, channels=channels, name="dataset_image")
+ image.set_shape([None, None, channels])
+
+ # Crop and other random augmentations
+ image = tf.image.random_flip_left_right(image)
+ image = tf.image.random_saturation(image, .95, 1.05)
+ image = tf.image.random_brightness(image, .05)
+ image = tf.image.random_contrast(image, .95, 1.05)
+
+ wiggle = 8
+ off_x, off_y = 25-wiggle, 60-wiggle
+ crop_size = 128
+ crop_size_plus = crop_size + 2*wiggle
+ image = tf.image.crop_to_bounding_box(image, off_y, off_x, crop_size_plus, crop_size_plus)
+ image = tf.random_crop(image, [crop_size, crop_size, 3])
+
+ image = tf.reshape(image, [1, crop_size, crop_size, 3])
+ image = tf.cast(image, tf.float32)/255.0
+
+ if crop_size != image_size:
+ image = tf.image.resize_area(image, [image_size, image_size])
+
+ # The feature is simply a Kx downscaled version
+ K = 4
+ downsampled = tf.image.resize_area(image, [image_size//K, image_size//K])
+
+ feature = tf.reshape(downsampled, [image_size//K, image_size//K, 3])
+ label = tf.reshape(image, [image_size, image_size, 3])
+
+ # Using asynchronous queues
+ features, labels = tf.train.batch([feature, label],
+ batch_size=FLAGS.batch_size,
+ num_threads=4,
+ capacity = capacity_factor*FLAGS.batch_size,
+ name='labels_and_features')
+
+ tf.train.start_queue_runners(sess=sess)
+
+ return features, labels
View
@@ -0,0 +1,190 @@
+import srez_demo
+import srez_input
+import srez_model
+import srez_train
+
+import os.path
+import random
+import numpy as np
+import numpy.random
+
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+# Configuration (alphabetically)
+tf.app.flags.DEFINE_integer('batch_size', 16,
+ "Number of samples per batch.")
+
+tf.app.flags.DEFINE_string('checkpoint_dir', 'checkpoint',
+ "Output folder where checkpoints are dumped.")
+
+tf.app.flags.DEFINE_integer('checkpoint_period', 10000,
+ "Number of batches in between checkpoints")
+
+tf.app.flags.DEFINE_string('dataset', 'dataset',
+ "Path to the dataset directory.")
+
+tf.app.flags.DEFINE_float('epsilon', 1e-8,
+ "Fuzz term to avoid numerical instability")
+
+tf.app.flags.DEFINE_string('run', 'demo',
+ "Which operation to run. [demo|train]")
+
+tf.app.flags.DEFINE_float('gene_l1_factor', .90,
+ "Multiplier for generator L1 loss term")
+
+tf.app.flags.DEFINE_float('learning_beta1', 0.5,
+ "Beta1 parameter used for AdamOptimizer")
+
+tf.app.flags.DEFINE_float('learning_rate_start', 0.00050,
+ "Starting learning rate used for AdamOptimizer")
+
+tf.app.flags.DEFINE_integer('learning_rate_half_life', 10000,
+ "Number of batches until learning rate is halved")
+
+tf.app.flags.DEFINE_bool('log_device_placement', False,
+ "Log the device where variables are placed.")
+
+tf.app.flags.DEFINE_integer('sample_size', 64,
+ "Image sample size in pixels. Range [64,128]")
+
+tf.app.flags.DEFINE_integer('summary_period', 200,
+ "Number of batches between summary data dumps")
+
+tf.app.flags.DEFINE_integer('random_seed', 0,
+ "Seed used to initialize rng.")
+
+tf.app.flags.DEFINE_integer('test_vectors', 16,
+ """Number of features to use for testing""")
+
+tf.app.flags.DEFINE_string('train_dir', 'train',
+ "Output folder where training logs are dumped.")
+
+tf.app.flags.DEFINE_integer('train_time', 20,
+ "Time in minutes to train the model")
+
+def prepare_dirs(delete_train_dir=False):
+ # Create checkpoint dir (do not delete anything)
+ if not tf.gfile.Exists(FLAGS.checkpoint_dir):
+ tf.gfile.MakeDirs(FLAGS.checkpoint_dir)
+
+ # Cleanup train dir
+ if delete_train_dir:
+ if tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.DeleteRecursively(FLAGS.train_dir)
+ tf.gfile.MakeDirs(FLAGS.train_dir)
+
+ # Return names of training files
+ if not tf.gfile.Exists(FLAGS.dataset) or \
+ not tf.gfile.IsDirectory(FLAGS.dataset):
+ raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.dataset,))
+
+ filenames = tf.gfile.ListDirectory(FLAGS.dataset)
+ filenames = sorted(filenames)
+ random.shuffle(filenames)
+ filenames = [os.path.join(FLAGS.dataset, f) for f in filenames]
+
+ return filenames
+
+
+def setup_tensorflow():
+ # Create session
+ config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
+ sess = tf.Session(config=config)
+
+ # Initialize rng with a deterministic seed
+ with sess.graph.as_default():
+ tf.set_random_seed(FLAGS.random_seed)
+
+ random.seed(FLAGS.random_seed)
+ np.random.seed(FLAGS.random_seed)
+
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
+
+ return sess, summary_writer
+
+def _demo():
+ # Load checkpoint
+ if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
+ raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.checkpoint_dir,))
+
+ # Setup global tensorflow state
+ sess, summary_writer = setup_tensorflow()
+
+ # Prepare directories
+ filenames = prepare_dirs(delete_train_dir=False)
+
+ # Setup async input queues
+ features, labels = srez_input.setup_inputs(sess, filenames)
+
+ # Create and initialize model
+ [gene_minput, gene_moutput,
+ gene_output, gene_var_list,
+ disc_real_output, disc_fake_output, disc_var_list] = \
+ srez_model.create_model(sess, features, labels)
+
+ # Restore variables from checkpoint
+ saver = tf.train.Saver()
+ filename = 'checkpoint_new.txt'
+ filename = os.path.join(FLAGS.checkpoint_dir, filename)
+ saver.restore(sess, filename)
+
+ # Execute demo
+ srez_demo.demo1(sess)
+
+class TrainData(object):
+ def __init__(self, dictionary):
+ self.__dict__.update(dictionary)
+
+def _train():
+ # Setup global tensorflow state
+ sess, summary_writer = setup_tensorflow()
+
+ # Prepare directories
+ all_filenames = prepare_dirs(delete_train_dir=True)
+
+ # Separate training and test sets
+ train_filenames = all_filenames[:-FLAGS.test_vectors]
+ test_filenames = all_filenames[-FLAGS.test_vectors:]
+
+ # TBD: Maybe download dataset here
+
+ # Setup async input queues
+ train_features, train_labels = srez_input.setup_inputs(sess, train_filenames)
+ test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)
+
+ # Add some noise during training (think denoising autoencoders)
+ noise_level = .03
+ noisy_train_features = train_features + \
+ tf.random_normal(train_features.get_shape(), stddev=noise_level)
+
+ # Create and initialize model
+ [gene_minput, gene_moutput,
+ gene_output, gene_var_list,
+ disc_real_output, disc_fake_output, disc_var_list] = \
+ srez_model.create_model(sess, noisy_train_features, train_labels)
+
+ gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features)
+ disc_real_loss, disc_fake_loss = \
+ srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
+ disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
+
+ (global_step, learning_rate, gene_minimize, disc_minimize) = \
+ srez_model.create_optimizers(gene_loss, gene_var_list,
+ disc_loss, disc_var_list)
+
+ # Train model
+ train_data = TrainData(locals())
+ srez_train.train_model(train_data)
+
+def main(argv=None):
+ # Training or showing off?
+
+ if FLAGS.run == 'demo':
+ _demo()
+ elif FLAGS.run == 'train':
+ _train()
+
+if __name__ == '__main__':
+ tf.app.run()
Oops, something went wrong.

0 comments on commit 153587b

Please sign in to comment.