Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Callable, List, Optional, Union

import torch
Expand Down Expand Up @@ -240,10 +241,6 @@ def __init__(
self.step_hook = None
self.generator = generator
self.secure_mode = secure_mode

self.param_groups = self.original_optimizer.param_groups
self.defaults = self.original_optimizer.defaults
self.state = self.original_optimizer.state
self._step_skip_queue = []
self._is_last_step_skipped = False

Expand Down Expand Up @@ -376,6 +373,48 @@ def accumulated_iterations(self) -> int:
)
return vals[0]

@property
def param_groups(self) -> List[dict]:
"""
Returns a list containing a dictionary of all parameters managed by the optimizer.
"""
return self.original_optimizer.param_groups

@param_groups.setter
def param_groups(self, param_groups: List[dict]):
"""
Updates the param_groups of the optimizer.
"""
self.original_optimizer.param_groups = param_groups

@property
def state(self) -> defaultdict:
"""
Returns a dictionary holding current optimization state.
"""
return self.original_optimizer.state

@state.setter
def state(self, state: defaultdict):
"""
Updates the state of the optimizer.
"""
self.original_optimizer.state = state

@property
def defaults(self) -> dict:
"""
Returns a dictionary containing default values for optimization.
"""
return self.original_optimizer.defaults

@defaults.setter
def defaults(self, defaults: dict):
"""
Updates the defaults of the optimizer.
"""
self.original_optimizer.defaults = defaults

def attach_step_hook(self, fn: Callable[[DPOptimizer], None]):
"""
Attaches a hook to be executed after gradient clipping/noising, but before the
Expand All @@ -386,7 +425,6 @@ def attach_step_hook(self, fn: Callable[[DPOptimizer], None]):
Args:
fn: hook function. Expected signature: ``foo(optim: DPOptimizer)``
"""

self.step_hook = fn

def clip_and_accumulate(self):
Expand Down