diff --git a/.circleci/config.yml b/.circleci/config.yml index a5981fb1..6bdb8f04 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -323,7 +323,6 @@ jobs: - checkout - pip_dev_install - unit_tests - - coveralls_upload_parallel unittest_py38_torch_release: docker: @@ -332,7 +331,6 @@ jobs: - checkout - pip_dev_install - unit_tests - - coveralls_upload_parallel unittest_py39_torch_release: docker: @@ -341,7 +339,6 @@ jobs: - checkout - pip_dev_install - unit_tests - - coveralls_upload_parallel unittest_py39_torch_nightly: docker: @@ -351,7 +348,19 @@ jobs: - pip_dev_install: args: "-n" - unit_tests - - coveralls_upload_parallel + + prv_accountant_values: + docker: + - image: cimg/python:3.7.5 + steps: + - checkout + - py_3_7_setup + - pip_dev_install + - run: + name: "Unit test prv accountant" + no_output_timeout: 1h + command: | + python -m unittest opacus.tests.prv_accountant integrationtest_py37_torch_release_cpu: docker: @@ -477,7 +486,6 @@ jobs: - pip_dev_install - run_nvidia_smi - command_unit_tests_multi_gpu - - coveralls_upload_parallel auto_deploy_site: @@ -537,14 +545,8 @@ workflows: filters: *exclude_ghpages - integrationtest_py37_torch_release_cuda: filters: *exclude_ghpages - - finish_coveralls_parallel: + - prv_accountant_values: filters: *exclude_ghpages - requires: - - unittest_py37_torch_release - - unittest_py38_torch_release - - unittest_py39_torch_release - - unittest_py39_torch_nightly - - unittest_multi_gpu nightly: when: @@ -560,10 +562,6 @@ workflows: filters: *exclude_ghpages - micro_benchmarks_py37_torch_release_cuda: filters: *exclude_ghpages - - finish_coveralls_parallel: - filters: *exclude_ghpages - requires: - - unittest_py39_torch_nightly website_deployment: when: diff --git a/examples/char-lstm-classification.py b/examples/char-lstm-classification.py index f71e6cf1..d2945a19 100644 --- a/examples/char-lstm-classification.py +++ b/examples/char-lstm-classification.py @@ -332,10 +332,8 @@ def train( f"\t Epoch {epoch}. Accuracy: {mean(accs):.6f} | Loss: {mean(losses):.6f}" ) try: - epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( - delta=target_delta - ) - printstr += f" | (ε = {epsilon:.2f}, δ = {target_delta}) for α = {best_alpha}" + epsilon = privacy_engine.accountant.get_epsilon(delta=target_delta) + printstr += f" | (ε = {epsilon:.2f}, δ = {target_delta})" except AttributeError: pass print(printstr) @@ -359,10 +357,8 @@ def test(model, test_loader, privacy_engine, target_delta, device="cuda:0"): mean_acc = mean(accs) printstr = "\n----------------------------\n" f"Test Accuracy: {mean_acc:.6f}" if privacy_engine: - epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( - delta=target_delta - ) - printstr += f" (ε = {epsilon:.2f}, δ = {target_delta}) for α = {best_alpha}" + epsilon = privacy_engine.accountant.get_epsilon(delta=target_delta) + printstr += f" (ε = {epsilon:.2f}, δ = {target_delta})" print(printstr + "\n----------------------------\n") return mean_acc diff --git a/examples/cifar10.py b/examples/cifar10.py index 6a85eb6d..9581b601 100644 --- a/examples/cifar10.py +++ b/examples/cifar10.py @@ -195,15 +195,12 @@ def compute_loss_stateless_model(params, sample, target): if i % args.print_freq == 0: if not args.disable_dp: - epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( - delta=args.delta, - alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)), - ) + epsilon = privacy_engine.accountant.get_epsilon(delta=args.delta) print( f"\tTrain Epoch: {epoch} \t" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {np.mean(top1_acc):.6f} " - f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}" + f"(ε = {epsilon:.2f}, δ = {args.delta})" ) else: print( diff --git a/examples/dcgan.py b/examples/dcgan.py index 7f0ab26e..c1065f12 100644 --- a/examples/dcgan.py +++ b/examples/dcgan.py @@ -333,14 +333,12 @@ def forward(self, input): optimizerG.step() if not opt.disable_dp: - epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( - delta=opt.delta - ) + epsilon = privacy_engine.accountant.get_epsilon(delta=opt.delta) data_bar.set_description( f"epoch: {epoch}, Loss_D: {errD.item()} " f"Loss_G: {errG.item()} D(x): {D_x} " f"D(G(z)): {D_G_z1}/{D_G_z2}" - "(ε = %.2f, δ = %.2f) for α = %.2f" % (epsilon, opt.delta, best_alpha) + "(ε = %.2f, δ = %.2f)" % (epsilon, opt.delta) ) else: data_bar.set_description( diff --git a/examples/imdb.py b/examples/imdb.py index 0582571e..844f41fb 100644 --- a/examples/imdb.py +++ b/examples/imdb.py @@ -95,14 +95,12 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch): accuracies.append(acc.item()) if not args.disable_dp: - epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( - delta=args.delta - ) + epsilon = privacy_engine.accountant.get_epsilon(delta=args.delta) print( f"Train Epoch: {epoch} \t" f"Train Loss: {np.mean(losses):.6f} " f"Train Accuracy: {np.mean(accuracies):.6f} " - f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}" + f"(ε = {epsilon:.2f}, δ = {args.delta})" ) else: print( diff --git a/examples/mnist.py b/examples/mnist.py index a97c34c7..71e285c5 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -72,13 +72,11 @@ def train(args, model, device, train_loader, optimizer, privacy_engine, epoch): losses.append(loss.item()) if not args.disable_dp: - epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( - delta=args.delta - ) + epsilon = privacy_engine.accountant.get_epsilon(delta=args.delta) print( f"Train Epoch: {epoch} \t" f"Loss: {np.mean(losses):.6f} " - f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}" + f"(ε = {epsilon:.2f}, δ = {args.delta})" ) else: print(f"Train Epoch: {epoch} \t Loss: {np.mean(losses):.6f}") diff --git a/opacus/accountants/__init__.py b/opacus/accountants/__init__.py index 9e9725e6..4cc0dae0 100644 --- a/opacus/accountants/__init__.py +++ b/opacus/accountants/__init__.py @@ -14,6 +14,7 @@ from .accountant import IAccountant from .gdp import GaussianAccountant +from .prv import PRVAccountant from .rdp import RDPAccountant @@ -29,5 +30,7 @@ def create_accountant(mechanism: str) -> IAccountant: return RDPAccountant() elif mechanism == "gdp": return GaussianAccountant() + elif mechanism == "prv": + return PRVAccountant() raise ValueError(f"Unexpected accounting mechanism: {mechanism}") diff --git a/opacus/accountants/analysis/prv/__init__.py b/opacus/accountants/analysis/prv/__init__.py new file mode 100644 index 00000000..d49d1674 --- /dev/null +++ b/opacus/accountants/analysis/prv/__init__.py @@ -0,0 +1,19 @@ +from .compose import compose_heterogeneous +from .domain import Domain, compute_safe_domain_size +from .prvs import ( + DiscretePRV, + PoissonSubsampledGaussianPRV, + TruncatedPrivacyRandomVariable, + discretize, +) + + +__all__ = [ + "DiscretePRV", + "Domain", + "PoissonSubsampledGaussianPRV", + "TruncatedPrivacyRandomVariable", + "compose_heterogeneous", + "compute_safe_domain_size", + "discretize", +] diff --git a/opacus/accountants/analysis/prv/compose.py b/opacus/accountants/analysis/prv/compose.py new file mode 100644 index 00000000..04aecf5b --- /dev/null +++ b/opacus/accountants/analysis/prv/compose.py @@ -0,0 +1,62 @@ +from typing import List + +import numpy as np +from scipy.fft import irfft, rfft +from scipy.signal import convolve + +from .prvs import DiscretePRV + + +def _compose_fourier(dprv: DiscretePRV, num_self_composition: int) -> DiscretePRV: + if len(dprv) % 2 != 0: + raise ValueError("Can only compose evenly sized discrete PRVs") + + composed_pmf = irfft(rfft(dprv.pmf) ** num_self_composition) + + m = num_self_composition - 1 + if num_self_composition % 2 == 0: + m += len(composed_pmf) // 2 + composed_pmf = np.roll(composed_pmf, m) + + domain = dprv.domain.shift_right(dprv.domain.shifts * (num_self_composition - 1)) + + return DiscretePRV(pmf=composed_pmf, domain=domain) + + +def _compose_two(dprv_left: DiscretePRV, dprv_right: DiscretePRV) -> DiscretePRV: + pmf = convolve(dprv_left.pmf, dprv_right.pmf, mode="same") + domain = dprv_left.domain.shift_right(dprv_right.domain.shifts) + return DiscretePRV(pmf=pmf, domain=domain) + + +def _compose_convolution_tree(dprvs: List[DiscretePRV]) -> DiscretePRV: + # repeatedly convolve neighbouring PRVs until we only have one left + while len(dprvs) > 1: + dprvs_conv = [] + if len(dprvs) % 2 == 1: + dprvs_conv.append(dprvs.pop()) + + for dprv_left, dprv_right in zip(dprvs[:-1:2], dprvs[1::2]): + dprvs_conv.append(_compose_two(dprv_left, dprv_right)) + + dprvs = dprvs_conv + return dprvs[0] + + +def compose_heterogeneous( + dprvs: List[DiscretePRV], num_self_compositions: List[int] +) -> DiscretePRV: + r""" + Compose a heterogenous list of PRVs with multiplicity. We use FFT to compose + identical PRVs with themselves first, then pairwise convolve the remaining PRVs. + + This is the approach taken in https://github.com/microsoft/prv_accountant + """ + if len(dprvs) != len(num_self_compositions): + raise ValueError("dprvs and num_self_compositions must have the same length") + + dprvs = [ + _compose_fourier(dprv, num_self_composition) + for dprv, num_self_composition in zip(dprvs, num_self_compositions) + ] + return _compose_convolution_tree(dprvs) diff --git a/opacus/accountants/analysis/prv/domain.py b/opacus/accountants/analysis/prv/domain.py new file mode 100644 index 00000000..5e86502d --- /dev/null +++ b/opacus/accountants/analysis/prv/domain.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import Sequence + +import numpy as np + +from ...rdp import RDPAccountant + + +@dataclass +class Domain: + r""" + Stores relevant information about the domain on which PRVs are discretized, and + includes a few convenience methods for manipulating it. + """ + t_min: float + t_max: float + size: int + shifts: float = 0.0 + + def __post_init__(self): + if not isinstance(self.size, int): + raise TypeError("`size` must be an integer") + if self.size % 2 != 0: + raise ValueError("`size` must be even") + + @classmethod + def create_aligned(cls, t_min: float, t_max: float, dt: float) -> "Domain": + t_min = np.floor(t_min / dt) * dt + t_max = np.ceil(t_max / dt) * dt + + size = int(np.round((t_max - t_min) / dt)) + 1 + + if size % 2 == 1: + size += 1 + t_max += dt + + domain = cls(t_min, t_max, size) + + if np.abs(domain.dt - dt) / dt >= 1e-8: + raise RuntimeError + + return domain + + def shift_right(self, dt: float) -> "Domain": + return Domain( + t_min=self.t_min + dt, + t_max=self.t_max + dt, + size=self.size, + shifts=self.shifts + dt, + ) + + @property + def dt(self): + return (self.t_max - self.t_min) / (self.size - 1) + + @property + def ts(self): + return np.linspace(self.t_min, self.t_max, self.size) + + def __getitem__(self, i: int) -> float: + return self.t_min + i * self.dt + + +def compute_safe_domain_size( + prvs, + max_self_compositions: Sequence[int], + eps_error: float, + delta_error: float, +) -> float: + """ + Compute safe domain size for the discretization of the PRVs. + + For details about this algorithm, see remark 5.6 in + https://www.microsoft.com/en-us/research/publication/numerical-composition-of-differential-privacy/ + """ + total_compositions = sum(max_self_compositions) + + rdp_accountant = RDPAccountant() + for prv, max_self_composition in zip(prvs, max_self_compositions): + rdp_accountant.history.append( + (prv.noise_multiplier, prv.sample_rate, max_self_composition) + ) + + L_max = rdp_accountant.get_epsilon(delta_error / 4) + + for prv, max_self_composition in zip(prvs, max_self_compositions): + rdp_accountant = RDPAccountant() + rdp_accountant.history = [(prv.noise_multiplier, prv.sample_rate, 1)] + L_max = max( + L_max, + rdp_accountant.get_epsilon(delta=delta_error / (8 * total_compositions)), + ) + + # FIXME: this implementation is adapted from the code accompanying the paper, but + # disagrees subtly with the theory from remark 5.6. It's not immediately clear this + # gives the right guarantees in all cases, though it's fine for eps_error < 1 and + # hence generic cases. + # cf. https://github.com/microsoft/prv_accountant/discussions/34 + return max(L_max, eps_error) + 3 diff --git a/opacus/accountants/analysis/prv/prvs.py b/opacus/accountants/analysis/prv/prvs.py new file mode 100644 index 00000000..c6a93244 --- /dev/null +++ b/opacus/accountants/analysis/prv/prvs.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +from scipy import integrate +from scipy.special import erfc + +from ..rdp import _compute_rdp +from .domain import Domain + + +SQRT2 = np.sqrt(2) + + +class PoissonSubsampledGaussianPRV: + r""" + A Poisson subsampled Gaussian privacy random variable. + + For details about the formulas for the pdf and cdf, see propositions B1 and B4 in + https://www.microsoft.com/en-us/research/publication/numerical-composition-of-differential-privacy/ + """ + + def __init__(self, sample_rate: float, noise_multiplier: float) -> None: + self.sample_rate = sample_rate + self.noise_multiplier = noise_multiplier + + def pdf(self, t): + q = self.sample_rate + sigma = self.noise_multiplier + + z = np.log((np.exp(t) + q - 1) / q) + + return np.where( + t > np.log(1 - q), + sigma + * np.exp(-(sigma**2) * z**2 / 2 - 1 / (8 * sigma**2) + 2 * t) + / ( + SQRT2 + * np.sqrt(np.pi) + * (np.exp(t) + q - 1) + * ((np.exp(t) + q - 1) / q) ** 0.5 + ), + 0.0, + ) + + def cdf(self, t): + q = self.sample_rate + sigma = self.noise_multiplier + + z = np.log((np.exp(t) + q - 1) / q) + + return np.where( + t > np.log(1 - q), + -q * erfc((2 * z * sigma**2 - 1) / (2 * SQRT2 * sigma)) / 2 + - (1 - q) * erfc((2 * z * sigma**2 + 1) / (2 * SQRT2 * sigma)) / 2 + + 1.0, + 0.0, + ) + + def rdp(self, alpha: float) -> float: + return _compute_rdp(self.sample_rate, self.noise_multiplier, alpha) + + +# though we have only implemented the PoissonSubsampledGaussianPRV, this truncated prv +# class is generic, and would work with PRVs corresponding to different mechanisms +class TruncatedPrivacyRandomVariable: + def __init__( + self, prv: PoissonSubsampledGaussianPRV, t_min: float, t_max: float + ) -> None: + self._prv = prv + self.t_min = t_min + self.t_max = t_max + self._remaining_mass = self._prv.cdf(t_max) - self._prv.cdf(t_min) + + def pdf(self, t): + return np.where( + t < self.t_min, + 0.0, + np.where(t < self.t_max, self._prv.pdf(t) / self._remaining_mass, 0.0), + ) + + def cdf(self, t): + return np.where( + t < self.t_min, + 0.0, + np.where( + t < self.t_max, + (self._prv.cdf(t) - self._prv.cdf(self.t_min)) / self._remaining_mass, + 1.0, + ), + ) + + def mean(self) -> float: + """ + Calculate the mean using numerical integration. + """ + points = np.concatenate( + [ + [self.t_min], + -np.logspace(-5, -1, 5)[::-1], + np.logspace(-5, -1, 5), + [self.t_max], + ] + ) + + mean = 0.0 + for left, right in zip(points[:-1], points[1:]): + integral, _ = integrate.quad(self.cdf, left, right, limit=500) + mean += right * self.cdf(right) - left * self.cdf(left) - integral + + return mean + + +@dataclass +class DiscretePRV: + pmf: np.ndarray + domain: Domain + + def __len__(self) -> int: + if len(self.pmf) != self.domain.size: + raise ValueError("pmf and domain must have the same length") + return len(self.pmf) + + def compute_epsilon( + self, delta: float, delta_error: float, eps_error: float + ) -> Tuple[float, float, float]: + if delta <= 0: + return (float("inf"),) * 3 + + if np.finfo(np.longdouble).eps * self.domain.size > delta - delta_error: + raise ValueError( + "Floating point errors will dominate for such small values of delta. " + "Increase delta or reduce domain size." + ) + + t = self.domain.ts + p = self.pmf + d1 = np.flip(np.flip(p).cumsum()) + d2 = np.flip(np.flip(p * np.exp(-t)).cumsum()) + ndelta = np.exp(t) * d2 - d1 + + def find_epsilon(delta_target): + i = np.searchsorted(ndelta, -delta_target, side="left") + if i <= 0: + raise RuntimeError("Cannot compute epsilon") + return np.log((d1[i] - delta_target) / d2[i]) + + eps_upper = find_epsilon(delta - delta_error) + eps_error + eps_lower = find_epsilon(delta + delta_error) - eps_error + eps_estimate = find_epsilon(delta) + return eps_lower, eps_estimate, eps_upper + + def compute_delta_estimate(self, eps: float) -> float: + return np.where( + self.domain.ts >= eps, + self.pmf * (1.0 - np.exp(eps) * np.exp(-self.domain.ts)), + 0.0, + ).sum() + + +def discretize(prv, domain: Domain) -> DiscretePRV: + tC = domain.ts + tL = tC - domain.dt / 2 + tR = tC + domain.dt / 2 + discrete_pmf = prv.cdf(tR) - prv.cdf(tL) + + mean_d = np.dot(domain.ts, discrete_pmf) + mean_c = prv.mean() + + mean_shift = mean_c - mean_d + + if np.abs(mean_shift) >= domain.dt / 2: + raise RuntimeError("Discrete mean differs significantly from continuous mean.") + + domain_shifted = domain.shift_right(mean_shift) + + return DiscretePRV(pmf=discrete_pmf, domain=domain_shifted) diff --git a/opacus/accountants/prv.py b/opacus/accountants/prv.py new file mode 100644 index 00000000..ee0e0d6d --- /dev/null +++ b/opacus/accountants/prv.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy as np + +from .accountant import IAccountant +from .analysis.prv import ( + Domain, + PoissonSubsampledGaussianPRV, + TruncatedPrivacyRandomVariable, + compose_heterogeneous, + compute_safe_domain_size, + discretize, +) + + +class PRVAccountant(IAccountant): + r""" + Tracks privacy expenditure via numerical composition of Privacy loss Random + Variables (PRVs) using the approach suggested by Gopi et al[1]. The implementation + here is heavily inspired by the implementation of the authors that accompanied + their paper[2]. + + By utilising the Fast Fourier transform, this accountant is able to efficiently + calculate tight bounds on the privacy expenditure, and has been shown + experimentally to obtain tighter bounds than the RDP accountant. + + The idea behind this accountant is approximately as follows: + + A differentially private mechanism can be characterised by a PRV. The composition + of multiple differentially privacy mechanisms can be charaterised by the sum of the + corresponding PRVs. To get the density of the sum of PRVs, we convolve the + individual densities. + + This accountant discretizes the PRVs corresponding to each step of the + optimization, and convolves the approximations using the Fast Fourier Transform. + The mesh size and bounds for the discretization are chosen automatically to ensure + suitable approximation quality. + + The resulting convolved density is used to recover epsilon. For more detail, see + the paper[1]. + + References: + [1] https://arxiv.org/abs/2106.02848 + [2] https://github.com/microsoft/prv_accountant + """ + + def __init__(self): + super().__init__() + + def step(self, *, noise_multiplier: float, sample_rate: float): + if len(self.history) >= 1: + (last_noise_multiplier, last_sample_rate, num_steps) = self.history.pop() + if ( + last_noise_multiplier == noise_multiplier + and last_sample_rate == sample_rate + ): + self.history.append( + (last_noise_multiplier, last_sample_rate, num_steps + 1) + ) + else: + self.history.append( + (last_noise_multiplier, last_sample_rate, num_steps) + ) + self.history.append((noise_multiplier, sample_rate, 1)) + + else: + self.history.append((noise_multiplier, sample_rate, 1)) + + def get_epsilon( + self, delta: float, *, eps_error: float = 0.01, delta_error: float = None + ) -> float: + """ + Return privacy budget (epsilon) expended so far. + + Args: + delta: target delta + eps_error: acceptable level of error in the epsilon estimate + delta_error: acceptable level of error in delta + """ + if delta_error is None: + delta_error = delta / 1000 + # we construct a discrete PRV from the history + dprv = self._get_dprv(eps_error=eps_error, delta_error=delta_error) + # this discrete PRV can be used to directly estimate and bound epsilon + _, _, eps_upper = dprv.compute_epsilon(delta, delta_error, eps_error) + # return upper bound as we want guarantee, not just estimate + return eps_upper + + def _get_dprv(self, eps_error, delta_error): + # convert history to privacy loss random variables (prvs). Opacus currently + # operates under the assumption that only a Poisson-subsampled Gaussian + # mechanism is ever used during optimisation + prvs = [ + PoissonSubsampledGaussianPRV(sample_rate, noise_multiplier) + for noise_multiplier, sample_rate, _ in self.history + ] + # compute a safe domain for discretization per Gopi et al. This determines both + # the mesh size and the truncation upper and lower bounds. + num_self_compositions = [steps for _, _, steps in self.history] + domain = self._get_domain( + prvs=prvs, + num_self_compositions=num_self_compositions, + eps_error=eps_error, + delta_error=delta_error, + ) + tprvs = [ + TruncatedPrivacyRandomVariable(prv, domain.t_min, domain.t_max) + for prv in prvs + ] + # discretize and convolve prvs + dprvs = [discretize(tprv, domain) for tprv in tprvs] + return compose_heterogeneous( + dprvs=dprvs, num_self_compositions=num_self_compositions + ) + + def _get_domain( + self, + prvs: List[PoissonSubsampledGaussianPRV], + num_self_compositions: List[int], + eps_error: float, + delta_error: float, + ) -> Domain: + total_self_compositions = sum(num_self_compositions) + + L = compute_safe_domain_size( + prvs=prvs, + max_self_compositions=num_self_compositions, + eps_error=eps_error, + delta_error=delta_error, + ) + + mesh_size = eps_error / np.sqrt( + total_self_compositions * np.log(12 / delta_error) / 2 + ) + + return Domain.create_aligned(-L, L, mesh_size) + + @classmethod + def mechanism(cls) -> str: + return "prv" + + def __len__(self): + return len(self.history) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index cbd1ea07..61450e44 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -107,13 +107,14 @@ class PrivacyEngine: >>> # continue training as normal """ - def __init__(self, *, accountant: str = "rdp", secure_mode: bool = False): + def __init__(self, *, accountant: str = "prv", secure_mode: bool = False): """ Args: accountant: Accounting mechanism. Currently supported: - rdp (:class:`~opacus.accountants.RDPAccountant`) - gdp (:class:`~opacus.accountants.GaussianAccountant`) + - prv (:class`~opacus.accountants.PRVAccountant`) secure_mode: Set to ``True`` if cryptographically strong DP guarantee is required. ``secure_mode=True`` uses secure random number generator for noise and shuffling (as opposed to pseudo-rng in vanilla PyTorch) and diff --git a/opacus/tests/accountants_test.py b/opacus/tests/accountants_test.py index 3d3d0dd3..2ef2e3a5 100644 --- a/opacus/tests/accountants_test.py +++ b/opacus/tests/accountants_test.py @@ -17,7 +17,12 @@ import hypothesis.strategies as st from hypothesis import given, settings -from opacus.accountants import GaussianAccountant, RDPAccountant, create_accountant +from opacus.accountants import ( + GaussianAccountant, + PRVAccountant, + RDPAccountant, + create_accountant, +) from opacus.accountants.utils import get_noise_multiplier @@ -47,6 +52,19 @@ def test_gdp_accountant(self): self.assertLess(6.59, epsilon) self.assertLess(epsilon, 6.6) + def test_prv_accountant(self): + noise_multiplier = 1.5 + sample_rate = 0.04 + steps = int(90 // 0.04) + + accountant = PRVAccountant() + + for _ in range(steps): + accountant.step(noise_multiplier=noise_multiplier, sample_rate=sample_rate) + + epsilon = accountant.get_epsilon(delta=1e-5) + self.assertAlmostEqual(epsilon, 6.777395712150674) + def test_get_noise_multiplier_rdp_epochs(self): delta = 1e-5 sample_rate = 0.04 @@ -78,6 +96,38 @@ def test_get_noise_multiplier_rdp_steps(self): self.assertAlmostEqual(noise_multiplier, 1.3562, places=4) + def test_get_noise_multiplier_prv_epochs(self): + delta = 1e-5 + sample_rate = 0.04 + epsilon = 8 + epochs = 90 + + noise_multiplier = get_noise_multiplier( + target_epsilon=epsilon, + target_delta=delta, + sample_rate=sample_rate, + epochs=epochs, + accountant="prv", + ) + + self.assertAlmostEqual(noise_multiplier, 1.34765625, places=4) + + def test_get_noise_multiplier_prv_steps(self): + delta = 1e-5 + sample_rate = 0.04 + epsilon = 8 + steps = 2000 + + noise_multiplier = get_noise_multiplier( + target_epsilon=epsilon, + target_delta=delta, + sample_rate=sample_rate, + steps=steps, + accountant="prv", + ) + + self.assertAlmostEqual(noise_multiplier, 1.2915, places=4) + @given( epsilon=st.floats(1.0, 10.0), epochs=st.integers(10, 100), diff --git a/opacus/tests/dp_layers/common.py b/opacus/tests/dp_layers/common.py index d9826969..536278ba 100644 --- a/opacus/tests/dp_layers/common.py +++ b/opacus/tests/dp_layers/common.py @@ -22,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence -from torch.testing import assert_allclose +from torch.testing import assert_close def clone_module(module: nn.Module) -> nn.Module: @@ -327,7 +327,7 @@ def _check_values( f"L1 Loss = {F.l1_loss(dp_out, nn_out)}", ) try: - assert_allclose( + assert_close( actual=dp_out, expected=nn_out, atol=atol, @@ -378,12 +378,10 @@ def _check_packed_sequence( ) try: - assert_allclose( + assert_close( actual=padded_seq_dp, expected=padded_seq_nn, atol=atol, rtol=rtol ) - assert_allclose( - actual=seq_lens_dp, expected=seq_lens_nn, atol=atol, rtol=rtol - ) + assert_close(actual=seq_lens_dp, expected=seq_lens_nn, atol=atol, rtol=rtol) except AssertionError: if failure_msgs is not None: failure_msgs.append(msg) diff --git a/opacus/tests/dp_layers/dp_rnn_test.py b/opacus/tests/dp_layers/dp_rnn_test.py index a70a4501..b310e5b4 100644 --- a/opacus/tests/dp_layers/dp_rnn_test.py +++ b/opacus/tests/dp_layers/dp_rnn_test.py @@ -105,6 +105,8 @@ def test_rnn( dp_rnn.load_state_dict(rnn.state_dict()) + # Packed sequences not happy with deterministic + torch.use_deterministic_algorithms(False) if packed_input_flag == 0: # no packed sequence input x = ( @@ -124,6 +126,7 @@ def test_rnn( ) else: raise ValueError("Invalid packed input flag") + torch.use_deterministic_algorithms(True) if zero_init: self.compare_forward_outputs( diff --git a/opacus/tests/grad_sample_module_test.py b/opacus/tests/grad_sample_module_test.py index ff19c82c..2a70a48c 100644 --- a/opacus/tests/grad_sample_module_test.py +++ b/opacus/tests/grad_sample_module_test.py @@ -25,7 +25,7 @@ ) from opacus.grad_sample.linear import compute_linear_grad_sample from opacus.grad_sample.utils import register_grad_sampler -from torch.testing import assert_allclose +from torch.testing import assert_close from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import FakeData @@ -106,7 +106,7 @@ def test_outputs_unaltered(self): f"MSE = {F.mse_loss(gs_out, normal_out)}, ", f"L1 Loss = {F.l1_loss(gs_out, normal_out)}", ) - assert_allclose(gs_out, normal_out, atol=1e-7, rtol=1e-5, msg=msg) + assert_close(gs_out, normal_out, atol=1e-7, rtol=1e-5, msg=msg) def test_zero_grad(self): x, _ = next(iter(self.dl)) @@ -161,7 +161,7 @@ def test_to_standard_module(self): f"L1 Loss = {F.l1_loss(gs_tensor, original_tensor)}", ) - assert_allclose(gs_tensor, original_tensor, atol=1e-6, rtol=1e-4, msg=msg) + assert_close(gs_tensor, original_tensor, atol=1e-6, rtol=1e-4, msg=msg) def test_remove_hooks(self): """ @@ -246,7 +246,7 @@ def test_state_dict(self): # check wrapped module state dict for key in og_state_dict.keys(): self.assertTrue(f"_module.{key}" in gs_state_dict) - assert_allclose(og_state_dict[key], gs_state_dict[f"_module.{key}"]) + assert_close(og_state_dict[key], gs_state_dict[f"_module.{key}"]) def test_load_state_dict(self): gs_state_dict = self.grad_sample_module.state_dict() @@ -257,7 +257,7 @@ def test_load_state_dict(self): # wrapped module is the same for key in self.original_model.state_dict().keys(): self.assertTrue(key in new_gs._module.state_dict()) - assert_allclose( + assert_close( self.original_model.state_dict()[key], new_gs._module.state_dict()[key] ) diff --git a/opacus/tests/grad_samples/common.py b/opacus/tests/grad_samples/common.py index f7fa1eac..4d2ff3a8 100644 --- a/opacus/tests/grad_samples/common.py +++ b/opacus/tests/grad_samples/common.py @@ -25,7 +25,7 @@ from opacus.utils.module_utils import trainable_parameters from opacus.utils.packed_sequences import compute_seq_lengths from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence -from torch.testing import assert_allclose +from torch.testing import assert_close def expander(x, factor: int = 2): @@ -344,7 +344,7 @@ def check_values( f"L1 Loss = {F.l1_loss(opacus_grad_sample, microbatch_grad_sample)}", ) try: - assert_allclose( + assert_close( actual=microbatch_grad_sample, expected=opacus_grad_sample, atol=atol, diff --git a/opacus/tests/grad_samples/conv2d_test.py b/opacus/tests/grad_samples/conv2d_test.py index f27ad158..55268e0e 100644 --- a/opacus/tests/grad_samples/conv2d_test.py +++ b/opacus/tests/grad_samples/conv2d_test.py @@ -22,7 +22,7 @@ from opacus.grad_sample.conv import convolution2d_backward_as_a_convolution from opacus.grad_sample.grad_sample_module import GradSampleModule from opacus.utils.tensor_utils import unfold2d -from torch.testing import assert_allclose +from torch.testing import assert_close from .common import GradSampleHooks_test, expander, shrinker @@ -150,7 +150,7 @@ def test_unfold2d( dilation=(dilation_w, dilation_h), ) - assert_allclose(X_unfold_torch, X_unfold_opacus, atol=0, rtol=0) + assert_close(X_unfold_torch, X_unfold_opacus, atol=0, rtol=0) def test_asymetric_dilation_and_kernel_size(self): """ diff --git a/opacus/tests/grad_samples/dp_rnn_test.py b/opacus/tests/grad_samples/dp_rnn_test.py index 39f29ad6..b6f64bb0 100644 --- a/opacus/tests/grad_samples/dp_rnn_test.py +++ b/opacus/tests/grad_samples/dp_rnn_test.py @@ -75,6 +75,7 @@ def test_rnn( using_packed_sequences: bool, packed_sequences_sorted: bool, ): + torch.use_deterministic_algorithms(False) rnn = model( D, H, diff --git a/opacus/tests/prv_accountant.py b/opacus/tests/prv_accountant.py new file mode 100644 index 00000000..b2efd290 --- /dev/null +++ b/opacus/tests/prv_accountant.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from opacus.accountants import PRVAccountant + + +# Computed from https://github.com/microsoft/prv_accountant +msr_values = { + (0.8, 0.001, 100): 0.270403364058018680360362395731, + (1.5, 0.001, 100): 0.0393694193095203287535710501288, + (2.0, 0.001, 100): 0.0291045369484074882560076247273, + (3.0, 0.001, 100): 0.0215200934390882663016508757892, + (0.8, 0.05, 100): 6.74105193941409996938318727189, + (1.5, 0.05, 100): 1.95385313568420015961635272106, + (2.0, 0.05, 100): 1.28363737962406521120328761754, + (3.0, 0.05, 100): 0.767933789666825572517439013609, + (0.8, 0.2, 100): 24.3336641779225644199868838768, + (1.5, 0.2, 100): 8.49897785912163072907787864096, + (2.0, 0.2, 100): 5.67205560882372950004537415225, + (3.0, 0.2, 100): 3.38122051022184422208738396876, + (0.8, 0.001, 1000): 0.477771266184522591657923840103, + (1.5, 0.001, 1000): 0.0998206602308474855167474970585, + (2.0, 0.001, 1000): 0.0712187825381152272985474382949, + (3.0, 0.001, 1000): 0.0476159308713589302097801692071, + (0.8, 0.05, 1000): 19.5258428053837356230815203162, + (1.5, 0.05, 1000): 6.20996924657726712126759593957, + (2.0, 0.05, 1000): 4.17279986668931890392286732094, + (3.0, 0.05, 1000): 2.52757432605973830774814814504, + (0.8, 0.001, 20000): 1.29600925757016161021795142005, + (1.5, 0.001, 20000): 0.437005108654860807693154356457, + (2.0, 0.001, 20000): 0.305837939762453436820521801565, + (3.0, 0.001, 20000): 0.194031030686054706269061398416, + # (0.8, 0.05, 20000): 140.655209760074228597659384832, + (1.5, 0.05, 20000): 38.3066179140872336006395926233, + (2.0, 0.05, 20000): 24.4217185703225858617315680021, + (3.0, 0.05, 20000): 13.9611146992367061159256991232, + (0.8, 0.001, 50000): 2.03451480771586634688219419331, + (1.5, 0.001, 50000): 0.704218384555810095193351116905, + (2.0, 0.001, 50000): 0.491354910211401318953505779064, + (3.0, 0.001, 50000): 0.309694161156544023327796821832, + # (0.8, 0.05, 50000): 291.384325452396524269715882838, + (1.5, 0.05, 50000): 73.1901767238329625797632616013, + (2.0, 0.05, 50000): 45.2369958613306906158868514467, + (3.0, 0.05, 50000): 24.9420056431444763234139827546, + # (0.8, 0.001, 100000): 2.92462211935041027643933375657, + # (1.5, 0.001, 100000): 1.01528807502760787251361307426, + # (2.0, 0.001, 100000): 0.706719901081851564761393547087, + # (3.0, 0.001, 100000): 0.443671720880875475323534828931, +} + + +class PRVAccountantTest(unittest.TestCase): + def test_values(self): + for (sigma, q, steps), expected_epsilon in msr_values.items(): + accountant = PRVAccountant() + accountant.history = [(sigma, q, steps)] + epsilon = accountant.get_epsilon(delta=1e-6) + self.assertAlmostEqual(epsilon, expected_epsilon, places=4)