Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping
self._per_sample_gradient_norms = None

def get_clipping_coef(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
Expand All @@ -131,6 +132,7 @@ def get_norm_sample(self) -> torch.Tensor:
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
self.per_sample_gradient_norms = norm_sample
return norm_sample

def capture_activations_hook(
Expand Down Expand Up @@ -231,3 +233,17 @@ def capture_backprops_hook(
if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len

@property
def per_sample_gradient_norms(self) -> torch.Tensor:
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
if self._per_sample_gradient_norms is not None:
return self._per_sample_gradient_norms
else:
raise AttributeError(
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
)

@per_sample_gradient_norms.setter
def per_sample_gradient_norms(self, value):
self._per_sample_gradient_norms = value
Loading