This colab is a supplement to the [Federated Learning with Formal Differential Privacy Guarantees](http://ai.googleblog.com/2022/02/federated-learning-with-formal.html) blogpost, providing details on the parameters we used for training the Gboard Spanish-language next-word-prediction model, and the 
corresponding $\rho$-zCDP (zero concentrated differential privacy) and $(\varepsilon,\delta)$-DP (differential privacy) guarantee for the DP-FTRL algorithm. 

# Code Locations
---

Our core algorithm is available in open-source; the routines for [estimating cumulative sums using tree aggregation](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py) with DP are released in TensorFlow Privacy. These are then integrated with TensorFlow Federated's [aggregation libraries](https://www.tensorflow.org/federated/api_docs/python/tff/aggregators/DifferentiallyPrivateFactory?version=nightly#tree_aggregation), which allow them to be plugged into different learning algorithms, in particular [Federated Averaging](https://www.tensorflow.org/federated/api_docs/python/tff/learning/build_federated_averaging_process).

The [code to perform privacy accounting](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py) (which we reproduce below to make this colab stand alone) can also be found in TensorFlow Privacy.

# Notion of differential privacy (DP)
---
The DP analysis covers the use of the Gboard training data cache for the training of a production model from random initialization using DP-FTRL.

In this work, we conform to the add/remove notion of DP (where neighboring data sets differ by addition/removal of a single client). In the absence of a client at any training step, we assume that the client's model update gets replaced with the all zeros vector. This assumption enforces a subtle modification to the traditional definition of the add/remove notion of DP which allows neighboring data sets to have the same number of records. The formal definition is provided in Definition 1.1 in [the paper](https://arxiv.org/abs/2103.00039). It is a special instantiation of Definition II.3 (removal DP) in [Erlingsson et al.](https://arxiv.org/abs/2001.03618).

Our privacy guarantee holds for all well-behaved clients (that is, clients that faithfully follow the algorithm including participation limits).  One thing to emphasize is that due to the design of the algorithm, a mis-behaved client does not adversely affect the DP guarantee of any well-behaved clients.

The notion of adjacency is with respect to arbitrary training datasets on each client device (i.e., the device removed might have an aribtrarily large local dataset containing arbitrary training examples). For user's with a single device, this corresponds directly to user-level DP; for devices shared with multiple users, this provides a stronger notion of DP than user-level; for a user with say 2 devices that happen to both participate in training the model, the notion is weaker, but group privacy can be used to obtain a user-level guarantee.


# Differential privacy analysis for the DP-FTRL algorithm 
---
Here, we instantiate the DP-FTRL algorithm without tree-restarts from the [arxiv version v3](https://arxiv.org/abs/2103.00039) of our ICML 2021 paper. The privacy analysis is from Appendix D.2 from the arxiv version. 

The main idea in the algorithm is the following:

Let $T$ be the number of rounds (`total_steps` in the code below) of the server-side training process, and in each round $t \in [T]$, we compute $\nabla_t$, which is the sum of gradients over `COHORT_SIZE` distinct clients who contributed at round $t$, at the current model state $\theta_t$. Now, we create a forest of binary trees as follows: 

   1. Create a complete binary tree with $T'$ leaf nodes, where $T'$ equals to smallest power of two $\geq T$. 
   1. Set the first $T$ leaf nodes (from left) to be $\nabla_1,\ldots,\nabla_T$ respectively, and delete the remaining leaf nodes. 
   1. Delete all the internal nodes $z$ which are not complete binary trees rooted at $z$ (that is, delete all $x$ where any leaf of the subtree rooted at $x$ is not present/has been deleted). 
   1. Set each node $z$ of this forest to be the sum of all the leaves in the subtree rooted at $z$. Call the resulting forest $\cal T$.

In this formulation, each client is allowed to participate *only once* in the computation of any $\nabla_t$, and a maximum of $E$  times (`max_participation` in the code below). Furthermore, each client is allowed to participate *once in every 24-hours*. We encode this constraint in the privacy accounting via $\xi$ (`min_separation` in the code below), which is the minimum number of rounds in between two successive participation of any clients. Recall, the following: 

  * Each node $z\in\cal T$ stores a sum of all the $\nabla_t$'s in the leaves of the sub-tree rooted at $z$. 
  * Each leaf node in $\cal T$ only affects values at the nodes in the path path from it to the root of $\cal T$. 

Let $L$ be the clipping norm of each individual client gradient participating in any of the $\nabla_t$'s. For `noise multiplier` $\sigma$, the noise that gets added to each node in the tree $\cal T$ is sampled from ${\cal N}\left(0, \sigma^2 L^2\right)^{p}$, where $p$ is the dimensionality of the model update. Let $\zeta^*$ bound the squared $\ell_2$-sensitivity of the tree $\cal T$: Consider an empty forest with the same structure of $\cal T$, call it $\widehat{\cal T}$. Let $c\in\{0,1\}^T$ be any bit vector with $\|c\|_0\leq E$ (with $E$ being the maximum number of participations per client), and any two successive ones have at least $\xi$ zeros in between them. Let $\widehat{\cal T}(c)$ be an instantiation of $\widehat{\cal T}$ with the leaf nodes being $c$, and each internal node being the sum of all the leaves in the sub-tree rooted at that node. We define $\zeta^*=\max\limits_c\sum_{z\in\widehat{\cal T}(c)}z^2$. In Theorem D.3 [in the paper](https://arxiv.org/abs/2103.00039) (and in function `_tree_sensitivity_square_sum`), we provide a dynamic programming based approach to calculate an upper bound on $\zeta^*$. 

Since our algorithm is primarily based on the Gaussian mechanism, we can easily provide $\rho$-zCDP guarantee, along with $(\epsilon,\delta)$-DP guarantee (with $\delta$ being `target_delta` in the code below). The bound for zCDP is immediate, and taken care of by function `compute_zcdp`. It is a direct consequence of Lemma 2.4 of [Bun and Steinke'16](https://arxiv.org/pdf/1605.02065.pdf) and Theorem D.3 [in our paper](https://arxiv.org/abs/2103.00039).

# General purposing accounting code

The code in this section is a replica of the open source version of privacy accounting
[here](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py), with the purpose of ease of verifiability. The `_tree_sensitivity_square_sum` function here computes the $\zeta^*$ mentioned earlier.

In [None]:
!pip install tensorflow_privacy
import itertools
import time
import warnings

from absl import app
from absl import flags

import six
from typing import Collection, Union, Dict, Tuple
import numpy as np
import math
import collections
import tensorflow_privacy as tfp

# RDP orders to consider
ORDERS = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))

def _check_nonnegative(value: Union[int, float], name: str):
  if value < 0:
    raise ValueError(f"Provided {name} must be non-negative, got {value}")


def _check_possible_tree_participation(num_participation: int,
                                       min_separation: int, start: int,
                                       end: int, steps: int) -> bool:
  """Check if participation is possible with `min_separation` in `steps`.

  This function checks if it is possible for a sample to appear
  `num_participation` in `steps`, assuming there are at least `min_separation`
  nodes between the appearance of the same sample in the streaming data (leaf
  nodes in tree aggregation). The first appearance of the sample is after
  `start` steps, and the sample won't appear in the `end` steps after the given
  `steps`.

  Args:
    num_participation: The number of times a sample will appear.
    min_separation: The minimum number of nodes between two appearance of a
      sample. If a sample appears in consecutive x, y steps in a streaming
      setting, then `min_separation=y-x-1`.
    start:  The first appearance of the sample is after `start` steps.
    end: The sample won't appear in the `end` steps after the given `steps`.
    steps: Total number of steps (leaf nodes in tree aggregation).

  Returns:
    True if a sample can appear `num_participation` with given conditions.
  """
  return start + (min_separation + 1) * num_participation <= steps + end


def _tree_sensitivity_square_sum(
    num_participation: int, min_separation: int, start: int, end: int,
    steps: int, hist_buffer: Dict[Tuple[int, int, int, int], float]) -> float:
  """Compute the worst-case sum of sensitivtiy square for `num_participation`.

  This is the key algorithm for DP accounting for DP-FTRL tree aggregation
  without restart, which recurrently counts the worst-case occurence of a sample
  in all the nodes in a tree. This implements a dynamic programming algorithm
  that exhausts the possible `num_participation` appearance of a sample in
  `steps` leaf nodes. See Appendix D of
  "Practical and Private (Deep) Learning without Sampling or Shuffling"
  https://arxiv.org/abs/2103.00039.

  Args:
    num_participation: The number of times a sample will appear.
    min_separation: The minimum number of nodes between two appearance of a
      sample. If a sample appears in consecutive x, y steps in a streaming
      setting, then `min_separation=y-x-1`.
    start:  The first appearance of the sample is after `start` steps.
    end: The sample won't appear in the `end` steps after the given `steps`.
    steps: Total number of steps (leaf nodes in tree aggregation).
    hist_buffer: A dictionary stores the worst-case sum of sesentivity square
      keyed by (num_participation, start, end, steps).

  Returns:
    The worst-case sum of sesentivity square for the given input.
  """
  key_tuple = (num_participation, start, end, steps)
  if key_tuple in hist_buffer:
    return hist_buffer[key_tuple]
  if not _check_possible_tree_participation(num_participation, min_separation,
                                            start, end, steps):
    sum_value = -np.inf
  elif num_participation == 0:
    sum_value = 0.
  elif num_participation == 1 and steps == 1:
    sum_value = 1.
  else:
    steps_log2 = math.log2(steps)
    max_2power = math.floor(steps_log2)
    if max_2power == steps_log2:
      sum_value = num_participation**2
      max_2power -= 1
    else:
      sum_value = 0.
    candidate_sum = []
    for right_part in range(num_participation + 1):
      for right_start in range(min_separation + 1):
        left_sum = _tree_sensitivity_square_sum(
            num_participation=num_participation - right_part,
            min_separation=min_separation,
            start=start,
            end=right_start,
            steps=2**max_2power,
            hist_buffer=hist_buffer)
        if np.isinf(left_sum):
          candidate_sum.append(-np.inf)
          continue  # Early pruning for dynamic programming
        right_sum = _tree_sensitivity_square_sum(
            num_participation=right_part,
            min_separation=min_separation,
            start=right_start,
            end=end,
            steps=steps - 2**max_2power,
            hist_buffer=hist_buffer)
        candidate_sum.append(left_sum + right_sum)
    sum_value += max(candidate_sum)
  hist_buffer[key_tuple] = sum_value
  return sum_value


def _max_tree_sensitivity_square_sum(max_participation: int,
                                     min_separation: int, steps: int) -> float:
  """Compute the worst-case sum of sensitivtiy square in tree aggregation.

  See Appendix D of
  "Practical and Private (Deep) Learning without Sampling or Shuffling"
  https://arxiv.org/abs/2103.00039.

  Args:
    max_participation: The maximum number of times a sample will appear.
    min_separation: The minimum number of nodes between two appearance of a
      sample. If a sample appears in consecutive x, y steps in a streaming
      setting, then `min_separation=y-x-1`.
    steps: Total number of steps (leaf nodes in tree aggregation).

  Returns:
    The worst-case sum of sesentivity square for the given input.
  """
  num_participation = max_participation
  while not _check_possible_tree_participation(
      num_participation, min_separation, 0, min_separation, steps):
    num_participation -= 1
  candidate_sum, hist_buffer = [], collections.OrderedDict()
  for num_part in range(1, num_participation + 1):
    candidate_sum.append(
        _tree_sensitivity_square_sum(num_part, min_separation, 0,
                                     min_separation, steps, hist_buffer))
  return max(candidate_sum)


def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
                          alpha: float) -> float:
  """Computes RDP of Gaussian mechanism."""
  if np.isinf(alpha):
    return np.inf
  return alpha * sum_sensitivity_square / (2 * sigma**2)


