In [1]:
#@title License
# Copyright 2025 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# @title Imports
import numpy as np
from scipy.spatial.distance import cdist
from scipy.spatial.distance import pdist
from sklearn.metrics.pairwise import rbf_kernel
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
seed=7 # @param {isTemplate: true}

_RNG = np.random.default_rng(seed=seed)

# MMD Utils

In [None]:
def ensure_list(input_data):
  if not isinstance(input_data, list):
    return [input_data]
  return input_data


def rbf_kernel(samples_x, samples_y, bw=1.0):
  """Returns squared exponential (or RBF) kernel matrix.

  Adapted from Ben Chugg: https://github.com/bchugg/testing-by-betting.

  Args:
    samples_x: Input data.
    samples_y: Input data.
    bw: Kernel bandwidth.
  """
  samples_y = samples_x if samples_y is None else samples_y
  samples_x = ensure_list(samples_x)
  samples_y = ensure_list(samples_y)
  samples_x = np.vstack(samples_x)
  samples_y = np.vstack(samples_y)

  # Matrix euclidean pairwise distances.
  distances = cdist(samples_x, samples_y, 'euclidean')
  sq_distances = distances * distances
  kernel_matrix = np.exp(-sq_distances / (2 * bw * bw))
  return kernel_matrix

def get_first_crossing_index(wealth_process, significance):
  """Returns the first index on the wealth_process to exceed 1/significance.

  Args:
    wealth_process: List of wealth values.
    significance: Significance level.
  """

  idx = np.where(np.array(wealth_process) >= 1 / significance)[0]
  if len(idx)>0:
    # Return the first time step that the process crosses the threshold.
    return idx[0]
  else:
    return len(wealth_process)

# Main classes for sequential testing

Below we include different classes for the Online Newton Step algorithm, Online Gradient Ascent, and a generic One-sided Sequential Two sample test that we then use for privacy audits.

In [None]:
class OnlineNewtonStep:
  """Online Newton Step betting strategy.

  See Shekhar and Ramdas (2023), "Nonparametric Two-Sample Testing by Betting"
  https://arxiv.org/pdf/2112.09162.pdf for more details. For clarity we use
  the notation of the manuscript.

  Attributes:
    tau: Threshold value. See Theorem 3.1. in the manuscript.
    scaling_constant: Scaling constant, see Definition 5. in
      https://arxiv.org/pdf/2112.09162
    sum_grads_squared: Sum of squared z values.
    previous_lambda: Previous lambda value.
  """

  def __init__(self, tau) -> None:
    self.scaling_constant = 2 / (2 - np.log(3))
    self.sum_grads_squared = 1
    self.previous_lambda = 0
    self.tau = tau

  def next_bet(self, payoff_history):
    """Returns the next bet based on the payoff history.

    Args:
      payoff_history: List of previous payoffs.
    """

    if not payoff_history:
      # Bet 0 for the first time.
      return 0
    else:
      previous_payoff = payoff_history[-1]
      z = -previous_payoff / (1 + self.previous_lambda * previous_payoff)
      self.sum_grads_squared += z**2
      lower_limit = 0
      upper_limit = 1 / (8 + 4 * self.tau)
      lambd = max(
          min(
              self.previous_lambda
              - self.scaling_constant * z / self.sum_grads_squared,
              upper_limit,
          ),
          lower_limit,
      )
      self.previous_lambda = lambd
      return lambd


