diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index 804d0874531..c5a7b76e229 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -137,8 +137,11 @@ def initialize(self, state: "State", **kwargs): if self.applied: return + accelerator = kwargs.get("accelerator", None) for modifier in self.modifiers: modifier.initialize(state, **kwargs) + if accelerator: + accelerator.wait_for_everyone() state.loggers.system.info(tag="stage", string="Modifiers initialized") def finalize(self, state: "State", **kwargs): @@ -153,8 +156,11 @@ def finalize(self, state: "State", **kwargs): if self.applied: return + accelerator = kwargs.get("accelerator", None) for modifier in self.modifiers: modifier.finalize(state, **kwargs) + if accelerator: + accelerator.wait_for_everyone() self.applied = True state.loggers.system.info(tag="stage", string="Modifiers finalized") diff --git a/src/sparseml/modifiers/smoothquant/pytorch.py b/src/sparseml/modifiers/smoothquant/pytorch.py index d488a3c3ce2..b3036ff0fdf 100644 --- a/src/sparseml/modifiers/smoothquant/pytorch.py +++ b/src/sparseml/modifiers/smoothquant/pytorch.py @@ -22,6 +22,7 @@ from sparseml.core.model.pytorch import ModifiableModelPyTorch from sparseml.modifiers.smoothquant.base import SmoothQuantModifier, SmoothQuantScale from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward +from sparseml.utils.fsdp.helpers import get_fsdp_parent _LOGGER = logging.getLogger(__name__) @@ -56,7 +57,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: self._setup_scale_hooks() self._calibrate(state.model, calibration_dataloader) - self._apply_smoothing() + self._apply_smoothing(state.model) return True @@ -138,7 +139,7 @@ def _calibrate(self, model: ModifiableModelPyTorch, calibration_dataloader: List del self.hooks_ @torch.no_grad() - def _apply_smoothing(self): + def _apply_smoothing(self, model: ModifiableModelPyTorch): """ After calibration, apply smoothing to the activations and push the transform into the following weights by applying the inverse to each balance weight. @@ -162,17 +163,26 @@ def _apply_smoothing(self): scales, torch.Tensor([MINIMUM_SMOOTHING_SCALE]).to(scales.device) ) - # invert the smoothing in the following layers - for layer in balance_layers: - layer.weight.mul_(scales.view(1, -1)) - - # apply the smoothing - if smooth_layer.weight.ndim == 1: - smooth_layer.weight.div_(scales) + @torch.no_grad() + def smooth(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales) + else: + module.weight.div_(scales.view(-1, 1)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales) + + parent = get_fsdp_parent(mapping.smooth_name, model.model) + if parent is not None: + parent.apply(smooth) else: - smooth_layer.weight.div_(scales.view(-1, 1)) - if hasattr(smooth_layer, "bias") and smooth_layer.bias is not None: - smooth_layer.bias.div_(scales) + # if we're not running with FSDP we can apply smoothing directly + for layer in balance_layers: + smooth(layer) + smooth(smooth_layer) def _calculate_smoothing_scales( self, balance_layers: List[Module], activation_scales: torch.Tensor diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 03d320acd84..ffbe95b0341 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -430,6 +430,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None): calib_data=calib_data, start=-1, copy_data=False, + accelerator=self.accelerator, ) self.accelerator.wait_for_everyone() diff --git a/src/sparseml/utils/fsdp/helpers.py b/src/sparseml/utils/fsdp/helpers.py index ad077e9022e..e0ba36683a2 100644 --- a/src/sparseml/utils/fsdp/helpers.py +++ b/src/sparseml/utils/fsdp/helpers.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator +from typing import Optional + + try: from torch.distributed.fsdp import ( FullStateDictConfig, @@ -33,6 +37,7 @@ "maybe_get_wrapped", "unwrap_and_export_model", "save_pretrained_fsdp", + "get_fsdp_parent", ] @@ -124,3 +129,28 @@ def save_pretrained_fsdp(model, accelerator, output_dir): save_function=accelerator.save, state_dict=state_dict, ) + + +def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: + """ + Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper + is found just return None + + :param layer_name: layer name in model to get parent of + :model: pytorch module to search through + :return: FSDP wrapped parent of layer_name if available, otherwise None + """ + if not is_fsdp_model(model): + return None + + parent_name = layer_name + parent = operator.attrgetter(parent_name)(model) + while not isinstance(parent, FullyShardedDataParallel): + if len(parent_name) == 0: # we've reached the root module and its not FSDP + # this should never get hit because we check for an FSDP root above + # but while statements without a backup are too scary + return None + parent_name = ".".join(parent_name.split(".")[:-1]) + parent = operator.attrgetter(parent_name)(model) + + return parent