def compute_rdp_single_tree(
    noise_multiplier: float, total_steps: int, max_participation: int,
    min_separation: int,
    orders: Union[float, Collection[float]]=ORDERS) -> Union[float, Collection[float]]:
  """Computes RDP of the Tree Aggregation Protocol for a single tree.

  The accounting assume a single tree is constructed for `total_steps` leaf
  nodes, where the same sample will appear at most `max_participation` times,
  and there are at least `min_separation` nodes between two appearance. The key
  idea is to (recurrently) count the worst-case occurence of a sample
  in all the nodes in a tree, which implements a dynamic programming algorithm
  that exhausts the possible `num_participation` appearance of a sample in
  `steps` leaf nodes.

  See Appendix D of
  "Practical and Private (Deep) Learning without Sampling or Shuffling"
  https://arxiv.org/abs/2103.00039.

  Args:
    noise_multiplier: A non-negative float representing the ratio of the
      standard deviation of the Gaussian noise to the l2-sensitivity of a single
      contribution (a leaf node), which is usually set in
      `TreeCumulativeSumQuery` and `TreeResidualSumQuery` from
      `dp_query.tree_aggregation_query`.
    total_steps: Total number of steps (leaf nodes in tree aggregation).
    max_participation: The maximum number of times a sample can appear.
    min_separation: The minimum number of nodes between two appearance of a
      sample. If a sample appears in consecutive x, y steps in a streaming
      setting, then `min_separation=y-x-1`.
    orders: An array (or a scalar) of RDP orders.

  Returns: The RDPs at all orders. Can be `np.inf`.
  """
  _check_nonnegative(noise_multiplier, "noise_multiplier")
  if noise_multiplier == 0:
    return np.inf
  _check_nonnegative(total_steps, "total_steps")
  _check_nonnegative(max_participation, "max_participation")
  _check_nonnegative(min_separation, "min_separation")
  sum_sensitivity_square = _max_tree_sensitivity_square_sum(
      max_participation, min_separation, total_steps)
  if np.isscalar(orders):
    rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
                                orders)
  else:
    rdp = np.array([
        _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
        for alpha in orders
    ])
  return rdp, sum_sensitivity_square

