Skip to content

Commit

Permalink
Merge branch 'master' into rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
gtoderici committed Feb 24, 2017
2 parents 53ba9e5 + 9871cad commit 1ad1fcc
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 54 deletions.
22 changes: 17 additions & 5 deletions README.md
Expand Up @@ -47,10 +47,11 @@ follow the instructions [here](https://cloud.google.com/ml/docs/how-tos/getting-
If you are participating in the Google Cloud & YouTube-8M Video Understanding
Challenge hosted on kaggle.com, see [these instructions](https://www.kaggle.com/c/youtube8m#getting-started-with-google-cloud) instead.

Please also verify that you have Tensorflow 1.0.0 or higher installed by
running the following command:
Please also verify that you have Python 2.7+ and Tensorflow 1.0.0 or higher
installed by running the following commands:

```sh
python --version
python -c 'import tensorflow as tf; print(tf.__version__)'
```

Expand Down Expand Up @@ -112,7 +113,9 @@ should be deployed to the cloud worker. The module-name refers to the specific
python script which should be executed (in this case the train module). Since
the training data files are hosted in the public 'youtube8m-ml' storage bucket
in the 'us-central1' region, we've colocated our job in the same
region in order to have the fastest access to the data.
region in order to have the fastest access to the data. If you find that your
jobs are getting queued in the 'us-central1' region, you can try the 'us-east1'
region instead.

It may take several minutes before the job starts running on Google Cloud.
When it starts you will see outputs like the following:
Expand Down Expand Up @@ -231,7 +234,8 @@ submit training $JOB_NAME \

The 'FrameLevelLogisticModel' is designed to provide equivalent results to a
logistic model trained over the video-level features. Please look at the
'models.py' file to see how to implement your own models.
'video_level_models.py' or 'frame_level_models.py' files to see how to implement
your own models.


### Using Audio Features
Expand Down Expand Up @@ -269,6 +273,14 @@ the instructions on [tensorflow.org](https://www.tensorflow.org/install/).
This code has been tested with Tensorflow 1.0.0. Going forward, we will continue
to target the latest released version of Tensorflow.

Please verify that you have Python 2.7+ and Tensorflow 1.0.0 or higher
installed by running the following commands:

```sh
python --version
python -c 'import tensorflow as tf; print(tf.__version__)'
```

You can find complete instructions for downloading the dataset on the
[YouTube-8M website](https://research.google.com/youtube8m/download.html).
We recommend downloading the smaller video-level features dataset first when
Expand Down Expand Up @@ -338,7 +350,7 @@ When you are happy with your model, you can generate a csv file of predictions
from it by running

```sh
python inference.py --output_file=$MODEL_DIR/video_level_logistic_model/predictions.csv --input_data_pattern='/path/to/features/validate*.tfrecord' --train_dir=$MODEL_DIR/video_level_logistic_model
python inference.py --output_file=$MODEL_DIR/video_level_logistic_model/predictions.csv --input_data_pattern='/path/to/features/test*.tfrecord' --train_dir=$MODEL_DIR/video_level_logistic_model
```

This will output the top 20 predicted labels from the model for every example
Expand Down
7 changes: 4 additions & 3 deletions average_precision_calculator.py
Expand Up @@ -123,7 +123,7 @@ def accumulate(self, predictions, actuals, num_positives=None):
topk = self._top_n
heap = self._heap

for i in xrange(numpy.size(predictions)):
for i in range(numpy.size(predictions)):
if topk is None or len(heap) < topk:
heapq.heappush(heap, (predictions[i], actuals[i]))
else:
Expand All @@ -146,7 +146,8 @@ def peek_ap_at_n(self):
"""
if self.heap_size <= 0:
return 0
predlists = numpy.array(zip(*self._heap))
predlists = numpy.array(list(zip(*self._heap)))

ap = self.ap_at_n(predlists[0],
predlists[1],
n=self._top_n,
Expand Down Expand Up @@ -237,7 +238,7 @@ def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
r = len(sortidx)
if n is not None:
r = min(r, n)
for i in xrange(r):
for i in range(r):
if actuals[sortidx[i]] > 0:
poscount += 1
ap += poscount / (i + 1) * delta_recall
Expand Down
5 changes: 3 additions & 2 deletions eval.py
Expand Up @@ -103,13 +103,14 @@ def get_input_evaluation_tensors(reader,
filename_queue = tf.train.string_input_producer(
files, shuffle=False, num_epochs=1)
eval_data = [
reader.prepare_reader(filename_queue) for _ in xrange(num_readers)
reader.prepare_reader(filename_queue) for _ in range(num_readers)
]
return tf.train.batch_join(
eval_data,
batch_size=batch_size,
capacity=3 * batch_size,
allow_smaller_final_batch=True)
allow_smaller_final_batch=True,
enqueue_many=True)


def build_graph(reader,
Expand Down
6 changes: 3 additions & 3 deletions eval_util.py
Expand Up @@ -117,12 +117,12 @@ def top_k_by_class(predictions, labels, k=20):
prediction_triplets= []
for video_index in range(predictions.shape[0]):
prediction_triplets.extend(top_k_triplets(predictions[video_index],labels[video_index], k))
out_predictions = [[] for v in xrange(num_classes)]
out_labels = [[] for v in xrange(num_classes)]
out_predictions = [[] for v in range(num_classes)]
out_labels = [[] for v in range(num_classes)]
for triplet in prediction_triplets:
out_predictions[triplet[0]].append(triplet[1])
out_labels[triplet[0]].append(triplet[2])
out_true_positives = [numpy.sum(labels[:,i]) for i in xrange(num_classes)]
out_true_positives = [numpy.sum(labels[:,i]) for i in range(num_classes)]

return out_predictions, out_labels, out_true_positives

Expand Down
2 changes: 1 addition & 1 deletion frame_level_models.py
Expand Up @@ -79,7 +79,7 @@ def create_model(self, model_input, vocab_size, num_frames, **unused_params):

output = slim.fully_connected(
avg_pooled, vocab_size, activation_fn=tf.nn.sigmoid,
weights_regularizer=slim.l2_regularizer(0.01))
weights_regularizer=slim.l2_regularizer(1e-8))
return {"predictions": output}

class DBoFModel(models.BaseModel):
Expand Down
13 changes: 9 additions & 4 deletions inference.py
Expand Up @@ -24,6 +24,7 @@
from tensorflow import flags
from tensorflow import gfile
from tensorflow import logging
from builtins import range

import eval_util
import losses
Expand Down Expand Up @@ -66,12 +67,15 @@

def format_lines(video_ids, predictions, top_k):
batch_size = len(video_ids)
for video_index in xrange(batch_size):
for video_index in range(batch_size):
top_indices = numpy.argpartition(predictions[video_index], -top_k)[-top_k:]
line = [(class_index, predictions[video_index][class_index])
for class_index in top_indices]
# print("Type - Test :")
# print(type(video_ids[video_index]))
# print(video_ids[video_index].decode('utf-8'))
line = sorted(line, key=lambda p: -p[1])
yield video_ids[video_index] + "," + " ".join("%i %f" % pair
yield video_ids[video_index].decode('utf-8') + "," + " ".join("%i %f" % pair
for pair in line) + "\n"


Expand Down Expand Up @@ -101,12 +105,13 @@ def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
filename_queue = tf.train.string_input_producer(
files, num_epochs=1, shuffle=False)
examples_and_labels = [reader.prepare_reader(filename_queue)
for _ in xrange(num_readers)]
for _ in range(num_readers)]

video_id_batch, video_batch, unused_labels, num_frames_batch = (
tf.train.batch_join(examples_and_labels,
batch_size=batch_size,
allow_smaller_final_batch = True))
allow_smaller_final_batch = True,
enqueue_many=True))
return video_id_batch, video_batch, num_frames_batch

def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
Expand Down
6 changes: 3 additions & 3 deletions mean_average_precision_calculator.py
Expand Up @@ -64,7 +64,7 @@ def __init__(self, num_class):

self._ap_calculators = [] # member of AveragePrecisionCalculator
self._num_class = num_class # total number of classes
for i in xrange(num_class):
for i in range(num_class):
self._ap_calculators.append(
average_precision_calculator.AveragePrecisionCalculator())

Expand All @@ -89,7 +89,7 @@ def accumulate(self, predictions, actuals, num_positives=None):
num_positives = [None for i in predictions.shape[1]]

calculators = self._ap_calculators
for i in xrange(len(predictions)):
for i in range(len(predictions)):
calculators[i].accumulate(predictions[i], actuals[i], num_positives[i])

def clear(self):
Expand All @@ -108,5 +108,5 @@ def peek_map_at_n(self):
class.
"""
aps = [self._ap_calculators[i].peek_ap_at_n()
for i in xrange(self._num_class)]
for i in range(self._num_class)]
return aps
34 changes: 16 additions & 18 deletions readers.py
Expand Up @@ -18,7 +18,6 @@
import utils

from tensorflow import logging

def resize_axis(tensor, axis, new_size, fill_value=0):
"""Truncates or pads a tensor to new_size on on a given axis.
Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(self,
self.feature_sizes = feature_sizes
self.feature_names = feature_names

def prepare_reader(self, filename_queue,):
def prepare_reader(self, filename_queue, batch_size=1024):
"""Creates a single reader thread for pre-aggregated YouTube 8M Examples.
Args:
Expand All @@ -102,7 +101,7 @@ def prepare_reader(self, filename_queue,):
A tuple of video indexes, features, labels, and padding data.
"""
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
_, serialized_examples = reader.read_up_to(filename_queue, batch_size)

# set the mapping from the fields to data types in the proto
num_features = len(self.feature_names)
Expand All @@ -117,22 +116,13 @@ def prepare_reader(self, filename_queue,):
feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature(
[self.feature_sizes[feature_index]], tf.float32)

features = tf.parse_single_example(serialized_example,
features=feature_map)

labels = (tf.cast(
tf.sparse_to_dense(features["labels"].values, (self.num_classes,), 1,
validate_indices=False),
tf.bool))
features = tf.parse_example(serialized_examples, features=feature_map)
labels = tf.sparse_to_indicator(features["labels"], self.num_classes)
labels.set_shape([None, self.num_classes])
concatenated_features = tf.concat([
features[feature_name] for feature_name in self.feature_names], 0)
fdim = concatenated_features.get_shape()[0].value
assert fdim == sum(self.feature_sizes), \
"dimensionality of the concatenated feature (={}) != sum of " \
"dimensionalities of groups of features (={})".format( \
fdim, sum(self.feature_sizes))
features[feature_name] for feature_name in self.feature_names], 1)

return features["video_id"], concatenated_features, labels, tf.constant(1)
return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]])

class YT8MFrameFeatureReader(BaseReader):
"""Reads TFRecords of SequenceExamples.
Expand Down Expand Up @@ -258,5 +248,13 @@ def prepare_reader(self,

# concatenate different features
video_matrix = tf.concat(feature_matrices, 1)
return contexts["video_id"], video_matrix, labels, num_frames

# convert to batch format.
# TODO: Do proper batch reads to remove the IO bottleneck.
batch_video_ids = tf.expand_dims(contexts["video_id"], 0)
batch_video_matrix = tf.expand_dims(video_matrix, 0)
batch_labels = tf.expand_dims(labels, 0)
batch_frames = tf.expand_dims(num_frames, 0)

return batch_video_ids, batch_video_matrix, batch_labels, batch_frames

31 changes: 21 additions & 10 deletions train.py
Expand Up @@ -68,7 +68,7 @@
"label_loss", "CrossEntropyLoss",
"Which loss function to use for training the model.")
flags.DEFINE_float(
"regularization_penalty", 1e-3,
"regularization_penalty", 1,
"How much weight to give to the regularization loss (the label loss has "
"a weight of 1).")
flags.DEFINE_float("base_learning_rate", 0.01,
Expand All @@ -79,6 +79,10 @@
flags.DEFINE_float("learning_rate_decay_examples", 4000000,
"Multiply current learning rate by learning_rate_decay "
"every learning_rate_decay_examples.")
flags.DEFINE_integer("num_epochs", 5,
"How many passes to make over the dataset before "
"halting training.")

# Other flags.
flags.DEFINE_integer("num_readers", 8,
"How many threads to use for reading input files.")
Expand Down Expand Up @@ -153,14 +157,15 @@ def get_input_data_tensors(reader,
filename_queue = tf.train.string_input_producer(files,
num_epochs=num_epochs)
training_data = [
reader.prepare_reader(filename_queue) for _ in xrange(num_readers)]
reader.prepare_reader(filename_queue) for _ in range(num_readers)]

return tf.train.shuffle_batch_join(
training_data,
batch_size=batch_size,
capacity=FLAGS.batch_size * 5,
min_after_dequeue=FLAGS.batch_size,
allow_smaller_final_batch=True)
allow_smaller_final_batch=True,
enqueue_many=True)


def find_class_by_name(name, modules):
Expand All @@ -179,9 +184,9 @@ def build_graph(reader,
learning_rate_decay=0.95,
optimizer_class=tf.train.AdamOptimizer,
clip_gradient_norm=1.0,
regularization_penalty=1e-3,
regularization_penalty=1,
num_readers=1,
num_epochs=None):
num_epochs=100):
"""Creates the Tensorflow graph.
This will only be called once in the life of
Expand Down Expand Up @@ -250,6 +255,9 @@ def build_graph(reader,
reg_loss = result["regularization_loss"]
else:
reg_loss = tf.constant(0.0)
reg_losses = tf.losses.get_regularization_losses()
if reg_losses:
reg_loss += tf.add_n(reg_losses)
if regularization_penalty != 0:
tf.summary.scalar("reg_loss", reg_loss)

Expand Down Expand Up @@ -310,13 +318,15 @@ def train_loop(train_dir=None,
sv = tf.train.Supervisor(logdir=train_dir,
is_chief=is_chief,
global_step=global_step,
save_model_secs=60,
save_summaries_secs=60,
save_model_secs=15 * 60,
save_summaries_secs=120,
saver=saver)
sess = sv.prepare_or_wait_for_session(
master,
start_standard_services=start_supervisor_services,
config=tf.ConfigProto(log_device_placement=False))
config=tf.ConfigProto(
log_device_placement=False,
allow_soft_placement=True))

logging.info("prepared session")
sv.start_queue_runners(sess)
Expand Down Expand Up @@ -417,9 +427,10 @@ def main(unused_argv):
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size)
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
logging.info("built graph")
saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)

train_loop(is_chief=is_chief,
train_dir=FLAGS.train_dir,
Expand Down

0 comments on commit 1ad1fcc

Please sign in to comment.