class OnlineGradientAscent:
  """Online Gradient Ascent betting strategy.

  Attributes:
    m_t: Current estimate of the second moment of the gradient.
    gradient_second_moments: History of the second moment of the gradient.
    history_products: History of the products of the betting strategy.
    history_auxiliaryterm: History of the auxiliary term.
  """

  def __init__(self) -> None:

    # OGA parameters
    self.m_t = 0
    self.gradient_second_moments = []
    self.history_products = []
    self.history_auxiliaryterm = []

  def step(self, x_hist, y_hist, bw):
    """Updates the betting strategy and returns the current MMD estimate.

    Args:
      x_hist: History of the first set of samples.
      y_hist: History of the second set of samples.
      bw: Kernel bandwidth.

    Returns:
      The current MMD estimate.
    """

    x = x_hist[-1]
    y = y_hist[-1]
    steps = len(x_hist)

    # First, update 2nd moments history.
    increment = (
        rbf_kernel(x, x, bw)[0, 0]
        + rbf_kernel(y, y, bw)[0, 0]
        - 2 * rbf_kernel(x, y, bw)[0, 0]
    )
    self.m_t += increment
    self.gradient_second_moments.append(self.m_t)

    if len(x_hist) == 1:
      v_t = 0
    else:
      kernel_matrix = (
          rbf_kernel(x_hist[:-1], x, bw)
          - rbf_kernel(x_hist[:-1], y, bw)
          + rbf_kernel(y_hist[:-1], y, bw)
          - rbf_kernel(y_hist[:-1], x, bw)
      ).flatten()
      v_t = np.sum(
          kernel_matrix
          * np.array(self.history_products)
          / (2 * np.sqrt(self.gradient_second_moments[:-1]))
      )

    # auxiliary term
    aux_term = v_t / np.sqrt(self.m_t) + increment / (4 * self.m_t)
    self.history_auxiliaryterm.append(aux_term)

    # update the products of normalizations
    s_t = sum(
        self.history_auxiliaryterm[i] * self.history_products[i] ** 2
        for i in range(steps - 1)
    )
    gamma_t = np.min([1, 1 / (2 * np.sqrt(s_t + aux_term))])
    self.history_products.append(1)
    self.history_products = [x * gamma_t for x in self.history_products]

    return v_t


class OneSidedTwoSampleSequentialTest:
  """Two-sample test using betting.

  Attributes:
    wealth: Current wealth.
    wealth_hist: History of wealth values.
    tau: Threshold value.
    online_newton_step: Online Newton Step betting strategy.
    online_gradient_ascent: Online Gradient Ascent betting strategy.
    bandwidth: Kernel bandwidth.
    lambd: Current lambda value.
    payoff_history: History of payoffs.
    x_hist: History of first set of samples.
    y_hist: History of second set of samples.
  """

  def __init__(self, epsilon: float, delta: float, bw: float) -> None:

    self.wealth = 1
    self.wealth_hist = [1]
    self.tau = 2 * (
        (1 + delta * np.exp(-epsilon))
        * (1 - np.exp(-epsilon))
        / (1 + np.exp(-epsilon))
        + np.exp(-epsilon) * delta
    )
    self.online_newton_step = OnlineNewtonStep(self.tau)
    self.online_gradient_ascent = OnlineGradientAscent()

    self.bandwidth = bw
    self.lambd = 0
    self.payoff_history = []

    # Histories of observed samples.
    self.x_hist = []
    self.y_hist = []

  def step(self, x, y):
    """Performs a step in the two-sample test.

    Args:
      x: New sample from the first distribution.
      y: New sample from the second distribution.

    Returns:
      The updated wealth value.
    """

    # Update samples history.
    self.x_hist.append(x)
    self.y_hist.append(y)

    mmd = self.online_gradient_ascent.step(
        self.x_hist, self.y_hist, self.bandwidth
    )
    payoff = mmd - self.tau
    self.wealth *= 1 + self.lambd * payoff
    self.wealth_hist.append(self.wealth)

    self.payoff_history.append(payoff)
    self.lambd = self.online_newton_step.next_bet(self.payoff_history)

    return self.wealth


# Paper experiments



## Benchmark experiments with Gaussian distributions.

