Skip to content

Commit

Permalink
Lint and small fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
XericZephyr committed May 14, 2019
1 parent 73aee83 commit 09774db
Showing 1 changed file with 54 additions and 60 deletions.
114 changes: 54 additions & 60 deletions eval.py
Expand Up @@ -18,32 +18,40 @@
import os
import time

from absl import logging
import eval_util
import losses
import frame_level_models
import video_level_models
import losses
import readers
import tensorflow as tf
from tensorflow.python.lib.io import file_io
from tensorflow import app
from tensorflow import flags
from tensorflow import gfile
from tensorflow import logging
# from tensorflow import logging
from tensorflow.python.lib.io import file_io
import utils
import video_level_models

FLAGS = flags.FLAGS

if __name__ == "__main__":
# Dataset flags.
flags.DEFINE_string("train_dir", "/tmp/yt8m_model/",
"The directory to load the model files from. "
"The tensorboard metrics files are also saved to this "
"directory.")
flags.DEFINE_string(
"train_dir", "/tmp/yt8m_model/",
"The directory to load the model files from. "
"The tensorboard metrics files are also saved to this "
"directory.")
flags.DEFINE_string(
"eval_data_pattern", "",
"File glob defining the evaluation dataset in tensorflow.SequenceExample "
"format. The SequenceExamples are expected to have an 'rgb' byte array "
"sequence feature as well as a 'labels' int64 context feature.")
flags.DEFINE_bool(
"segment_labels", False,
"If set, then --train_data_pattern must be frame-level features (but with"
" segment_labels). Otherwise, --train_data_pattern must be aggregated "
"video-level features. The model must also be set appropriately (i.e. to "
"read 3D batches VS 4D batches.")

