Skip to content

Commit

Permalink
Split utils into summary_utils, training_utils and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
mfigurnov committed Mar 21, 2017
1 parent bc679b6 commit ef2ee4b
Show file tree
Hide file tree
Showing 13 changed files with 506 additions and 426 deletions.
20 changes: 11 additions & 9 deletions cifar_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import cifar_data_provider
import cifar_model
import summary_utils
import training_utils
import utils


Expand Down Expand Up @@ -135,20 +137,20 @@ def train():
# Specify the loss function:
tf.losses.softmax_cross_entropy(logits, one_hot_labels)
if FLAGS.use_act:
utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('Total Loss', total_loss)

metric_map = {} # utils.flops_metric_map(end_points, False)
metric_map = {} # summary_utils.flops_metric_map(end_points, False)
if FLAGS.use_act:
metric_map.update(utils.act_metric_map(end_points, False))
metric_map.update(summary_utils.act_metric_map(end_points, False))
for name, value in metric_map.iteritems():
tf.summary.scalar(name, value)

if FLAGS.use_act and FLAGS.sact:
utils.add_heatmaps_image_summary(end_points)
summary_utils.add_heatmaps_image_summary(end_points)

init_fn = utils.finetuning_init_fn(FLAGS.finetune_path)
init_fn = training_utils.finetuning_init_fn(FLAGS.finetune_path)

# Specify the optimization scheme:
global_step = slim.get_or_create_global_step()
Expand Down Expand Up @@ -204,7 +206,7 @@ def evaluate():

tf.losses.softmax_cross_entropy(logits, one_hot_labels)
if FLAGS.use_act:
utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)

loss = tf.losses.get_total_loss()

Expand All @@ -214,9 +216,9 @@ def evaluate():
'eval/Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
'eval/Mean Loss': slim.metrics.streaming_mean(loss),
}
metric_map.update(utils.flops_metric_map(end_points, True))
metric_map.update(summary_utils.flops_metric_map(end_points, True))
if FLAGS.use_act:
metric_map.update(utils.act_metric_map(end_points, True))
metric_map.update(summary_utils.act_metric_map(end_points, True))
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
metric_map)

Expand All @@ -226,7 +228,7 @@ def evaluate():
tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

if FLAGS.use_act and FLAGS.sact:
utils.add_heatmaps_image_summary(end_points)
summary_utils.add_heatmaps_image_summary(end_points)

# This ensures that we make a single pass over all of the data.
num_batches = math.ceil(num_samples / float(FLAGS.eval_batch_size))
Expand Down
4 changes: 2 additions & 2 deletions cifar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tensorflow.contrib.slim.nets import resnet_utils

import flopsometer
import utils
import resnet_act


@slim.add_arg_scope
Expand Down Expand Up @@ -124,7 +124,7 @@ def resnet(inputs,
net, current_flops = flopsometer.conv2d(
net, 16, 3, activation_fn=None, normalizer_fn=None)
end_points['flops'] += current_flops
net, end_points = utils.stack_blocks(
net, end_points = resnet_act.stack_blocks(
net,
blocks,
use_act=use_act,
Expand Down
19 changes: 10 additions & 9 deletions cifar_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import tensorflow as tf
from tensorflow.contrib import slim

import utils
import cifar_model
import summary_utils
import training_utils


class CifarModelTest(tf.test.TestCase):
Expand All @@ -45,8 +46,8 @@ def _runBatch(self, is_training, use_act, model=[5]):
use_act=use_act,
sact=False)
if use_act:
metrics = utils.act_metric_map(end_points, False)
metrics.update(utils.flops_metric_map(end_points, False))
metrics = summary_utils.act_metric_map(end_points, False)
metrics.update(summary_utils.flops_metric_map(end_points, False))
else:
metrics = {}

Expand All @@ -60,7 +61,7 @@ def _runBatch(self, is_training, use_act, model=[5]):
tf.losses.softmax_cross_entropy(
logits, one_hot_labels, label_smoothing=0.1, weights=1.0)
if use_act:
utils.add_all_ponder_costs(end_points, weights=1.0)
training_utils.add_all_ponder_costs(end_points, weights=1.0)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
train_op = slim.learning.create_train_op(total_loss, optimizer)
Expand Down Expand Up @@ -123,8 +124,8 @@ def _runBatch(self, is_training):
num_classes=num_classes,
use_act=True,
sact=True)
metrics = utils.act_metric_map(end_points, False)
metrics.update(utils.flops_metric_map(end_points, False))
metrics = summary_utils.act_metric_map(end_points, False)
metrics.update(summary_utils.flops_metric_map(end_points, False))

# Check that there are no global updates as they break tf.cond.
self.assertEqual(tf.get_collection(tf.GraphKeys.UPDATE_OPS), [])
Expand All @@ -135,7 +136,7 @@ def _runBatch(self, is_training):
one_hot_labels = slim.one_hot_encoding(labels, num_classes)
tf.losses.softmax_cross_entropy(
logits, one_hot_labels, label_smoothing=0.1, weights=1.0)
utils.add_all_ponder_costs(end_points, weights=1.0)
training_utils.add_all_ponder_costs(end_points, weights=1.0)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
train_op = slim.learning.create_train_op(total_loss, optimizer)
Expand Down Expand Up @@ -170,13 +171,13 @@ def testVisualizationBasic(self):
use_act=True,
sact=True)

