Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sparseml/core/modifier/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,11 @@ def initialize(self, state: "State", **kwargs):
if self.applied:
return

accelerator = kwargs.get("accelerator", None)
for modifier in self.modifiers:
modifier.initialize(state, **kwargs)
if accelerator:
accelerator.wait_for_everyone()
state.loggers.system.info(tag="stage", string="Modifiers initialized")

def finalize(self, state: "State", **kwargs):
Expand All @@ -153,8 +156,11 @@ def finalize(self, state: "State", **kwargs):
if self.applied:
return

accelerator = kwargs.get("accelerator", None)
for modifier in self.modifiers:
modifier.finalize(state, **kwargs)
if accelerator:
accelerator.wait_for_everyone()

self.applied = True
state.loggers.system.info(tag="stage", string="Modifiers finalized")
Expand Down
34 changes: 22 additions & 12 deletions src/sparseml/modifiers/smoothquant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sparseml.core.model.pytorch import ModifiableModelPyTorch
from sparseml.modifiers.smoothquant.base import SmoothQuantModifier, SmoothQuantScale
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
from sparseml.utils.fsdp.helpers import get_fsdp_parent


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,7 +57,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self._setup_scale_hooks()
self._calibrate(state.model, calibration_dataloader)
self._apply_smoothing()
self._apply_smoothing(state.model)

return True

Expand Down Expand Up @@ -138,7 +139,7 @@ def _calibrate(self, model: ModifiableModelPyTorch, calibration_dataloader: List
del self.hooks_

@torch.no_grad()
def _apply_smoothing(self):
def _apply_smoothing(self, model: ModifiableModelPyTorch):
"""
After calibration, apply smoothing to the activations and push the transform
into the following weights by applying the inverse to each balance weight.
Expand All @@ -162,17 +163,26 @@ def _apply_smoothing(self):
scales, torch.Tensor([MINIMUM_SMOOTHING_SCALE]).to(scales.device)
)

# invert the smoothing in the following layers
for layer in balance_layers:
layer.weight.mul_(scales.view(1, -1))

# apply the smoothing
if smooth_layer.weight.ndim == 1:
smooth_layer.weight.div_(scales)
@torch.no_grad()
def smooth(module):
if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

parent = get_fsdp_parent(mapping.smooth_name, model.model)
if parent is not None:
parent.apply(smooth)
else:
smooth_layer.weight.div_(scales.view(-1, 1))
if hasattr(smooth_layer, "bias") and smooth_layer.bias is not None:
smooth_layer.bias.div_(scales)
# if we're not running with FSDP we can apply smoothing directly
for layer in balance_layers:
smooth(layer)
smooth(smooth_layer)

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
calib_data=calib_data,
start=-1,
copy_data=False,
accelerator=self.accelerator,
)

self.accelerator.wait_for_everyone()
Expand Down
30 changes: 30 additions & 0 deletions src/sparseml/utils/fsdp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from typing import Optional


try:
from torch.distributed.fsdp import (
FullStateDictConfig,
Expand All @@ -33,6 +37,7 @@
"maybe_get_wrapped",
"unwrap_and_export_model",
"save_pretrained_fsdp",
"get_fsdp_parent",
]


Expand Down Expand Up @@ -124,3 +129,28 @@ def save_pretrained_fsdp(model, accelerator, output_dir):
save_function=accelerator.save,
state_dict=state_dict,
)


def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]:
"""
Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper
is found just return None

:param layer_name: layer name in model to get parent of
:model: pytorch module to search through
:return: FSDP wrapped parent of layer_name if available, otherwise None
"""
if not is_fsdp_model(model):
return None

parent_name = layer_name
parent = operator.attrgetter(parent_name)(model)
while not isinstance(parent, FullyShardedDataParallel):
if len(parent_name) == 0: # we've reached the root module and its not FSDP
# this should never get hit because we check for an FSDP root above
# but while statements without a backup are too scary
return None
parent_name = ".".join(parent_name.split(".")[:-1])
parent = operator.attrgetter(parent_name)(model)

return parent