diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 1af891c4..558c8f8e 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -212,6 +212,26 @@ def _prepare_model( loss_reduction=loss_reduction, ) + def _prepare_criterion( + self, + *, + module: GradSampleModule, + optimizer: DPOptimizer, + criterion=nn.CrossEntropyLoss(), + loss_reduction: str = "mean", + **kwargs, + ) -> DPLossFastGradientClipping: + """ + Args: + module: GradSampleModule used for training, + optimizer: DPOptimizer used for training, + criterion: Loss function used for training, + loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) + + Prepare the DP loss class, which packages the two backward passes for fast gradient clipping. + """ + return DPLossFastGradientClipping(module, optimizer, criterion, loss_reduction) + def is_compatible( self, *, @@ -403,9 +423,14 @@ def make_private( self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate) ) if grad_sample_mode == "ghost": - criterion = DPLossFastGradientClipping( - module, optimizer, criterion, loss_reduction + criterion = self._prepare_criterion( + module=module, + optimizer=optimizer, + criterion=criterion, + loss_reduction=loss_reduction, + **kwargs, ) + return module, optimizer, criterion, data_loader return module, optimizer, data_loader