From 09774db80a515b667a91b14fe21a6134f3856c7a Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Mon, 13 May 2019 19:51:32 -0700 Subject: [PATCH] Lint and small fixes. --- eval.py | 114 +++++++++++++++++++++++++++----------------------------- 1 file changed, 54 insertions(+), 60 deletions(-) diff --git a/eval.py b/eval.py index 407ded0d..f891944b 100644 --- a/eval.py +++ b/eval.py @@ -18,32 +18,40 @@ import os import time +from absl import logging import eval_util -import losses import frame_level_models -import video_level_models +import losses import readers import tensorflow as tf -from tensorflow.python.lib.io import file_io from tensorflow import app from tensorflow import flags from tensorflow import gfile -from tensorflow import logging +# from tensorflow import logging +from tensorflow.python.lib.io import file_io import utils +import video_level_models FLAGS = flags.FLAGS if __name__ == "__main__": # Dataset flags. - flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", - "The directory to load the model files from. " - "The tensorboard metrics files are also saved to this " - "directory.") + flags.DEFINE_string( + "train_dir", "/tmp/yt8m_model/", + "The directory to load the model files from. " + "The tensorboard metrics files are also saved to this " + "directory.") flags.DEFINE_string( "eval_data_pattern", "", "File glob defining the evaluation dataset in tensorflow.SequenceExample " "format. The SequenceExamples are expected to have an 'rgb' byte array " "sequence feature as well as a 'labels' int64 context feature.") + flags.DEFINE_bool( + "segment_labels", False, + "If set, then --train_data_pattern must be frame-level features (but with" + " segment_labels). Otherwise, --train_data_pattern must be aggregated " + "video-level features. The model must also be set appropriately (i.e. to " + "read 3D batches VS 4D batches.") # Other flags. flags.DEFINE_integer("batch_size", 1024, @@ -80,12 +88,12 @@ def get_input_evaluation_tensors(reader, Raises: IOError: If no files matching the given pattern were found. """ - logging.info("Using batch size of " + str(batch_size) + " for evaluation.") + logging.info("Using batch size of %d for evaluation.", batch_size) with tf.name_scope("eval_input"): files = gfile.Glob(data_pattern) if not files: raise IOError("Unable to find the evaluation files.") - logging.info("number of evaluation files: " + str(len(files))) + logging.info("number of evaluation files: %d", len(files)) filename_queue = tf.train.string_input_producer( files, shuffle=False, num_epochs=1) eval_data = [ @@ -109,21 +117,18 @@ def build_graph(reader, Args: reader: The data file reader. It should inherit from BaseReader. - model: The core model (e.g. logistic or neural net). It should inherit - from BaseModel. + model: The core model (e.g. logistic or neural net). It should inherit from + BaseModel. eval_data_pattern: glob path to the evaluation data files. label_loss_fn: What kind of loss to apply to the model. It should inherit - from BaseLoss. + from BaseLoss. batch_size: How many examples to process at a time. num_readers: How many threads to use for I/O operations. """ global_step = tf.Variable(0, trainable=False, name="global_step") input_data_dict = get_input_evaluation_tensors( - reader, - eval_data_pattern, - batch_size=batch_size, - num_readers=num_readers) + reader, eval_data_pattern, batch_size=batch_size, num_readers=num_readers) video_id_batch = input_data_dict["video_ids"] model_input_raw = input_data_dict["video_matrix"] labels_batch = input_data_dict["labels"] @@ -136,11 +141,12 @@ def build_graph(reader, model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) with tf.variable_scope("tower"): - result = model.create_model(model_input, - num_frames=num_frames, - vocab_size=reader.num_classes, - labels=labels_batch, - is_training=False) + result = model.create_model( + model_input, + num_frames=num_frames, + vocab_size=reader.num_classes, + labels=labels_batch, + is_training=False) predictions = result["predictions"] tf.summary.histogram("model_activations", predictions) if "loss" in result.keys(): @@ -159,23 +165,6 @@ def build_graph(reader, tf.add_to_collection("summary_op", tf.summary.merge_all()) -def get_latest_checkpoint(): - index_files = file_io.get_matching_files(os.path.join(FLAGS.train_dir, 'model.ckpt-*.index')) - - # No files - if not index_files: - return None - - - # Index file path with the maximum step size. - latest_index_file = sorted( - [(int(os.path.basename(f).split("-")[-1].split(".")[0]), f) - for f in index_files])[-1][1] - - # Chop off .index suffix and return - return latest_index_file[:-6] - - def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, summary_op, saver, summary_writer, evl_metrics, last_global_step_val): @@ -198,9 +187,9 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, global_step_val = -1 with tf.Session() as sess: - latest_checkpoint = get_latest_checkpoint() + latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir) if latest_checkpoint: - logging.info("Loading checkpoint for eval: " + latest_checkpoint) + logging.info("Loading checkpoint for eval: %s", latest_checkpoint) # Restores from checkpoint saver.restore(sess, latest_checkpoint) # Assuming model_checkpoint_path looks something like: @@ -208,14 +197,17 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, global_step_val = os.path.basename(latest_checkpoint).split("-")[-1] # Save model - saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model", "inference_model")) + saver.save( + sess, + os.path.join(FLAGS.train_dir, "inference_model", "inference_model")) else: logging.info("No checkpoint file found.") return global_step_val if global_step_val == last_global_step_val: - logging.info("skip this checkpoint global_step_val=%s " - "(same as the previous one).", global_step_val) + logging.info( + "skip this checkpoint global_step_val=%s " + "(same as the previous one).", global_step_val) return global_step_val sess.run([tf.local_variables_initializer()]) @@ -226,9 +218,8 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, try: threads = [] for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): - threads.extend(qr.create_threads( - sess, coord=coord, daemon=True, - start=True)) + threads.extend( + qr.create_threads(sess, coord=coord, daemon=True, start=True)) logging.info("enter eval_once loop global_step_val = %s. ", global_step_val) @@ -236,10 +227,12 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, examples_processed = 0 while not coord.should_stop(): + logging.info("Running eval step...") batch_start_time = time.time() _, predictions_val, labels_val, loss_val, summary_val = sess.run( fetches) seconds_per_batch = time.time() - batch_start_time + logging.info("calculating metrics...") example_per_second = labels_val.shape[0] / seconds_per_batch examples_processed += labels_val.shape[0] @@ -272,7 +265,7 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, logging.info(epochinfo) evl_metrics.clear() except Exception as e: # pylint: disable=broad-except - logging.info("Unexpected exception: " + str(e)) + logging.info("Unexpected exception: %s", str(e)) coord.request_stop(e) coord.request_stop() @@ -282,6 +275,7 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, def evaluate(): + """Starts main evaluation loop.""" tf.set_random_seed(0) # for reproducibility # Write json of flags @@ -297,19 +291,21 @@ def evaluate(): flags_dict["feature_names"], flags_dict["feature_sizes"]) if flags_dict["frame_features"]: - reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, - feature_sizes=feature_sizes) + reader = readers.YT8MFrameFeatureReader( + feature_names=feature_names, + feature_sizes=feature_sizes, + segment_labels=FLAGS.segment_labels) else: - reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, - feature_sizes=feature_sizes) + reader = readers.YT8MAggregatedFeatureReader( + feature_names=feature_names, feature_sizes=feature_sizes) model = find_class_by_name(flags_dict["model"], - [frame_level_models, video_level_models])() + [frame_level_models, video_level_models])() label_loss_fn = find_class_by_name(flags_dict["label_loss"], [losses])() - if FLAGS.eval_data_pattern is "": - raise IOError("'eval_data_pattern' was not specified. " + - "Nothing to evaluate.") + if not FLAGS.eval_data_pattern: + raise IOError("'eval_data_pattern' was not specified. Nothing to " + "evaluate.") build_graph( reader=reader, @@ -327,8 +323,7 @@ def evaluate(): saver = tf.train.Saver(tf.global_variables()) summary_writer = tf.summary.FileWriter( - os.path.join(FLAGS.train_dir, "eval"), - graph=tf.get_default_graph()) + os.path.join(FLAGS.train_dir, "eval"), graph=tf.get_default_graph()) evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k) @@ -344,10 +339,9 @@ def evaluate(): def main(unused_argv): logging.set_verbosity(tf.logging.INFO) - print("tensorflow version: %s" % tf.__version__) + logging.info("tensorflow version: %s", tf.__version__) evaluate() if __name__ == "__main__": app.run() -