Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/sparseml/pytorch/optim/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def loss_update(
optimizer: Optimizer,
epoch: float,
steps_per_epoch: int,
**kwargs,
) -> Tensor:
"""
Optional call that can be made on the optimizer to update the contained
Expand All @@ -468,13 +469,15 @@ def loss_update(
(calculate batch number using this and epoch)
:return: the modified loss tensor
"""
super().loss_update(loss, module, optimizer, epoch, steps_per_epoch)
super().loss_update(loss, module, optimizer, epoch, steps_per_epoch, **kwargs)

for mod in self._modifiers:
if not mod.enabled:
continue

loss = mod.loss_update(loss, module, optimizer, epoch, steps_per_epoch)
loss = mod.loss_update(
loss, module, optimizer, epoch, steps_per_epoch, **kwargs
)

return loss

Expand Down
1 change: 1 addition & 0 deletions src/sparseml/pytorch/optim/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def loss_update(
optimizer: Optimizer,
epoch: float,
steps_per_epoch: int,
**kwargs,
):
"""
Optional call that can be made on the optimizer to update the modifiers
Expand Down
102 changes: 35 additions & 67 deletions src/sparseml/pytorch/optim/modifier_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

import logging
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.nn.functional as TF
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from sparseml.optim import ModifierProp
from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier
from sparseml.pytorch.utils import BaseLogger, tensors_module_forward
from sparseml.pytorch.utils import BaseLogger, device_of, tensors_module_forward


__all__ = [
Expand Down Expand Up @@ -221,6 +222,9 @@ def loss_update(
optimizer: Optimizer,
epoch: float,
steps_per_epoch: int,
student_outputs: Union[Tensor, Dict, Iterable] = None,
teacher_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None,
**kwargs,
) -> Tensor:
"""
Updates the bass loss with the distillation loss
Expand All @@ -233,69 +237,57 @@ def loss_update(
(calculate batch number using this and epoch)
:return: loss tensor with knowledge distillation loss added
"""
loss = super().loss_update(loss, module, optimizer, epoch, steps_per_epoch)
loss = super().loss_update(
loss, module, optimizer, epoch, steps_per_epoch, **kwargs
)

if not self._distillation_enabled or self._disable_distillation:
return loss

if self._student_outputs is None or self._student_inputs is None:
raise RuntimeError(
"A forward pass of the module must be run before calling loss_update "
"with a DistillationModifier"
if student_outputs is None or teacher_inputs is None:
raise ValueError(
"Student outputs and teacher inputs are required for "
"distillation loss update"
)

# ensure that teacher model is in eval mode and on correct device
self._teacher.eval()
target_device = (
self._student_inputs.device
if isinstance(self._student_inputs, Tensor)
else self._student_inputs[0].device
if isinstance(self._student_inputs, Iterable)
else [
tens.device
for tens in self._student_inputs.values()
if isinstance(tens, Tensor)
][0]
)
target_device = device_of(teacher_inputs)
self._teacher.to(target_device)
with torch.no_grad():
teacher_outputs = tensors_module_forward(
teacher_inputs, self._teacher, check_feat_lab_inp=False
)

teacher_outputs = tensors_module_forward(
self._student_inputs, self._teacher, check_feat_lab_inp=False
)

assert type(self._student_outputs) == type(
teacher_outputs
), "Student and teacher models must have the same output type"
if type(student_outputs) != type(teacher_outputs):
raise ValueError(
"Student and teacher models must have the same output type"
)

distill_losses = []
if isinstance(self._student_outputs, Tensor):
if isinstance(student_outputs, Tensor):
distill_losses.append(
self._calc_distill_loss(self._student_outputs, teacher_outputs)
self._calc_distill_loss(student_outputs, teacher_outputs)
)
elif isinstance(self._student_outputs, Dict):
for key in self._distill_output_keys or self._student_outputs:
elif isinstance(student_outputs, Dict):
for key in self._distill_output_keys or student_outputs:
distill_losses.append(
self._calc_distill_loss(
self._student_outputs[key], teacher_outputs[key]
)
self._calc_distill_loss(student_outputs[key], teacher_outputs[key])
)
elif isinstance(self._student_outputs, Iterable):
for idx in self._distill_output_keys or range(len(self._student_outputs)):
elif isinstance(student_outputs, Iterable):
for idx in self._distill_output_keys or range(len(student_outputs)):
distill_losses.append(
self._calc_distill_loss(
self._student_outputs[idx], teacher_outputs[idx]
)
self._calc_distill_loss(student_outputs[idx], teacher_outputs[idx])
)

# get distillation loss as average of individual output distillation loss values
teacher_loss = sum(distill_losses) / len(distill_losses)
distillation_loss = ((1.0 - self._hardness) * loss) + (
self._hardness * teacher_loss
)

_log_losses(
self.loggers, epoch, steps_per_epoch, loss, teacher_loss, distillation_loss
)
global_step = kwargs.get("global_step")
global_step = epoch * steps_per_epoch if global_step is None else global_step
_log_losses(self.loggers, global_step, loss, teacher_loss, distillation_loss)
return distillation_loss

def finalize(
Expand Down Expand Up @@ -340,46 +332,22 @@ def _check_distillation_update(
"Using self distillation with copy of the module's current state"
)
self._teacher = deepcopy(module)
self._set_student_hook(module)
self._distillation_enabled = True

if self.end_pending(epoch, steps_per_epoch):
self._disable_student_hook()
self._distillation_enabled = False

def _set_student_hook(self, module: Module):
# delete hook if already exists
self._disable_student_hook()

def _track_inputs_and_outputs_hook(mod, inputs, outputs):
self._student_inputs = inputs
self._student_outputs = outputs

self._track_student_hook = module.register_forward_hook(
_track_inputs_and_outputs_hook
)

def _disable_student_hook(self):
if self._track_student_hook is not None:
self._track_student_hook.remove()
self._track_student_hook = None
self._student_inputs = None
self._student_outputs = None

def _is_distillation_epoch(self, epoch):
return self.start_epoch <= epoch < self.end_epoch


def _log_losses(
loggers: List[BaseLogger],
epoch: float,
steps_per_epoch: int,
global_step: int,
original_loss: float,
teacher_loss: float,
distillation_loss: float,
):
step = round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch)

losses = {
"original_loss": original_loss,
"teacher_loss": teacher_loss,
Expand All @@ -388,4 +356,4 @@ def _log_losses(

for logger in loggers:
for (name, loss) in losses.items():
logger.log_scalar(f"DistillationModifier/{name}", loss.item(), step)
logger.log_scalar(f"DistillationModifier/{name}", loss.item(), global_step)
14 changes: 14 additions & 0 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

__all__ = [
"default_device",
"device_of",
"get_optim_learning_rate",
"get_optim_groups_learning_rates",
"set_optim_learning_rate",
Expand Down Expand Up @@ -98,6 +99,19 @@ def default_device() -> str:
return "cuda:{}".format(",".join(device_ids))


def device_of(inputs: Any):
if isinstance(inputs, Tensor):
return inputs.device
elif isinstance(inputs, Dict):
for tens in inputs.values():
return device_of(tens)
elif isinstance(inputs, Iterable):
return device_of(inputs[0])
else:
raise RuntimeError("Unknown type of inputs to device_of function")
return default_device()


##############################
#
# pytorch optim helpers
Expand Down
35 changes: 20 additions & 15 deletions tests/sparseml/pytorch/optim/test_modifier_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,18 @@ def test_lifecycle(

# test distillation has been applied
# fake forward pass
fake_loss = model(self._get_fake_batch(model_lambda)).mean()
student_inputs = self._get_fake_batch(model_lambda)
student_outputs = model(student_inputs)
teacher_outputs = student_outputs + 0.5 # fake teacher model's outputs
fake_loss = student_outputs.mean()
updated_loss = modifier.loss_update(
fake_loss, model, optimizer, -1, test_steps_per_epoch
fake_loss,
model,
optimizer,
-1,
test_steps_per_epoch,
student_outputs,
teacher_outputs,
)

assert isinstance(updated_loss, torch.Tensor)
Expand All @@ -98,23 +107,19 @@ def test_loss_update(
model = model_lambda()
optimizer = optim_lambda(model)

with pytest.raises(RuntimeError):
modifier.loss_update(
test_loss, model, optimizer, test_epoch, test_steps_per_epoch
)

self.initialize_helper(modifier, model)

# should fail until a forward pass is run
with pytest.raises(RuntimeError):
modifier.loss_update(
test_loss, model, optimizer, test_epoch, test_steps_per_epoch
)

# run fake forward pass and try updating the loss
_ = model(self._get_fake_batch(model_lambda))
inputs = self._get_fake_batch(model_lambda)
student_outputs = model(inputs)
new_loss = modifier.loss_update(
test_loss, model, optimizer, test_epoch, test_steps_per_epoch
test_loss,
model,
optimizer,
test_epoch,
test_steps_per_epoch,
student_outputs,
inputs,
)

assert isinstance(new_loss, Tensor)
Expand Down