def compute_zcdp(noise_multiplier, sensitivity_sq):
  """Computes zCDP of the Tree Aggregation Protocol for a single tree, using
     Lemma 2.4 from https://arxiv.org/pdf/1605.02065.pdf.
  Args:
    noise_multiplier: A non-negative float representing the ratio of the
      standard deviation of the Gaussian noise to the l2-sensitivity of a single
      contribution (a leaf node), which is usually set in
      `TreeCumulativeSumQuery` and `TreeResidualSumQuery` from
      `dp_query.tree_aggregation_query`.
    sensitivity_sq: The sum of squared sensitivity of the nodes in the binary tree, assuming 
      each minibatch has \ell_2 sensitivity of one.
  Returns: zCDP parameter.
  """
  return sensitivity_sq / (2 * pow(noise_multiplier, 2))

def eps_from_rdp(rdp, target_delta):
  """Compute epsilon for (eps, target_delta)-DP from rdp."""
  return tfp.get_privacy_spent(
          ORDERS, rdp, target_delta=target_delta)[0]

# Application to the training of a production Gboard language model


We can now plug the parameters used in the production training run into this privacy accounting code. The actual DP-training procedure via DP-FTRL aggregates client updates via the following code path: 

```python
NOISE_MULTIPLIER = 7.0
COHORT_SIZE = 6500
tff.aggregators.DifferentiallyPrivateFactory.tree_aggregation(
        noise_multiplier=NOISE_MULTIPLIER,
        clients_per_round=COHORT_SIZE,       
        use_efficient=True,
        ...)
```

