diff --git a/opacus/accountants/__init__.py b/opacus/accountants/__init__.py index 4cc0dae0..dc87e42d 100644 --- a/opacus/accountants/__init__.py +++ b/opacus/accountants/__init__.py @@ -16,21 +16,14 @@ from .gdp import GaussianAccountant from .prv import PRVAccountant from .rdp import RDPAccountant +from .registry import create_accountant, register_accountant __all__ = [ "IAccountant", "GaussianAccountant", "RDPAccountant", + "PRVAccountant", + "register_accountant", + "create_accountant", ] - - -def create_accountant(mechanism: str) -> IAccountant: - if mechanism == "rdp": - return RDPAccountant() - elif mechanism == "gdp": - return GaussianAccountant() - elif mechanism == "prv": - return PRVAccountant() - - raise ValueError(f"Unexpected accounting mechanism: {mechanism}") diff --git a/opacus/accountants/registry.py b/opacus/accountants/registry.py new file mode 100644 index 00000000..c0c23405 --- /dev/null +++ b/opacus/accountants/registry.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Type + +from .accountant import IAccountant +from .gdp import GaussianAccountant +from .prv import PRVAccountant +from .rdp import RDPAccountant + + +_ACCOUNTANTS: Dict[str, Type[IAccountant]] = { + "rdp": RDPAccountant, + "gdp": GaussianAccountant, + "prv": PRVAccountant, +} + + +def register_accountant( + mechanism: str, accountant: Type[IAccountant], force: bool = False +): + r""" + Register a new accountant class to be used with a specified mechanism name. + + Args: + mechanism: Name of the mechanism to register the accountant for + accountant: Accountant class (subclass of IAccountant) to register + force: If True, overwrites existing accountant for the specified mechanism. + + Raises: + ValueError: If the mechanism is already registered. + """ + if mechanism in _ACCOUNTANTS and not force: + raise ValueError(f"Accountant for mechanism {mechanism} is already registered") + + _ACCOUNTANTS[mechanism] = accountant + + +def create_accountant(mechanism: str) -> IAccountant: + r""" + Creates and returns an accountant instance for the specified privacy mechanism. + + Args: + mechanism: Name of the privacy accounting mechanism to use. + + Returns: + An instance of the appropriate accountant class (subclass of IAccountant) + for the specified mechanism. + + Raises: + ValueError: If the specified mechanism is not registered. + """ + if mechanism in _ACCOUNTANTS: + return _ACCOUNTANTS[mechanism]() + + raise ValueError(f"Unexpected accounting mechanism: {mechanism}") diff --git a/opacus/tests/accountants_test.py b/opacus/tests/accountants_test.py index 5871f190..16523294 100644 --- a/opacus/tests/accountants_test.py +++ b/opacus/tests/accountants_test.py @@ -19,13 +19,77 @@ from hypothesis import given, settings from opacus.accountants import ( GaussianAccountant, + IAccountant, PRVAccountant, RDPAccountant, create_accountant, + register_accountant, + registry, ) from opacus.accountants.utils import get_noise_multiplier +class AccountantRegistryTest(unittest.TestCase): + + class DummyAccountant(IAccountant): + def __init__(self): + pass + + def __len__(self): + return 0 + + def step(self, **kwargs): + pass + + def get_epsilon(self, **kwargs): + return 0.0 + + def mechanism(cls) -> str: + return "dummy" + + class Dummy2Accountant(DummyAccountant): + pass + + def test_register_accountant(self) -> None: + try: + register_accountant("dummy", AccountantRegistryTest.DummyAccountant) + self.assertIsInstance( + create_accountant("dummy"), AccountantRegistryTest.DummyAccountant + ) + self.assertEqual(create_accountant("dummy").mechanism(), "dummy") + finally: + if "dummy" in registry._ACCOUNTANTS: + del registry._ACCOUNTANTS["dummy"] + + def test_create_accountant_not_registered(self) -> None: + with self.assertRaises(ValueError): + create_accountant("not_registered") + + def test_register_existing_accountant(self): + try: + register_accountant("dummy", AccountantRegistryTest.DummyAccountant) + + with self.assertRaises(ValueError): + register_accountant("rdp", AccountantRegistryTest.DummyAccountant) + finally: + if "dummy" in registry._ACCOUNTANTS: + del registry._ACCOUNTANTS["dummy"] + + def test_force_register_existing_accountant(self) -> None: + try: + register_accountant("dummy", AccountantRegistryTest.DummyAccountant) + + register_accountant( + "dummy", AccountantRegistryTest.Dummy2Accountant, force=True + ) + self.assertIsInstance( + create_accountant("dummy"), AccountantRegistryTest.Dummy2Accountant + ) + finally: + if "dummy" in registry._ACCOUNTANTS: + del registry._ACCOUNTANTS["dummy"] + + class AccountingTest(unittest.TestCase): def test_rdp_accountant(self) -> None: noise_multiplier = 1.5