diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index bbd554a6..21be2147 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from collections import defaultdict from typing import Callable, List, Optional, Union import torch @@ -240,10 +241,6 @@ def __init__( self.step_hook = None self.generator = generator self.secure_mode = secure_mode - - self.param_groups = self.original_optimizer.param_groups - self.defaults = self.original_optimizer.defaults - self.state = self.original_optimizer.state self._step_skip_queue = [] self._is_last_step_skipped = False @@ -376,6 +373,48 @@ def accumulated_iterations(self) -> int: ) return vals[0] + @property + def param_groups(self) -> List[dict]: + """ + Returns a list containing a dictionary of all parameters managed by the optimizer. + """ + return self.original_optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups: List[dict]): + """ + Updates the param_groups of the optimizer. + """ + self.original_optimizer.param_groups = param_groups + + @property + def state(self) -> defaultdict: + """ + Returns a dictionary holding current optimization state. + """ + return self.original_optimizer.state + + @state.setter + def state(self, state: defaultdict): + """ + Updates the state of the optimizer. + """ + self.original_optimizer.state = state + + @property + def defaults(self) -> dict: + """ + Returns a dictionary containing default values for optimization. + """ + return self.original_optimizer.defaults + + @defaults.setter + def defaults(self, defaults: dict): + """ + Updates the defaults of the optimizer. + """ + self.original_optimizer.defaults = defaults + def attach_step_hook(self, fn: Callable[[DPOptimizer], None]): """ Attaches a hook to be executed after gradient clipping/noising, but before the @@ -386,7 +425,6 @@ def attach_step_hook(self, fn: Callable[[DPOptimizer], None]): Args: fn: hook function. Expected signature: ``foo(optim: DPOptimizer)`` """ - self.step_hook = fn def clip_and_accumulate(self):