Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
363 lines (314 sloc) 16.5 KB
"""TensorBoard callback for training visualization.
This is the TF v1 version. A subset of the functionality
also works with other backends.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import warnings
from .. import backend as K
from ..engine.training_utils import standardize_input_data
from . import Callback
class TensorBoard(Callback):
"""TensorBoard basic visualizations.
[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
is a visualization tool provided with TensorFlow.
This callback writes a log for TensorBoard, which allows
you to visualize dynamic graphs of your training and test
metrics, as well as activation histograms for the different
layers in your model.
If you have installed TensorFlow with pip, you should be able
to launch TensorBoard from the command line:
```sh
tensorboard --logdir=/full_path_to_your_logs
```
When using a backend other than TensorFlow, TensorBoard will still work
(if you have TensorFlow installed), but the only feature available will
be the display of the losses and metrics plots.
# Arguments
log_dir: the path of the directory where to save the log
files to be parsed by TensorBoard.
histogram_freq: frequency (in epochs) at which to compute activation
and weight histograms for the layers of the model. If set to 0,
histograms won't be computed. Validation data (or split) must be
specified for histogram visualizations.
batch_size: size of batch of inputs to feed to the network
for histograms computation.
write_graph: whether to visualize the graph in TensorBoard.
The log file can become quite large when
write_graph is set to True.
write_grads: whether to visualize gradient histograms in TensorBoard.
`histogram_freq` must be greater than 0.
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
layers will be saved. If set to 0, embeddings won't be computed.
Data to be visualized in TensorBoard's Embedding tab must be passed
as `embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
in which metadata for this embedding layer is saved. See the
[details](https://www.tensorflow.org/guide/embedding#metadata)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
embeddings_data: data to be embedded at layers specified in
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings](
https://www.tensorflow.org/guide/embedding).
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, writes
the losses and metrics to TensorBoard after each batch. The same
applies for `'epoch'`. If using an integer, let's say `10000`,
the callback will write the metrics and losses to TensorBoard every
10000 samples. Note that writing too frequently to TensorBoard
can slow down your training.
"""
def __init__(self, log_dir='./logs',
histogram_freq=0,
batch_size=32,
write_graph=True,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None,
update_freq='epoch'):
super(TensorBoard, self).__init__()
global tf, projector
try:
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
except ImportError:
raise ImportError('You need the TensorFlow (v1) module installed to '
'use TensorBoard.')
if K.backend() != 'tensorflow':
if histogram_freq != 0:
warnings.warn('You are not using the TensorFlow backend. '
'histogram_freq was set to 0')
histogram_freq = 0
if write_graph:
warnings.warn('You are not using the TensorFlow backend. '
'write_graph was set to False')
write_graph = False
if write_images:
warnings.warn('You are not using the TensorFlow backend. '
'write_images was set to False')
write_images = False
if embeddings_freq != 0:
warnings.warn('You are not using the TensorFlow backend. '
'embeddings_freq was set to 0')
embeddings_freq = 0
self.log_dir = log_dir
self.histogram_freq = histogram_freq
self.merged = None
self.write_graph = write_graph
self.write_grads = write_grads
self.write_images = write_images
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata or {}
self.batch_size = batch_size
self.embeddings_data = embeddings_data
if update_freq == 'batch':
# It is the same as writing as frequently as possible.
self.update_freq = 1
else:
self.update_freq = update_freq
self.samples_seen = 0
self.samples_seen_at_last_write = 0
def set_model(self, model):
self.model = model
if K.backend() == 'tensorflow':
self.sess = K.get_session()
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
mapped_weight_name = weight.name.replace(':', '_')
tf.summary.histogram(mapped_weight_name, weight)
if self.write_grads and weight in layer.trainable_weights:
grads = model.optimizer.get_gradients(model.total_loss,
weight)
def is_indexed_slices(grad):
return type(grad).__name__ == 'IndexedSlices'
grads = [
grad.values if is_indexed_slices(grad) else grad
for grad in grads]
tf.summary.histogram('{}_grad'.format(mapped_weight_name),
grads)
if self.write_images:
w_img = tf.squeeze(weight)
shape = K.int_shape(w_img)
if len(shape) == 2: # dense layer kernel case
if shape[0] > shape[1]:
w_img = tf.transpose(w_img)
shape = K.int_shape(w_img)
w_img = tf.reshape(w_img, [1,
shape[0],
shape[1],
1])
elif len(shape) == 3: # convnet case
if K.image_data_format() == 'channels_last':
# switch to channels_first to display
# every kernel as a separate image
w_img = tf.transpose(w_img, perm=[2, 0, 1])
shape = K.int_shape(w_img)
w_img = tf.reshape(w_img, [shape[0],
shape[1],
shape[2],
1])
elif len(shape) == 1: # bias case
w_img = tf.reshape(w_img, [1,
shape[0],
1,
1])
else:
# not possible to handle 3D convnets etc.
continue
shape = K.int_shape(w_img)
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
tf.summary.image(mapped_weight_name, w_img)
if hasattr(layer, 'output'):
if isinstance(layer.output, list):
for i, output in enumerate(layer.output):
tf.summary.histogram('{}_out_{}'.format(layer.name, i),
output)
else:
tf.summary.histogram('{}_out'.format(layer.name),
layer.output)
self.merged = tf.summary.merge_all()
if self.write_graph:
self.writer = tf.summary.FileWriter(self.log_dir,
self.sess.graph)
else:
self.writer = tf.summary.FileWriter(self.log_dir)
if self.embeddings_freq and self.embeddings_data is not None:
self.embeddings_data = standardize_input_data(self.embeddings_data,
model.input_names)
embeddings_layer_names = self.embeddings_layer_names
if not embeddings_layer_names:
embeddings_layer_names = [layer.name for layer in self.model.layers
if type(layer).__name__ == 'Embedding']
self.assign_embeddings = []
embeddings_vars = {}
self.batch_id = batch_id = tf.placeholder(tf.int32)
self.step = step = tf.placeholder(tf.int32)
for layer in self.model.layers:
if layer.name in embeddings_layer_names:
embedding_input = self.model.get_layer(layer.name).output
embedding_size = np.prod(embedding_input.shape[1:])
embedding_input = tf.reshape(embedding_input,
(step, int(embedding_size)))
shape = (self.embeddings_data[0].shape[0], int(embedding_size))
embedding = K.variable(K.zeros(shape),
name=layer.name + '_embedding')
embeddings_vars[layer.name] = embedding
batch = tf.assign(embedding[batch_id:batch_id + step],
embedding_input)
self.assign_embeddings.append(batch)
self.saver = tf.train.Saver(list(embeddings_vars.values()))
if not isinstance(self.embeddings_metadata, str):
embeddings_metadata = self.embeddings_metadata
else:
embeddings_metadata = {layer_name: self.embeddings_metadata
for layer_name in embeddings_vars.keys()}
config = projector.ProjectorConfig()
for layer_name, tensor in embeddings_vars.items():
embedding = config.embeddings.add()
embedding.tensor_name = tensor.name
if layer_name in embeddings_metadata:
embedding.metadata_path = embeddings_metadata[layer_name]
projector.visualize_embeddings(self.writer, config)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if not self.validation_data and self.histogram_freq:
raise ValueError("If printing histograms, validation_data must be "
"provided, and cannot be a generator.")
if self.embeddings_data is None and self.embeddings_freq:
raise ValueError("To visualize embeddings, embeddings_data must "
"be provided.")
if self.validation_data and self.histogram_freq:
if epoch % self.histogram_freq == 0:
val_data = self.validation_data
tensors = (self.model.inputs +
self.model.targets +
self.model.sample_weights)
if self.model.uses_learning_phase:
tensors += [K.learning_phase()]
assert len(val_data) == len(tensors)
val_size = val_data[0].shape[0]
i = 0
while i < val_size:
step = min(self.batch_size, val_size - i)
if self.model.uses_learning_phase:
# do not slice the learning phase
batch_val = [x[i:i + step] for x in val_data[:-1]]
batch_val.append(val_data[-1])
else:
batch_val = [x[i:i + step] for x in val_data]
assert len(batch_val) == len(tensors)
feed_dict = dict(zip(tensors, batch_val))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)
i += self.batch_size
if self.embeddings_freq and self.embeddings_data is not None:
if epoch % self.embeddings_freq == 0:
# We need a second forward-pass here because we're passing
# the `embeddings_data` explicitly. This design allows to pass
# arbitrary data as `embeddings_data` and results from the fact
# that we need to know the size of the `tf.Variable`s which
# hold the embeddings in `set_model`. At this point, however,
# the `validation_data` is not yet set.
# More details in this discussion:
# https://github.com/keras-team/keras/pull/7766#issuecomment-329195622
embeddings_data = self.embeddings_data
n_samples = embeddings_data[0].shape[0]
i = 0
while i < n_samples:
step = min(self.batch_size, n_samples - i)
batch = slice(i, i + step)
if type(self.model.input) == list:
feed_dict = {_input: embeddings_data[idx][batch]
for idx, _input in enumerate(self.model.input)}
else:
feed_dict = {self.model.input: embeddings_data[0][batch]}
feed_dict.update({self.batch_id: i, self.step: step})
if self.model.uses_learning_phase:
feed_dict[K.learning_phase()] = False
self.sess.run(self.assign_embeddings, feed_dict=feed_dict)
self.saver.save(self.sess,
os.path.join(self.log_dir,
'keras_embedding.ckpt'),
epoch)
i += self.batch_size
if self.update_freq == 'epoch':
index = epoch
else:
index = self.samples_seen
self._write_logs(logs, index)
def _write_logs(self, logs, index):
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
if isinstance(value, np.ndarray):
summary_value.simple_value = value.item()
else:
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, index)
self.writer.flush()
def on_train_end(self, _):
self.writer.close()
def on_batch_end(self, batch, logs=None):
if self.update_freq != 'epoch':
self.samples_seen += logs['size']
samples_seen_since = self.samples_seen - self.samples_seen_at_last_write
if samples_seen_since >= self.update_freq:
self._write_logs(logs, self.samples_seen)
self.samples_seen_at_last_write = self.samples_seen
You can’t perform that action at this time.