vis_ponder = utils.sact_image_heatmap(
vis_ponder = summary_utils.sact_image_heatmap(
end_points,
'ponder_cost',
num_images=num_images,
alpha=0.75,
border=border)
vis_units = utils.sact_image_heatmap(
vis_units = summary_utils.sact_image_heatmap(
end_points,
'num_units',
num_images=num_images,
Expand Down
9 changes: 5 additions & 4 deletions imagenet_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import imagenet_data_provider
import imagenet_model
import summary_utils
import utils

FLAGS = tf.app.flags.FLAGS
Expand Down Expand Up @@ -87,7 +88,7 @@ def main(_):

# Define the model:
with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)):
model = utils.parse_num_layers(FLAGS.model)
model = utils.split_and_int(FLAGS.model)
logits, end_points = imagenet_model.get_network(
images,
model,
Expand Down Expand Up @@ -119,9 +120,9 @@ def main(_):
slim.metrics.streaming_recall_at_k(end_points['predictions'],
labels, 5),
}
metric_map.update(utils.flops_metric_map(end_points, True))
metric_map.update(summary_utils.flops_metric_map(end_points, True))
if FLAGS.use_act:
metric_map.update(utils.act_metric_map(end_points, True))
metric_map.update(summary_utils.act_metric_map(end_points, True))

names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
metric_map)
Expand All @@ -132,7 +133,7 @@ def main(_):
tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ)

if FLAGS.use_act and FLAGS.sact:
utils.add_heatmaps_image_summary(end_points, border=10)
summary_utils.add_heatmaps_image_summary(end_points, border=10)

# This ensures that we make a single pass over all of the data.
num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
Expand Down
7 changes: 4 additions & 3 deletions imagenet_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import imagenet_data_provider
import imagenet_model
import summary_utils
import utils

FLAGS = tf.app.flags.FLAGS
Expand Down Expand Up @@ -81,9 +82,9 @@ def main(_):
use_act=FLAGS.use_act,
sact=FLAGS.sact)

utils.export_to_h5(FLAGS.checkpoint_path, FLAGS.export_path,
images, end_points, FLAGS.num_examples,
FLAGS.batch_size, FLAGS.sact)
summary_utils.export_to_h5(FLAGS.checkpoint_path, FLAGS.export_path,
images, end_points, FLAGS.num_examples,
FLAGS.batch_size, FLAGS.sact)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions imagenet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import act
import flopsometer
import utils
import resnet_act


