Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation histograms #17624

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,7 +2324,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):

* Metrics summary plots
* Training graph visualization
* Weight histograms
* Activation and Weight histograms
* Sampled profiling

When used in `Model.evaluate`, in addition to epoch summaries, there will be
Expand All @@ -2346,10 +2346,11 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir,
'logs') This directory should not be reused by any other callbacks.
histogram_freq: frequency (in epochs) at which to compute
histogram_freq: frequency (in epochs) at which to compute activation and
weight histograms for the layers of the model. If set to 0, histograms
won't be computed. Validation data (or split) must be specified for
histogram visualizations.
histogram visualizations. Only weight histograms are supported for
subclassed models.
write_graph: whether to visualize the graph in TensorBoard. The log file
can become quite large when write_graph is set to True.
write_images: whether to write model weights to visualize as image in
Expand Down Expand Up @@ -2488,6 +2489,12 @@ def __init__(
# Used to restore any existing `SummaryWriter` after training ends.
self._prev_summary_state = []

# Used to track activation values for histogram.
self._activations = {}

# Used to cache the original layer call methods.
self._layer_calls = {"inner": {}, "outer": {}}

def _validate_kwargs(self, kwargs):
"""Handle arguments were supported in V1."""
if kwargs.get("write_grads", False):
Expand Down Expand Up @@ -2549,6 +2556,8 @@ def set_model(self, model):
self._should_write_train_graph = True
if self.embeddings_freq:
self._configure_embeddings()
if self.histogram_freq:
self._configure_layer_calls()

@property
def _train_writer(self):
Expand Down Expand Up @@ -2736,6 +2745,41 @@ def _init_profile_batch(self, profile_batch):
self._start_batch == 0 and self._stop_batch == 0
)

def _configure_layer_calls(self):
"""Configures the layer call functions to record activations."""
for layer in self.model.layers:
if not layer.trainable_variables:
continue

self._activations[layer.name] = tf.Variable(
initial_value=float("nan"),
trainable=False,
dtype=layer.output.dtype,
shape=tf.TensorShape(None),
)
self._layer_calls["inner"][layer.name] = layer.call

def outer_call(
inputs, *args, layer=layer, layer_call=layer.call, **kwargs
):
outputs = layer_call(inputs, *args, **kwargs)
self._activations[layer.name].assign(outputs)
return outputs

self._layer_calls["outer"][layer.name] = outer_call

def _override_layer_calls(self):
"""Overrides the `call` method of each layer to record activations."""
for layer in self.model.layers:
if layer.name in self._layer_calls["outer"]:
layer.call = self._layer_calls["outer"][layer.name]

def _restore_layer_calls(self):
"""Restores the `call` method of each layer to its original value."""
for layer in self.model.layers:
if layer.name in self._layer_calls["inner"]:
layer.call = self._layer_calls["inner"][layer.name]

def on_train_begin(self, logs=None):
self._global_train_batch = 0
self._previous_epoch_iterations = 0
Expand Down Expand Up @@ -2779,6 +2823,9 @@ def on_train_batch_begin(self, batch, logs=None):
if self._global_train_batch == self._start_batch:
self._start_trace()

if self.histogram_freq:
self._override_layer_calls()

def on_train_batch_end(self, batch, logs=None):
if self._should_write_train_graph:
self._write_keras_model_train_graph()
Expand All @@ -2805,6 +2852,9 @@ def on_train_batch_end(self, batch, logs=None):
if self._is_tracing and self._global_train_batch >= self._stop_batch:
self._stop_trace()

if self.histogram_freq:
self._restore_layer_calls()

def on_epoch_begin(self, epoch, logs=None):
# Keeps track of epoch for profiling.
if self.write_steps_per_second:
Expand Down Expand Up @@ -2902,6 +2952,17 @@ def _log_weights(self, epoch):
self._log_weight_as_image(
weight, image_weight_name, epoch
)
if layer.name in self._activations:
activation_name = layer.name + "/activations"
# Add a suffix to prevent summary tag name collision.
histogram_activation_name = (
activation_name + "/histogram"
)
tf.summary.histogram(
histogram_activation_name,
self._activations[layer.name],
step=epoch,
)
self._train_writer.flush()

def _log_weight_as_image(self, weight, weight_name, epoch):
Expand Down
73 changes: 54 additions & 19 deletions keras/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,6 @@ def get_input_datasets():
return model, train_ds, callback, filepath

def _run_load_weights_on_restart_test_common_iterations(self):

(
model,
train_ds,
Expand Down Expand Up @@ -3162,15 +3161,33 @@ def test_TensorBoard_weight_histograms(self):
),
},
)
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag="bias_0/histogram"),
_ObservedSummary(
logdir=self.train_dir, tag="kernel_0/histogram"
),
},
)
if "subclass" not in model_type:
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(
logdir=self.train_dir, tag="bias_0/histogram"
),
_ObservedSummary(
logdir=self.train_dir, tag="kernel_0/histogram"
),
_ObservedSummary(
logdir=self.train_dir, tag="activations/histogram"
),
},
)
else:
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(
logdir=self.train_dir, tag="bias_0/histogram"
),
_ObservedSummary(
logdir=self.train_dir, tag="kernel_0/histogram"
),
},
)

def test_TensorBoard_weight_images(self):
model = self._get_model()
Expand Down Expand Up @@ -3201,15 +3218,33 @@ def test_TensorBoard_weight_images(self):
),
},
)
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag="bias_0/histogram"),
_ObservedSummary(
logdir=self.train_dir, tag="kernel_0/histogram"
),
},
)
if "subclass" not in model_type:
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(
logdir=self.train_dir, tag="bias_0/histogram"
),
_ObservedSummary(
logdir=self.train_dir, tag="kernel_0/histogram"
),
_ObservedSummary(
logdir=self.train_dir, tag="activations/histogram"
),
},
)
else:
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(
logdir=self.train_dir, tag="bias_0/histogram"
),
_ObservedSummary(
logdir=self.train_dir, tag="kernel_0/histogram"
),
},
)
if summary_file.convert_from_v2_summary_proto:
expected_image_summaries = {
_ObservedSummary(logdir=self.train_dir, tag="bias_0/image"),
Expand Down