From 8e79ebffb03285450ce2a981c509684e6bf98de7 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Sun, 3 Oct 2021 23:23:12 -0400 Subject: [PATCH 1/2] Revised distillation modifier --- src/sparseml/pytorch/optim/manager.py | 7 +- src/sparseml/pytorch/optim/modifier.py | 1 + .../pytorch/optim/modifier_distillation.py | 99 +++++-------------- src/sparseml/pytorch/utils/helpers.py | 14 +++ .../optim/test_modifier_distillation.py | 24 +++-- 5 files changed, 58 insertions(+), 87 deletions(-) diff --git a/src/sparseml/pytorch/optim/manager.py b/src/sparseml/pytorch/optim/manager.py index 272dc0d0643..ed3a74fa84a 100644 --- a/src/sparseml/pytorch/optim/manager.py +++ b/src/sparseml/pytorch/optim/manager.py @@ -455,6 +455,7 @@ def loss_update( optimizer: Optimizer, epoch: float, steps_per_epoch: int, + **kwargs, ) -> Tensor: """ Optional call that can be made on the optimizer to update the contained @@ -468,13 +469,15 @@ def loss_update( (calculate batch number using this and epoch) :return: the modified loss tensor """ - super().loss_update(loss, module, optimizer, epoch, steps_per_epoch) + super().loss_update(loss, module, optimizer, epoch, steps_per_epoch, **kwargs) for mod in self._modifiers: if not mod.enabled: continue - loss = mod.loss_update(loss, module, optimizer, epoch, steps_per_epoch) + loss = mod.loss_update( + loss, module, optimizer, epoch, steps_per_epoch, **kwargs + ) return loss diff --git a/src/sparseml/pytorch/optim/modifier.py b/src/sparseml/pytorch/optim/modifier.py index ec9ef674fa6..b4ad053f441 100644 --- a/src/sparseml/pytorch/optim/modifier.py +++ b/src/sparseml/pytorch/optim/modifier.py @@ -287,6 +287,7 @@ def loss_update( optimizer: Optimizer, epoch: float, steps_per_epoch: int, + **kwargs, ): """ Optional call that can be made on the optimizer to update the modifiers diff --git a/src/sparseml/pytorch/optim/modifier_distillation.py b/src/sparseml/pytorch/optim/modifier_distillation.py index 68fe02404d4..16459afcd70 100644 --- a/src/sparseml/pytorch/optim/modifier_distillation.py +++ b/src/sparseml/pytorch/optim/modifier_distillation.py @@ -19,7 +19,7 @@ import logging from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union import torch.nn.functional as TF from torch import Tensor @@ -28,7 +28,7 @@ from sparseml.optim import ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier -from sparseml.pytorch.utils import BaseLogger, tensors_module_forward +from sparseml.pytorch.utils import BaseLogger __all__ = [ @@ -221,6 +221,9 @@ def loss_update( optimizer: Optimizer, epoch: float, steps_per_epoch: int, + student_outputs: Union[Tensor, Dict, Iterable] = None, + teacher_outputs: Union[Tensor, Dict, Iterable] = None, + **kwargs, ) -> Tensor: """ Updates the bass loss with the distillation loss @@ -233,58 +236,35 @@ def loss_update( (calculate batch number using this and epoch) :return: loss tensor with knowledge distillation loss added """ - loss = super().loss_update(loss, module, optimizer, epoch, steps_per_epoch) + loss = super().loss_update( + loss, module, optimizer, epoch, steps_per_epoch, **kwargs + ) if not self._distillation_enabled or self._disable_distillation: return loss - if self._student_outputs is None or self._student_inputs is None: - raise RuntimeError( - "A forward pass of the module must be run before calling loss_update " - "with a DistillationModifier" - ) - - # ensure that teacher model is in eval mode and on correct device - self._teacher.eval() - target_device = ( - self._student_inputs.device - if isinstance(self._student_inputs, Tensor) - else self._student_inputs[0].device - if isinstance(self._student_inputs, Iterable) - else [ - tens.device - for tens in self._student_inputs.values() - if isinstance(tens, Tensor) - ][0] - ) - self._teacher.to(target_device) - - teacher_outputs = tensors_module_forward( - self._student_inputs, self._teacher, check_feat_lab_inp=False - ) + if student_outputs is None or teacher_outputs is None: + return loss - assert type(self._student_outputs) == type( - teacher_outputs - ), "Student and teacher models must have the same output type" + if type(student_outputs) != type(teacher_outputs): + raise ValueError( + "Student and teacher models must have the same output type" + ) distill_losses = [] - if isinstance(self._student_outputs, Tensor): + if isinstance(student_outputs, Tensor): distill_losses.append( - self._calc_distill_loss(self._student_outputs, teacher_outputs) + self._calc_distill_loss(student_outputs, teacher_outputs) ) - elif isinstance(self._student_outputs, Dict): - for key in self._distill_output_keys or self._student_outputs: + elif isinstance(student_outputs, Dict): + for key in self._distill_output_keys or student_outputs: distill_losses.append( - self._calc_distill_loss( - self._student_outputs[key], teacher_outputs[key] - ) + self._calc_distill_loss(student_outputs[key], teacher_outputs[key]) ) - elif isinstance(self._student_outputs, Iterable): - for idx in self._distill_output_keys or range(len(self._student_outputs)): + elif isinstance(student_outputs, Iterable): + for idx in self._distill_output_keys or range(len(student_outputs)): distill_losses.append( - self._calc_distill_loss( - self._student_outputs[idx], teacher_outputs[idx] - ) + self._calc_distill_loss(student_outputs[idx], teacher_outputs[idx]) ) # get distillation loss as average of individual output distillation loss values @@ -292,10 +272,9 @@ def loss_update( distillation_loss = ((1.0 - self._hardness) * loss) + ( self._hardness * teacher_loss ) - - _log_losses( - self.loggers, epoch, steps_per_epoch, loss, teacher_loss, distillation_loss - ) + global_step = kwargs.get("global_step") + global_step = epoch * steps_per_epoch if global_step is None else global_step + _log_losses(self.loggers, global_step, loss, teacher_loss, distillation_loss) return distillation_loss def finalize( @@ -340,46 +319,22 @@ def _check_distillation_update( "Using self distillation with copy of the module's current state" ) self._teacher = deepcopy(module) - self._set_student_hook(module) self._distillation_enabled = True if self.end_pending(epoch, steps_per_epoch): - self._disable_student_hook() self._distillation_enabled = False - def _set_student_hook(self, module: Module): - # delete hook if already exists - self._disable_student_hook() - - def _track_inputs_and_outputs_hook(mod, inputs, outputs): - self._student_inputs = inputs - self._student_outputs = outputs - - self._track_student_hook = module.register_forward_hook( - _track_inputs_and_outputs_hook - ) - - def _disable_student_hook(self): - if self._track_student_hook is not None: - self._track_student_hook.remove() - self._track_student_hook = None - self._student_inputs = None - self._student_outputs = None - def _is_distillation_epoch(self, epoch): return self.start_epoch <= epoch < self.end_epoch def _log_losses( loggers: List[BaseLogger], - epoch: float, - steps_per_epoch: int, + global_step: int, original_loss: float, teacher_loss: float, distillation_loss: float, ): - step = round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch) - losses = { "original_loss": original_loss, "teacher_loss": teacher_loss, @@ -388,4 +343,4 @@ def _log_losses( for logger in loggers: for (name, loss) in losses.items(): - logger.log_scalar(f"DistillationModifier/{name}", loss.item(), step) + logger.log_scalar(f"DistillationModifier/{name}", loss.item(), global_step) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 56e93154f90..7a93832e973 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -42,6 +42,7 @@ __all__ = [ "default_device", + "device_of", "get_optim_learning_rate", "get_optim_groups_learning_rates", "set_optim_learning_rate", @@ -98,6 +99,19 @@ def default_device() -> str: return "cuda:{}".format(",".join(device_ids)) +def device_of(inputs: Any): + if isinstance(inputs, Tensor): + return inputs.device + elif isinstance(inputs, Dict): + for tens in inputs.values(): + return device_of(tens) + elif isinstance(inputs, Iterable): + return device_of(inputs[0]) + else: + raise RuntimeError("Unknown type of inputs to device_of function") + return default_device() + + ############################## # # pytorch optim helpers diff --git a/tests/sparseml/pytorch/optim/test_modifier_distillation.py b/tests/sparseml/pytorch/optim/test_modifier_distillation.py index 3454afc0f81..d6afb5156f2 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_distillation.py +++ b/tests/sparseml/pytorch/optim/test_modifier_distillation.py @@ -70,9 +70,18 @@ def test_lifecycle( # test distillation has been applied # fake forward pass - fake_loss = model(self._get_fake_batch(model_lambda)).mean() + student_inputs = self._get_fake_batch(model_lambda) + student_outputs = model(student_inputs) + teacher_outputs = student_outputs + 0.5 # fake teacher model's outputs + fake_loss = student_outputs.mean() updated_loss = modifier.loss_update( - fake_loss, model, optimizer, -1, test_steps_per_epoch + fake_loss, + model, + optimizer, + -1, + test_steps_per_epoch, + student_outputs, + teacher_outputs, ) assert isinstance(updated_loss, torch.Tensor) @@ -98,19 +107,8 @@ def test_loss_update( model = model_lambda() optimizer = optim_lambda(model) - with pytest.raises(RuntimeError): - modifier.loss_update( - test_loss, model, optimizer, test_epoch, test_steps_per_epoch - ) - self.initialize_helper(modifier, model) - # should fail until a forward pass is run - with pytest.raises(RuntimeError): - modifier.loss_update( - test_loss, model, optimizer, test_epoch, test_steps_per_epoch - ) - # run fake forward pass and try updating the loss _ = model(self._get_fake_batch(model_lambda)) new_loss = modifier.loss_update( From fd23713704d0763fa67dd439d65e40ec306b139b Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 6 Oct 2021 13:36:22 -0400 Subject: [PATCH 2/2] Move teacher model's logic back to modifier --- .../pytorch/optim/modifier_distillation.py | 21 +++++++++++++++---- .../optim/test_modifier_distillation.py | 11 ++++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/sparseml/pytorch/optim/modifier_distillation.py b/src/sparseml/pytorch/optim/modifier_distillation.py index 16459afcd70..f4affac2086 100644 --- a/src/sparseml/pytorch/optim/modifier_distillation.py +++ b/src/sparseml/pytorch/optim/modifier_distillation.py @@ -21,6 +21,7 @@ from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Union +import torch import torch.nn.functional as TF from torch import Tensor from torch.nn import Module @@ -28,7 +29,7 @@ from sparseml.optim import ModifierProp from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier -from sparseml.pytorch.utils import BaseLogger +from sparseml.pytorch.utils import BaseLogger, device_of, tensors_module_forward __all__ = [ @@ -222,7 +223,7 @@ def loss_update( epoch: float, steps_per_epoch: int, student_outputs: Union[Tensor, Dict, Iterable] = None, - teacher_outputs: Union[Tensor, Dict, Iterable] = None, + teacher_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None, **kwargs, ) -> Tensor: """ @@ -243,8 +244,20 @@ def loss_update( if not self._distillation_enabled or self._disable_distillation: return loss - if student_outputs is None or teacher_outputs is None: - return loss + if student_outputs is None or teacher_inputs is None: + raise ValueError( + "Student outputs and teacher inputs are required for " + "distillation loss update" + ) + + # ensure that teacher model is in eval mode and on correct device + self._teacher.eval() + target_device = device_of(teacher_inputs) + self._teacher.to(target_device) + with torch.no_grad(): + teacher_outputs = tensors_module_forward( + teacher_inputs, self._teacher, check_feat_lab_inp=False + ) if type(student_outputs) != type(teacher_outputs): raise ValueError( diff --git a/tests/sparseml/pytorch/optim/test_modifier_distillation.py b/tests/sparseml/pytorch/optim/test_modifier_distillation.py index d6afb5156f2..a9111360680 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_distillation.py +++ b/tests/sparseml/pytorch/optim/test_modifier_distillation.py @@ -110,9 +110,16 @@ def test_loss_update( self.initialize_helper(modifier, model) # run fake forward pass and try updating the loss - _ = model(self._get_fake_batch(model_lambda)) + inputs = self._get_fake_batch(model_lambda) + student_outputs = model(inputs) new_loss = modifier.loss_update( - test_loss, model, optimizer, test_epoch, test_steps_per_epoch + test_loss, + model, + optimizer, + test_epoch, + test_steps_per_epoch, + student_outputs, + inputs, ) assert isinstance(new_loss, Tensor)