From 568ed31721b129f2cb283764649b3b6812243d5b Mon Sep 17 00:00:00 2001 From: Evgeny Grigorenko Date: Fri, 15 Aug 2025 20:31:06 +0200 Subject: [PATCH 1/5] Add Accountant registry --- opacus/accountants/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/opacus/accountants/__init__.py b/opacus/accountants/__init__.py index 4cc0dae00..52ba1a570 100644 --- a/opacus/accountants/__init__.py +++ b/opacus/accountants/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type from .accountant import IAccountant from .gdp import GaussianAccountant @@ -24,13 +25,13 @@ "RDPAccountant", ] +_ACCOUNTANTS = {"rdp": RDPAccountant, "gdp": GaussianAccountant, "prv": PRVAccountant} + +def register_accountant(mechanism: str, accountant: Type[IAccountant]): + _ACCOUNTANTS[mechanism] = accountant def create_accountant(mechanism: str) -> IAccountant: - if mechanism == "rdp": - return RDPAccountant() - elif mechanism == "gdp": - return GaussianAccountant() - elif mechanism == "prv": - return PRVAccountant() + if mechanism in _ACCOUNTANTS: + return _ACCOUNTANTS[mechanism]() raise ValueError(f"Unexpected accounting mechanism: {mechanism}") From 22eed0b9016e1dfbf5770317a4cd71ae72622e7a Mon Sep 17 00:00:00 2001 From: Evgeny Grigorenko Date: Thu, 4 Sep 2025 15:01:20 +0200 Subject: [PATCH 2/5] Refactor Accountant registry into a separate module and add corresponding tests --- opacus/accountants/__init__.py | 17 ++------ opacus/accountants/registry.py | 66 ++++++++++++++++++++++++++++++++ opacus/tests/accountants_test.py | 27 +++++++++++++ 3 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 opacus/accountants/registry.py diff --git a/opacus/accountants/__init__.py b/opacus/accountants/__init__.py index 52ba1a570..27c30c2f6 100644 --- a/opacus/accountants/__init__.py +++ b/opacus/accountants/__init__.py @@ -11,27 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type from .accountant import IAccountant from .gdp import GaussianAccountant from .prv import PRVAccountant from .rdp import RDPAccountant - +from .registry import register_accountant, create_accountant __all__ = [ "IAccountant", "GaussianAccountant", "RDPAccountant", + "PRVAccountant", + "register_accountant", + "create_accountant", ] - -_ACCOUNTANTS = {"rdp": RDPAccountant, "gdp": GaussianAccountant, "prv": PRVAccountant} - -def register_accountant(mechanism: str, accountant: Type[IAccountant]): - _ACCOUNTANTS[mechanism] = accountant - -def create_accountant(mechanism: str) -> IAccountant: - if mechanism in _ACCOUNTANTS: - return _ACCOUNTANTS[mechanism]() - - raise ValueError(f"Unexpected accounting mechanism: {mechanism}") diff --git a/opacus/accountants/registry.py b/opacus/accountants/registry.py new file mode 100644 index 000000000..bf302e4ee --- /dev/null +++ b/opacus/accountants/registry.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Type + +from .accountant import IAccountant +from .gdp import GaussianAccountant +from .prv import PRVAccountant +from .rdp import RDPAccountant + + +_ACCOUNTANTS: Dict[str, Type[IAccountant]] = { + "rdp": RDPAccountant, + "gdp": GaussianAccountant, + "prv": PRVAccountant, +} + + +def register_accountant(mechanism: str, accountant: Type[IAccountant]): + r""" + Register a new accountant class to be used with a specified mechanism name. + + Args: + mechanism: Name of the mechanism to register the accountant for + accountant: Accountant class (subclass of IAccountant) to register + + Example: + >>> register_accountant("my_accountant", MyAccountant) + """ + _ACCOUNTANTS[mechanism] = accountant + + +def create_accountant(mechanism: str) -> IAccountant: + r""" + Creates and returns an accountant instance for the specified privacy mechanism. + + Args: + mechanism: Name of the privacy accounting mechanism to use. + + Returns: + An instance of the appropriate accountant class (subclass of IAccountant) + for the specified mechanism. + + Raises: + ValueError: If the specified mechanism is not registered. + + Example: + >>> accountant = create_accountant("rdp") + >>> accountant.step(noise_multiplier=1.0, sample_rate=0.01) + >>> epsilon = accountant.get_epsilon(delta=1e-5) + """ + if mechanism in _ACCOUNTANTS: + return _ACCOUNTANTS[mechanism]() + + raise ValueError(f"Unexpected accounting mechanism: {mechanism}") diff --git a/opacus/tests/accountants_test.py b/opacus/tests/accountants_test.py index 5871f190c..e2e9c95d8 100644 --- a/opacus/tests/accountants_test.py +++ b/opacus/tests/accountants_test.py @@ -22,11 +22,38 @@ PRVAccountant, RDPAccountant, create_accountant, + IAccountant, + register_accountant, ) from opacus.accountants.utils import get_noise_multiplier class AccountingTest(unittest.TestCase): + def test_register_accountant(self) -> None: + class DummyAccountant(IAccountant): + def __init__(self): + pass + + def __len__(self): + return 0 + + def step(self, **kwargs): + pass + + def get_epsilon(self, **kwargs): + return 0.0 + + def mechanism(cls) -> str: + return "dummy" + + register_accountant("dummy", DummyAccountant) + self.assertIsInstance(create_accountant("dummy"), DummyAccountant) + self.assertEqual(create_accountant("dummy").mechanism(), "dummy") + + def test_get_accountant_not_registered(self) -> None: + with self.assertRaises(ValueError): + create_accountant("not_registered") + def test_rdp_accountant(self) -> None: noise_multiplier = 1.5 sample_rate = 0.04 From 4481a3c3a0f55d229fd76baad0c49dc864cb11ac Mon Sep 17 00:00:00 2001 From: Evgeny Grigorenko Date: Wed, 10 Sep 2025 17:13:19 +0200 Subject: [PATCH 3/5] Fail on existing accountant + isort fix --- opacus/accountants/__init__.py | 3 ++- opacus/accountants/registry.py | 15 ++++++------ opacus/tests/accountants_test.py | 40 ++++++++++++++++++++------------ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/opacus/accountants/__init__.py b/opacus/accountants/__init__.py index 27c30c2f6..dc87e42d8 100644 --- a/opacus/accountants/__init__.py +++ b/opacus/accountants/__init__.py @@ -16,7 +16,8 @@ from .gdp import GaussianAccountant from .prv import PRVAccountant from .rdp import RDPAccountant -from .registry import register_accountant, create_accountant +from .registry import create_accountant, register_accountant + __all__ = [ "IAccountant", diff --git a/opacus/accountants/registry.py b/opacus/accountants/registry.py index bf302e4ee..c4df143b3 100644 --- a/opacus/accountants/registry.py +++ b/opacus/accountants/registry.py @@ -27,17 +27,21 @@ } -def register_accountant(mechanism: str, accountant: Type[IAccountant]): +def register_accountant(mechanism: str, accountant: Type[IAccountant], force: bool = False): r""" Register a new accountant class to be used with a specified mechanism name. Args: mechanism: Name of the mechanism to register the accountant for accountant: Accountant class (subclass of IAccountant) to register + force: If True, overwrites existing accountant for the specified mechanism. - Example: - >>> register_accountant("my_accountant", MyAccountant) + Raises: + ValueError: If the mechanism is already registered. """ + if mechanism in _ACCOUNTANTS and not force: + raise ValueError(f"Accountant for mechanism {mechanism} is already registered") + _ACCOUNTANTS[mechanism] = accountant @@ -54,11 +58,6 @@ def create_accountant(mechanism: str) -> IAccountant: Raises: ValueError: If the specified mechanism is not registered. - - Example: - >>> accountant = create_accountant("rdp") - >>> accountant.step(noise_multiplier=1.0, sample_rate=0.01) - >>> epsilon = accountant.get_epsilon(delta=1e-5) """ if mechanism in _ACCOUNTANTS: return _ACCOUNTANTS[mechanism]() diff --git a/opacus/tests/accountants_test.py b/opacus/tests/accountants_test.py index e2e9c95d8..1ebf0907c 100644 --- a/opacus/tests/accountants_test.py +++ b/opacus/tests/accountants_test.py @@ -19,41 +19,51 @@ from hypothesis import given, settings from opacus.accountants import ( GaussianAccountant, + IAccountant, PRVAccountant, RDPAccountant, create_accountant, - IAccountant, register_accountant, ) from opacus.accountants.utils import get_noise_multiplier -class AccountingTest(unittest.TestCase): - def test_register_accountant(self) -> None: - class DummyAccountant(IAccountant): - def __init__(self): - pass +class DummyAccountant(IAccountant): + def __init__(self): + pass + + def __len__(self): + return 0 - def __len__(self): - return 0 + def step(self, **kwargs): + pass - def step(self, **kwargs): - pass + def get_epsilon(self, **kwargs): + return 0.0 - def get_epsilon(self, **kwargs): - return 0.0 + def mechanism(cls) -> str: + return "dummy" - def mechanism(cls) -> str: - return "dummy" +class AccountingTest(unittest.TestCase): + def test_register_accountant(self) -> None: register_accountant("dummy", DummyAccountant) self.assertIsInstance(create_accountant("dummy"), DummyAccountant) self.assertEqual(create_accountant("dummy").mechanism(), "dummy") - def test_get_accountant_not_registered(self) -> None: + def test_create_accountant_not_registered(self) -> None: with self.assertRaises(ValueError): create_accountant("not_registered") + def test_register_existing_accountant(self): + with self.assertRaises(ValueError): + register_accountant("rdp", DummyAccountant) + + def test_force_register_existing_accountant(self) -> None: + register_accountant("rdp", DummyAccountant, force=True) + self.assertIsInstance(create_accountant("rdp"), DummyAccountant) + self.assertEqual(create_accountant("rdp").mechanism(), "dummy") + def test_rdp_accountant(self) -> None: noise_multiplier = 1.5 sample_rate = 0.04 From 9cc9caae814f8714f8d06332a6eb697c0dfa6555 Mon Sep 17 00:00:00 2001 From: Evgeny Grigorenko Date: Thu, 11 Sep 2025 10:12:19 +0200 Subject: [PATCH 4/5] Fix Black formatting --- opacus/accountants/registry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/opacus/accountants/registry.py b/opacus/accountants/registry.py index c4df143b3..c0c234059 100644 --- a/opacus/accountants/registry.py +++ b/opacus/accountants/registry.py @@ -27,7 +27,9 @@ } -def register_accountant(mechanism: str, accountant: Type[IAccountant], force: bool = False): +def register_accountant( + mechanism: str, accountant: Type[IAccountant], force: bool = False +): r""" Register a new accountant class to be used with a specified mechanism name. From 3d64ca8c4c53a1cb3e407162e14df8c3078560fb Mon Sep 17 00:00:00 2001 From: Evgeny Grigorenko Date: Tue, 16 Sep 2025 18:04:07 +0200 Subject: [PATCH 5/5] Refactor `accountants_test.py` to group accountant tests into `AccountantRegistryTest`, add cleanup for registered accountants. --- opacus/tests/accountants_test.py | 67 ++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/opacus/tests/accountants_test.py b/opacus/tests/accountants_test.py index 1ebf0907c..165232948 100644 --- a/opacus/tests/accountants_test.py +++ b/opacus/tests/accountants_test.py @@ -24,46 +24,73 @@ RDPAccountant, create_accountant, register_accountant, + registry, ) from opacus.accountants.utils import get_noise_multiplier -class DummyAccountant(IAccountant): - def __init__(self): - pass +class AccountantRegistryTest(unittest.TestCase): - def __len__(self): - return 0 + class DummyAccountant(IAccountant): + def __init__(self): + pass - def step(self, **kwargs): - pass + def __len__(self): + return 0 - def get_epsilon(self, **kwargs): - return 0.0 + def step(self, **kwargs): + pass - def mechanism(cls) -> str: - return "dummy" + def get_epsilon(self, **kwargs): + return 0.0 + def mechanism(cls) -> str: + return "dummy" + + class Dummy2Accountant(DummyAccountant): + pass -class AccountingTest(unittest.TestCase): def test_register_accountant(self) -> None: - register_accountant("dummy", DummyAccountant) - self.assertIsInstance(create_accountant("dummy"), DummyAccountant) - self.assertEqual(create_accountant("dummy").mechanism(), "dummy") + try: + register_accountant("dummy", AccountantRegistryTest.DummyAccountant) + self.assertIsInstance( + create_accountant("dummy"), AccountantRegistryTest.DummyAccountant + ) + self.assertEqual(create_accountant("dummy").mechanism(), "dummy") + finally: + if "dummy" in registry._ACCOUNTANTS: + del registry._ACCOUNTANTS["dummy"] def test_create_accountant_not_registered(self) -> None: with self.assertRaises(ValueError): create_accountant("not_registered") def test_register_existing_accountant(self): - with self.assertRaises(ValueError): - register_accountant("rdp", DummyAccountant) + try: + register_accountant("dummy", AccountantRegistryTest.DummyAccountant) + + with self.assertRaises(ValueError): + register_accountant("rdp", AccountantRegistryTest.DummyAccountant) + finally: + if "dummy" in registry._ACCOUNTANTS: + del registry._ACCOUNTANTS["dummy"] def test_force_register_existing_accountant(self) -> None: - register_accountant("rdp", DummyAccountant, force=True) - self.assertIsInstance(create_accountant("rdp"), DummyAccountant) - self.assertEqual(create_accountant("rdp").mechanism(), "dummy") + try: + register_accountant("dummy", AccountantRegistryTest.DummyAccountant) + + register_accountant( + "dummy", AccountantRegistryTest.Dummy2Accountant, force=True + ) + self.assertIsInstance( + create_accountant("dummy"), AccountantRegistryTest.Dummy2Accountant + ) + finally: + if "dummy" in registry._ACCOUNTANTS: + del registry._ACCOUNTANTS["dummy"] + +class AccountingTest(unittest.TestCase): def test_rdp_accountant(self) -> None: noise_multiplier = 1.5 sample_rate = 0.04