Under the hood, this uses the [`TreeResidualSumQuery.tree_aggregation`](https://github.com/tensorflow/federated/blob/139a123d6918631fc6604df67a0a2dc58c971d40/tensorflow_federated/python/aggregators/differential_privacy.py#L222) which estimates per-round model updates. This is achieved via post-processing of the tree aggregation method: we estimate the sum of updates on round $t$ by taking the difference between the private estimates of the cumulative sums after round $t$ and $t-1$. Finally, for optimization we want an average update, which is provided (again via post-processing) by a [` tfp.NormalizedQuery`](https://www.tensorflow.org/responsible_ai/privacy/api_docs/python/tf_privacy/NormalizedQuery) parameterized by `COHORT_SIZE`. The privacy calculations are independent of the constant `COHORT_SIZE`, but utility depends on (roughly) `COHORT_SIZE` clients contributing to each round.

Finally, the parameters `min_separation`, `max_participation`, and `total_steps` are computed post-facto after the training completes. As each device is configured to participate in training at most once per 24 hours, `min_separation` is computed by taking the smallest number of training rounds completed in any 24 hour period, and `max_participation` is the maximum number of times any device can participate in `total_steps` training rounds while satisfying the participation constraint. 


In [None]:
noise_multiplier=7.0
min_separation=313
max_participation=6
total_steps=2000
target_delta=1e-10

start_time = time.time()

rdp, sensitivity_sq = compute_rdp_single_tree(
    noise_multiplier, total_steps, 
    max_participation, min_separation)
eps = eps_from_rdp(rdp, target_delta)
zcdp = compute_zcdp(noise_multiplier, sensitivity_sq)
print(f'Accounting time {time.time()-start_time:.2f} secs')
print(f'zCDP = {zcdp:.2f}')
print(f'(epsilon, delta) = ({eps:.2f}, {target_delta:.2g})') 

Accounting time 347.85 secs
zCDP = 0.81
(epsilon, delta) = (8.90, 1e-10)
