From aa18604fec4a309658bcc83b6aa595b5f0838d36 Mon Sep 17 00:00:00 2001 From: David Vetrano Date: Wed, 11 Jan 2017 08:56:01 -0800 Subject: [PATCH] add tf version detection for new summary functions (#4989) --- keras/callbacks.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index 9124a4dcd4d..fea19f6c747 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -543,7 +543,10 @@ def set_model(self, model): for layer in self.model.layers: for weight in layer.weights: - tf.histogram_summary(weight.name, weight) + if hasattr(tf, 'histogram_summary'): + tf.histogram_summary(weight.name, weight) + else: + tf.summary.histogram(weight.name, weight) if self.write_images: w_img = tf.squeeze(weight) @@ -557,17 +560,26 @@ def set_model(self, model): w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1) - tf.image_summary(weight.name, w_img) + if hasattr(tf, 'image_summary'): + tf.image_summary(weight.name, w_img) + else: + tf.summary.image(weight.name, w_img) if hasattr(layer, 'output'): - tf.histogram_summary('{}_out'.format(layer.name), - layer.output) - if parse_version(tf.__version__) >= parse_version('0.12.0'): - self.merged = tf.summary.merge_all() - else: + if hasattr(tf, 'histogram_summary'): + tf.histogram_summary('{}_out'.format(layer.name), + layer.output) + else: + tf.summary.histogram('{}_out'.format(layer.name), + layer.output) + + if hasattr(tf, 'merge_all_summaries'): self.merged = tf.merge_all_summaries() + else: + self.merged = tf.summary.merge_all() + if self.write_graph: - if parse_version(tf.__version__) >= parse_version('0.12.0'): + if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'): self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) elif parse_version(tf.__version__) >= parse_version('0.8.0'): @@ -577,7 +589,7 @@ def set_model(self, model): self.writer = tf.train.SummaryWriter(self.log_dir, self.sess.graph_def) else: - if parse_version(tf.__version__) >= parse_version('0.12.0'): + if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'): self.writer = tf.summary.FileWriter(self.log_dir) else: self.writer = tf.train.SummaryWriter(self.log_dir)