From 007d1ceb5b3b5b36e1e20a9d51d5c4847e9e6a5d Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 4 Nov 2020 18:53:12 -0800 Subject: [PATCH] Added ModelPruningListener and ModelPruningHook to allow running Python pruning updates in tf.estimator.Estimator. PiperOrigin-RevId: 340766821 --- model_pruning/__init__.py | 2 + model_pruning/python/pruning_hook.py | 116 ++++++++++++++++++++++ model_pruning/python/pruning_hook_test.py | 84 ++++++++++++++++ 3 files changed, 202 insertions(+) create mode 100644 model_pruning/python/pruning_hook.py create mode 100644 model_pruning/python/pruning_hook_test.py diff --git a/model_pruning/__init__.py b/model_pruning/__init__.py index 6744fcb319f..c3bf85c139b 100644 --- a/model_pruning/__init__.py +++ b/model_pruning/__init__.py @@ -30,6 +30,8 @@ from model_pruning.python.pruning import get_weight_sparsity from model_pruning.python.pruning import get_weights from model_pruning.python.pruning import Pruning +from model_pruning.python.pruning_hook import ModelPruningHook +from model_pruning.python.pruning_hook import ModelPruningListener from model_pruning.python.pruning_interface import apply_matrix_compression from model_pruning.python.pruning_interface import apply_pruning from model_pruning.python.pruning_interface import get_matrix_compression_object diff --git a/model_pruning/python/pruning_hook.py b/model_pruning/python/pruning_hook.py new file mode 100644 index 00000000000..9fe4a982526 --- /dev/null +++ b/model_pruning/python/pruning_hook.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hooks for model pruning. + +Model pruning hooks are used in estimators (instances of tf.estimator.Estimator) +to explicitly update the graph. +""" +import tensorflow.compat.v1 as tf + + +class ModelPruningListener(tf.estimator.CheckpointSaverListener): + """Listener class for ModelPruningHook. + + Used for pruning python update functions that are run periodically. + """ + + def __init__(self, pruning_obj): + """Initializer. + + Args: + pruning_obj: Pruning object whose update function needs to be run. + """ + self.pruning_obj = pruning_obj + + def before_save(self, session, global_step_value): + """Before save processing.""" + # Disable all the protected-access violations in this function as + # need to unfinalize the graph to call run_update_step. + # pylint: disable=protected-access + session.graph._unsafe_unfinalize() + self.pruning_obj.run_update_step(session, global_step_value) + + +class ModelPruningHook(tf.estimator.SessionRunHook): + """Prune the model every N steps.""" + + _STEPS_PER_RUN = 1 + + def __init__(self, every_steps=None, listeners=None): + """Initialize a `ModelPruningHook`. + + Args: + every_steps: `int`, prune every N steps. + listeners: List of `ModelPruningListener` subclass instances. + """ + tf.logging.info("Creating ModelPruningHook.") + self._every_steps = every_steps + self._listeners = listeners + self._timer = tf.estimator.SecondOrStepTimer(every_steps=every_steps) + + def _call_prune_listener(self, session, step): + """Calls model pruning listeners, return should_step_training.""" + tf.logging.info("Calling model pruning listeners at step %d...", + step) + for listener in self._listeners: + listener.before_save(session, step) + + should_stop_training = False + for listener in self._listeners: + if listener.after_save(session, step): + tf.logging.info( + "A model pruning listener requested that training be stopped. " + "listener: {}".format(listener)) + should_stop_training = True + return should_stop_training + + def begin(self): + self._global_step_tensor = tf.compat.v1.train.get_or_create_global_step() + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use ModelPruningHook.") + for l in self._listeners: + l.begin() + + def after_create_session(self, session, coord): + global_step = session.run(self._global_step_tensor) + self._call_prune_listener(session, global_step) + self._timer.update_last_triggered_step(global_step) + + def before_run(self, run_context): # pylint: disable=unused-argument + return tf.estimator.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + stale_global_step = run_values.results + if not self._timer.should_trigger_for_step(stale_global_step + + self._STEPS_PER_RUN): + return + + # Get the real value after train op. + global_step = run_context.session.run(self._global_step_tensor) + if not self._timer.should_trigger_for_step(global_step): + return + + self._timer.update_last_triggered_step(global_step) + if self._call_prune_listener(run_context.session, global_step): + run_context.request_stop() + + def end(self, session): + last_step = session.run(self._global_step_tensor) + if last_step != self._timer.last_triggered_step(): + self._call_prune_listener(session, last_step) + for l in self._listeners: + l.end(session, last_step) diff --git a/model_pruning/python/pruning_hook_test.py b/model_pruning/python/pruning_hook_test.py new file mode 100644 index 00000000000..d57dd54b609 --- /dev/null +++ b/model_pruning/python/pruning_hook_test.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ModelPruningHook.""" + +import tensorflow.compat.v1 as tf + +from model_pruning.python import pruning_hook + + +class MockPruningObject(object): + """Mock Pruning Object that has a run_update_step() function.""" + + def __init__(self): + self.logged_steps = [] + + def run_update_step(self, session, global_step): # pylint: disable=unused-argument + self.logged_steps.append(global_step) + + +class PruningHookTest(tf.test.TestCase): + + def test_prune_after_session_creation(self): + every_steps = 10 + pruning_obj = MockPruningObject() + listener = pruning_hook.ModelPruningListener(pruning_obj) + hook = pruning_hook.ModelPruningHook(every_steps=every_steps, + listeners=[listener]) + mon_sess = tf.train.MonitoredSession(hooks=[hook]) # pylint: disable=unused-variable. + self.evaluate(tf.global_variables_initializer()) + + self.assertEqual(len(pruning_obj.logged_steps), 1) + self.assertEqual(pruning_obj.logged_steps[0], 0) + + def test_prune_every_n_steps(self): + every_steps = 10 + pruning_obj = MockPruningObject() + + with tf.Graph().as_default(): + listener = pruning_hook.ModelPruningListener(pruning_obj) + hook = pruning_hook.ModelPruningHook(every_steps=every_steps, + listeners=[listener]) + global_step = tf.train.get_or_create_global_step() + train_op = tf.constant(0) + global_step_increment_op = tf.assign_add(global_step, 1) + with tf.train.MonitoredSession(tf.train.ChiefSessionCreator(), + hooks=[hook]) as mon_sess: + mon_sess.run(tf.global_variables_initializer()) + + mon_sess.run(train_op) + mon_sess.run(global_step_increment_op) + # ModelPruningHook runs once after session creation, at step 0. + self.assertEqual(len(pruning_obj.logged_steps), 1) + self.assertEqual(pruning_obj.logged_steps[0], 0) + + for _ in range(every_steps-1): + mon_sess.run(train_op) + mon_sess.run(global_step_increment_op) + + self.assertEqual(len(pruning_obj.logged_steps), 2) + self.assertSameElements(pruning_obj.logged_steps, [0, every_steps]) + + for _ in range(every_steps-1): + mon_sess.run(train_op) + mon_sess.run(global_step_increment_op) + + self.assertEqual(len(pruning_obj.logged_steps), 2) + self.assertSameElements(pruning_obj.logged_steps, [0, every_steps]) + + +if __name__ == '__main__': + tf.test.main()