@slim.add_arg_scope
Expand Down Expand Up @@ -123,7 +123,7 @@ def resnet_v2(inputs,
end_points['flops'] += current_flops
net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
# Early stopping is broken in distributed training.
net, end_points = utils.stack_blocks(
net, end_points = resnet_act.stack_blocks(
net,
blocks,
use_act=use_act,
Expand Down
19 changes: 10 additions & 9 deletions imagenet_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from tensorflow.contrib import slim

import imagenet_model
import utils
import summary_utils
import training_utils


class ImagenetModelTest(tf.test.TestCase):
Expand All @@ -44,8 +45,8 @@ def _runBatch(self,
logits, end_points = imagenet_model.get_network(
images, model, num_classes, use_act, False)
if use_act:
metrics = utils.act_metric_map(end_points, False)
metrics.update(utils.flops_metric_map(end_points, False))
metrics = summary_utils.act_metric_map(end_points, False)
metrics.update(summary_utils.flops_metric_map(end_points, False))
else:
metrics = {}

Expand All @@ -60,7 +61,7 @@ def _runBatch(self,
tf.losses.softmax_cross_entropy(
logits, one_hot_labels, label_smoothing=0.1, weights=1.0)
if use_act:
utils.add_all_ponder_costs(end_points, weights=1.0)
training_utils.add_all_ponder_costs(end_points, weights=1.0)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
train_op = slim.learning.create_train_op(total_loss, optimizer)
Expand Down Expand Up @@ -115,8 +116,8 @@ def _runBatch(self, is_training):
imagenet_model.resnet_arg_scope(is_training=is_training)):
logits, end_points = imagenet_model.get_network(
images, [50], num_classes, True, True)
metrics = utils.act_metric_map(end_points, False)
metrics.update(utils.flops_metric_map(end_points, False))
metrics = summary_utils.act_metric_map(end_points, False)
metrics.update(summary_utils.flops_metric_map(end_points, False))

# Check that there are no global updates as they break tf.cond.
# TODO:re-enable
Expand All @@ -128,7 +129,7 @@ def _runBatch(self, is_training):
one_hot_labels = slim.one_hot_encoding(labels, num_classes)
tf.losses.softmax_cross_entropy(
logits, one_hot_labels, label_smoothing=0.1, weights=1.0)
utils.add_all_ponder_costs(end_points, weights=1.0)
training_utils.add_all_ponder_costs(end_points, weights=1.0)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
train_op = slim.learning.create_train_op(total_loss, optimizer)
Expand Down Expand Up @@ -159,13 +160,13 @@ def testVisualizationBasic(self):
logits, end_points = imagenet_model.get_network(
images, [50], num_classes, True, True)

vis_ponder = utils.sact_image_heatmap(
vis_ponder = summary_utils.sact_image_heatmap(
end_points,
'ponder_cost',
num_images=num_images,
alpha=0.75,
border=border)
vis_units = utils.sact_image_heatmap(
vis_units = summary_utils.sact_image_heatmap(
end_points,
'num_units',
num_images=num_images,
Expand Down
12 changes: 7 additions & 5 deletions imagenet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import imagenet_data_provider
import imagenet_model
import summary_utils
import training_utils
import utils

FLAGS = tf.app.flags.FLAGS
Expand Down Expand Up @@ -132,7 +134,7 @@ def main(_):
tf.losses.softmax_cross_entropy(
logits, labels, label_smoothing=0.1, weights=1.0)
if FLAGS.use_act:
utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau)
total_loss = tf.losses.get_total_loss()

# Setup the moving averages:
Expand Down Expand Up @@ -175,7 +177,7 @@ def main(_):
replica_id=replica_id,
total_num_replicas=FLAGS.worker_replicas)

init_fn = utils.finetuning_init_fn(FLAGS.finetune_path)
init_fn = training_utils.finetuning_init_fn(FLAGS.finetune_path)

train_tensor = slim.learning.create_train_op(
total_loss,
Expand All @@ -186,14 +188,14 @@ def main(_):
tf.summary.scalar('losses/Total Loss', total_loss)
tf.summary.scalar('training/Learning Rate', learning_rate)

metric_map = {} # utils.flops_metric_map(end_points, False)
metric_map = {} # summary_utils.flops_metric_map(end_points, False)
if FLAGS.use_act:
metric_map.update(utils.act_metric_map(end_points, False))
metric_map.update(summary_utils.act_metric_map(end_points, False))
for name, value in metric_map.iteritems():
tf.summary.scalar(name, value)

if FLAGS.use_act and FLAGS.sact:
utils.add_heatmaps_image_summary(end_points, border=10)
summary_utils.add_heatmaps_image_summary(end_points, border=10)

if FLAGS.sync_replicas:
sync_optimizer = opt
Expand Down
Loading

0 comments on commit ef2ee4b

Please sign in to comment.