Skip to content

Commit

Permalink
add tf version detection for new summary functions (#4989)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidvetrano authored and fchollet committed Jan 11, 2017
1 parent 875bc59 commit aa18604
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions keras/callbacks.py
Expand Up @@ -543,7 +543,10 @@ def set_model(self, model):
for layer in self.model.layers:

for weight in layer.weights:
tf.histogram_summary(weight.name, weight)
if hasattr(tf, 'histogram_summary'):
tf.histogram_summary(weight.name, weight)
else:
tf.summary.histogram(weight.name, weight)

if self.write_images:
w_img = tf.squeeze(weight)
Expand All @@ -557,17 +560,26 @@ def set_model(self, model):

w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1)

tf.image_summary(weight.name, w_img)
if hasattr(tf, 'image_summary'):
tf.image_summary(weight.name, w_img)
else:
tf.summary.image(weight.name, w_img)

if hasattr(layer, 'output'):
tf.histogram_summary('{}_out'.format(layer.name),
layer.output)
if parse_version(tf.__version__) >= parse_version('0.12.0'):
self.merged = tf.summary.merge_all()
else:
if hasattr(tf, 'histogram_summary'):
tf.histogram_summary('{}_out'.format(layer.name),
layer.output)
else:
tf.summary.histogram('{}_out'.format(layer.name),
layer.output)

if hasattr(tf, 'merge_all_summaries'):
self.merged = tf.merge_all_summaries()
else:
self.merged = tf.summary.merge_all()

if self.write_graph:
if parse_version(tf.__version__) >= parse_version('0.12.0'):
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
self.writer = tf.summary.FileWriter(self.log_dir,
self.sess.graph)
elif parse_version(tf.__version__) >= parse_version('0.8.0'):
Expand All @@ -577,7 +589,7 @@ def set_model(self, model):
self.writer = tf.train.SummaryWriter(self.log_dir,
self.sess.graph_def)
else:
if parse_version(tf.__version__) >= parse_version('0.12.0'):
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
self.writer = tf.summary.FileWriter(self.log_dir)
else:
self.writer = tf.train.SummaryWriter(self.log_dir)
Expand Down

0 comments on commit aa18604

Please sign in to comment.