# Other flags.
flags.DEFINE_integer("batch_size", 1024,
Expand Down Expand Up @@ -80,12 +88,12 @@ def get_input_evaluation_tensors(reader,
Raises:
IOError: If no files matching the given pattern were found.
"""
logging.info("Using batch size of " + str(batch_size) + " for evaluation.")
logging.info("Using batch size of %d for evaluation.", batch_size)
with tf.name_scope("eval_input"):
files = gfile.Glob(data_pattern)
if not files:
raise IOError("Unable to find the evaluation files.")
logging.info("number of evaluation files: " + str(len(files)))
logging.info("number of evaluation files: %d", len(files))
filename_queue = tf.train.string_input_producer(
files, shuffle=False, num_epochs=1)
eval_data = [
Expand All @@ -109,21 +117,18 @@ def build_graph(reader,
Args:
reader: The data file reader. It should inherit from BaseReader.
model: The core model (e.g. logistic or neural net). It should inherit
from BaseModel.
model: The core model (e.g. logistic or neural net). It should inherit from
BaseModel.
eval_data_pattern: glob path to the evaluation data files.
label_loss_fn: What kind of loss to apply to the model. It should inherit
from BaseLoss.
from BaseLoss.
batch_size: How many examples to process at a time.
num_readers: How many threads to use for I/O operations.
"""

global_step = tf.Variable(0, trainable=False, name="global_step")
input_data_dict = get_input_evaluation_tensors(
reader,
eval_data_pattern,
batch_size=batch_size,
num_readers=num_readers)
reader, eval_data_pattern, batch_size=batch_size, num_readers=num_readers)
video_id_batch = input_data_dict["video_ids"]
model_input_raw = input_data_dict["video_matrix"]
labels_batch = input_data_dict["labels"]
Expand All @@ -136,11 +141,12 @@ def build_graph(reader,
model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

with tf.variable_scope("tower"):
result = model.create_model(model_input,
num_frames=num_frames,
vocab_size=reader.num_classes,
labels=labels_batch,
is_training=False)
result = model.create_model(
model_input,
num_frames=num_frames,
vocab_size=reader.num_classes,
labels=labels_batch,
is_training=False)
predictions = result["predictions"]
tf.summary.histogram("model_activations", predictions)
if "loss" in result.keys():
Expand All @@ -159,23 +165,6 @@ def build_graph(reader,
tf.add_to_collection("summary_op", tf.summary.merge_all())


def get_latest_checkpoint():
index_files = file_io.get_matching_files(os.path.join(FLAGS.train_dir, 'model.ckpt-*.index'))

# No files
if not index_files:
return None


# Index file path with the maximum step size.
latest_index_file = sorted(
[(int(os.path.basename(f).split("-")[-1].split(".")[0]), f)
for f in index_files])[-1][1]

# Chop off .index suffix and return
return latest_index_file[:-6]


def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
summary_op, saver, summary_writer, evl_metrics,
last_global_step_val):
Expand All @@ -198,24 +187,27 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,

global_step_val = -1
with tf.Session() as sess:
latest_checkpoint = get_latest_checkpoint()
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
if latest_checkpoint:
logging.info("Loading checkpoint for eval: " + latest_checkpoint)
logging.info("Loading checkpoint for eval: %s", latest_checkpoint)
# Restores from checkpoint
saver.restore(sess, latest_checkpoint)
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

# Save model
saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model", "inference_model"))
saver.save(
sess,
os.path.join(FLAGS.train_dir, "inference_model", "inference_model"))
else:
logging.info("No checkpoint file found.")
return global_step_val

if global_step_val == last_global_step_val:
logging.info("skip this checkpoint global_step_val=%s "
"(same as the previous one).", global_step_val)
logging.info(
"skip this checkpoint global_step_val=%s "
"(same as the previous one).", global_step_val)
return global_step_val

sess.run([tf.local_variables_initializer()])
Expand All @@ -226,20 +218,21 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(
sess, coord=coord, daemon=True,
start=True))
threads.extend(
qr.create_threads(sess, coord=coord, daemon=True, start=True))
logging.info("enter eval_once loop global_step_val = %s. ",
global_step_val)

evl_metrics.clear()

examples_processed = 0
while not coord.should_stop():
logging.info("Running eval step...")
batch_start_time = time.time()
_, predictions_val, labels_val, loss_val, summary_val = sess.run(
fetches)
seconds_per_batch = time.time() - batch_start_time
logging.info("calculating metrics...")
example_per_second = labels_val.shape[0] / seconds_per_batch
examples_processed += labels_val.shape[0]

Expand Down Expand Up @@ -272,7 +265,7 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
logging.info(epochinfo)
evl_metrics.clear()
except Exception as e: # pylint: disable=broad-except
logging.info("Unexpected exception: " + str(e))
logging.info("Unexpected exception: %s", str(e))
coord.request_stop(e)

coord.request_stop()
Expand All @@ -282,6 +275,7 @@ def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,


def evaluate():
"""Starts main evaluation loop."""
tf.set_random_seed(0) # for reproducibility

# Write json of flags
Expand All @@ -297,19 +291,21 @@ def evaluate():
flags_dict["feature_names"], flags_dict["feature_sizes"])

if flags_dict["frame_features"]:
reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
feature_sizes=feature_sizes)
reader = readers.YT8MFrameFeatureReader(
feature_names=feature_names,
feature_sizes=feature_sizes,
segment_labels=FLAGS.segment_labels)
else:
reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
feature_sizes=feature_sizes)
reader = readers.YT8MAggregatedFeatureReader(
feature_names=feature_names, feature_sizes=feature_sizes)

model = find_class_by_name(flags_dict["model"],
[frame_level_models, video_level_models])()
[frame_level_models, video_level_models])()
label_loss_fn = find_class_by_name(flags_dict["label_loss"], [losses])()

if FLAGS.eval_data_pattern is "":
raise IOError("'eval_data_pattern' was not specified. " +
"Nothing to evaluate.")
if not FLAGS.eval_data_pattern:
raise IOError("'eval_data_pattern' was not specified. Nothing to "
"evaluate.")

build_graph(
reader=reader,
Expand All @@ -327,8 +323,7 @@ def evaluate():

saver = tf.train.Saver(tf.global_variables())
summary_writer = tf.summary.FileWriter(
os.path.join(FLAGS.train_dir, "eval"),
graph=tf.get_default_graph())
os.path.join(FLAGS.train_dir, "eval"), graph=tf.get_default_graph())

evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

Expand All @@ -344,10 +339,9 @@ def evaluate():

def main(unused_argv):
logging.set_verbosity(tf.logging.INFO)
print("tensorflow version: %s" % tf.__version__)
logging.info("tensorflow version: %s", tf.__version__)
evaluate()


if __name__ == "__main__":
app.run()

0 comments on commit 09774db

Please sign in to comment.