In [None]:
def plot_gaussians_experiments(
    epsilon: float,
    delta: float,
    location: float,
    significance: float,
    max_observations: int,
    initial_seed: int,
):
  """Plots wealth process for a test with one dimensional gaussians.

  Args:
    epsilon: Privacy parameter.
    delta: Privacy parameter.
    location: Mean of the first gaussian distribution. The second gaussian
      distribution has mean 0.
    significance: Significance level.
    max_observations: Maximum number of observations to use.
    initial_seed: Initial seed for the random number generator.

  Returns:
    A dictionary containing the rejection probability and the average number of
    observations to reject.
  """
  num_observations_to_reject = []

  for i in range(10):
    seed = initial_seed + i
    _RNG = np.random.default_rng(seed=seed)
    kernel_mmd_dp = OneSidedTwoSampleSequentialTest(epsilon, delta, bw=1.0)

    for j in range(max_observations):

      samples_x = _RNG.normal(loc=0, scale=1, size=1)
      samples_y = _RNG.normal(loc=location, scale=1, size=1)

      # Perform the sequential test
      kernel_mmd_dp.step(samples_x, samples_y)

    curve = kernel_mmd_dp.wealth_hist
    num_observations_to_reject.append(
        get_first_crossing_index(curve, significance)
    )

    plt.plot(np.arange(len(curve)), curve, alpha=0.3, color='k')

    plt.axhline(y=1 / significance, color='r', linestyle='--')
    plt.text(
        max_observations + 8,
        1 / significance + 1,
        f'Rejection threshold $1 / \\alpha = {1/significance:.0f}$',
        fontsize=14,
        ha='right',
    )
    plt.xlabel('Number of observations', fontsize=16)
    plt.ylabel('Wealth process', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.ylim(0, 100)

  return {
      'rejection_rate': np.mean(
          np.array(num_observations_to_reject) < max_observations
      ),
      'avg_obs_to_reject': np.mean(num_observations_to_reject),
  }

In [None]:
plot_gaussians_experiments(0, 0, location = 0,  significance=0.05, max_observations = 200, initial_seed=7)

In [None]:
plot_gaussians_experiments(0, 0, location = 1.0,  significance=0.05, max_observations = 200, initial_seed=7)

In [None]:
# Power in different dimensions
alpha = 0.05
mean_distances = [0, 0.25, 0.5, 0.75, 1]
dimensions = np.arange(1, 6)
num_samples = 1000
num_simulations = 10

mean_observations = [[] for _ in range(len(mean_distances))]
mean_proportions = [[] for _ in range(len(mean_distances))]
ci_observations = [[] for _ in range(len(mean_distances))]

for i in range(len(mean_distances)):
  mean_distance = mean_distances[i]
  observations = mean_observations[i]
  proportions = mean_proportions[i]
  ci_obs = ci_observations[i]

  for dim in dimensions:
    obs_to_reject = []
    for ii in range(num_simulations):
        _RNG = np.random.default_rng(seed=7 + ii)
        kernel_mmd_dp = OneSidedTwoSampleSequentialTest(epsilon = 0, delta = 0, bw = np.sqrt(dim))
        for jj in range(num_samples):
            X1 =  _RNG.normal(loc=0,scale=1, size=dim)
            X2 =  _RNG.normal(loc=mean_distance/np.sqrt(dim),scale=1, size=dim)

            # Perform the sequential test
            kernel_mmd_dp.step(X1, X2)

        curve = kernel_mmd_dp.wealth_hist
        obs_to_reject.append(get_first_crossing_index(curve, alpha))
    observations.append(np.mean(obs_to_reject))
    proportions.append(np.mean(np.array(obs_to_reject) < num_samples))
    ci_obs.append(1.96*np.std(obs_to_reject)/np.sqrt(num_simulations))

In [None]:
markers = ['o', 's', 'd', '^', '*', 'x']

for i in range(len(mean_distances)):
  plt.errorbar(dimensions, mean_observations[i], yerr=ci_observations[i],
                 marker=markers[i], capsize=5, capthick=1,
                 label=f"$||\mu||_2 = {mean_distances[i]}$")

plt.xticks(range(1,6), range(1,6), fontsize = 14); plt.yticks(fontsize = 14)
plt.xlabel('Dimension of $\mathcal{N}(\mu, I_d)$ distribution', fontsize = 16)
plt.ylabel('Mean size to reject', fontsize = 16)
plt.grid(True, linestyle='--', alpha=0.4)
plt.legend(loc='right', bbox_to_anchor=(1, 0.65), fontsize = 14, framealpha=0.5)
plt.show()

In [None]:
for i in range(len(mean_distances)):
  ci_props = [1.96*np.sqrt(p*(1 - p)/ num_simulations) for p in mean_proportions[i]]
  plt.errorbar(dimensions, mean_proportions[i], yerr=ci_props,
                 marker=markers[i], capsize=5, capthick=1,
                 label=f"$||\mu||_2 = {mean_distances[i]}$")

plt.xticks(range(1,6), range(1,6), fontsize = 14); plt.yticks(fontsize = 14)
plt.xlabel(r'Dimension of $\mathcal{N}(\mu, I_d)$ distribution', fontsize=16)
plt.ylabel('Rejection rate', fontsize = 16)
plt.grid(True, linestyle='--', alpha=0.4)
plt.legend(loc='right', bbox_to_anchor=(1, 0.6), fontsize = 14)
plt.axhline(y=alpha, color='k', linestyle='--', alpha = 0.6)
plt.text(4.8, 0.065, r'$\alpha = 0.05$', fontsize=13, ha='center')
plt.show()

## DP-SGD Sequential Test

For these experiments:

1. From the root directory of the cloned repository, run [`python run_experiment_jaxline.py --config=mnist_audit.py`](https://github.com/google-deepmind/jax_privacy/blob/main/experiments/image_classification/run_experiment_jaxline.py)
2.   The config file can be found [here](https://github.com/google-deepmind/jax_privacy/blob/main/experiments/image_classification/configs/mnist_audit.py). Please update the epsilon value (line 64) to your desired target for the audit. The noise added will per iteration will be adjusted to meet this epsilon over the number of iterations.
3. Logged train metrics will have a "`canary_count`" field and a `dot_product` field. To form the two distributions for analysis:
  - The first distribution consists of `dot_product` values where `canary_count=0`.
   - The second distribution consists of `dot_product` values where `canary_count=1`.



In [None]:
def empirical_epsilon(samples1, samples2, test_epsilons):

  alpha = 0.05; delta = 1e-5; bw_obs = 20
  lower_bound = []
  med = np.median(pdist(np.concatenate((np.array(samples1[:bw_obs]).reshape(-1, 1),
                                        np.array(samples2[:bw_obs]).reshape(-1, 1)), axis=0)))

  n = min(len(samples1), len(samples2))
  testers = {eps: OneSidedTwoSampleSequentialTest(epsilon = eps, delta = delta, bw = med) for eps in test_epsilons}

  for jj in range(bw_obs, n):
    lb = 0
    X1 = samples1[jj]
    X2 = samples2[jj]

    for eps in test_epsilons:
      testers[eps].step(X1, X2)

    rejected_eps = [eps for eps in test_epsilons if testers[eps].wealth_hist[-1] > 1/alpha]
    if len(rejected_eps) > 0:
      lb = np.max(rejected_eps)

    lower_bound.append(lb)

  return lower_bound

In [None]:
# load file with samples.
# We use two gaussians here.
samples1 = np.random.normal(size=1000)
samples2 = np.random.normal(mean=1.0, size=1000)

n = min(len(samples1), len(samples2))
test_epsilons = np.arange(0.001, 0.02, 0.002)

lower_bound_001 = empirical_epsilon(samples1, samples2, test_epsilons)

plt.plot(np.arange(len(lower_bound_001)), lower_bound_001, alpha=0.4, color='blue')

plt.xlabel('Number of observations', fontsize = 18)
plt.ylabel('Empirical lower bound $\epsilon$', fontsize = 18)
plt.axhline(y=0.01, color='r', linestyle='--')
plt.text(0, 0.05, "Theoretical $\epsilon^{ub} = 0.01$", fontsize=15, ha='left')
plt.ylim(-0.09, 1.09); plt.xticks(fontsize = 14); plt.yticks(fontsize = 14)
plt.show()