diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index a57a4ea6bcb8..77752a8af9c9 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -258,9 +258,12 @@ def _remove_activation_post_process(module): delattr(module, 'activation_post_process') # remove activation_post_proceess hook + handle_ids_to_remove = set() for handle_id, hook_fn in module._forward_hooks.items(): if hook_fn is _observer_forward_hook: - module._forward_hooks.pop(handle_id) + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + module._forward_hooks.pop(handle_id) # TODO: rename to something more general def _remove_qconfig(module):