##### Copyright 2024 Google Inc.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# FAX - Federated primitives in JAX

In this colab, we will learn what FAX is, how to use it, and why it was designed. We'll go over a couple of JAX-related topics on the way, but some minor familiarity with JAX may be useful.

# Imports

In [None]:
!pip install --upgrade google-fax

In [None]:
import jax
import jax.numpy as jnp
import fax

# What is FAX?

FAX has a few goals.

1. Create a JAX authoring surface for FL research that uses TFF-like primitives.
2. Enable the use of [Federated Automatic Differentiation](https://arxiv.org/abs/2301.07806) (federated AD).

## What is federated AD?

Suppose I have (in the parlance of TFF), a **federated computation** with the following type signature:
```
foo: (x@SERVER, y@CLIENTS) -> float32@SERVER
```
This looks like a server's loss function - it takes in some server value, and outputs a server float. In many settings, we want to compute the **derivative** of this function - `dz/dx`. This allows us to do things like gradient descent.

For non-federated computations, in frameworks like TF, PyTorch, JAX, we can just call `grad(foo)` For example, in JAX we can do the following:


In [None]:
def square_and_dot(x, y):
  return jnp.dot(jnp.square(x), y)

In [None]:
x = jnp.array([1.0, -3.0])
y = jnp.array([2.0, 2.0])
square_and_dot(x, y)

Array(20., dtype=float32)

To get the derivative of the output with respect to `x`, we can just use `jax.grad`:

In [None]:
jax.grad(square_and_dot)(x, y)

Array([  4., -12.], dtype=float32)

With a single call to `jax.grad`, we can compute the derivative. This is using what's known as **automatic differentiation** (AD).

We would like to be able to use this for **federated computation**. While the [federated AD paper](https://arxiv.org/abs/2301.07806) gives a theoretical framework for doing this, it does not have any direct implementation.

This is where FAX comes in. FAX defines federated primitives (eg. `federated_broadcast`, `federated_map`, etc.) as JAX primitives. This allows JAX to differentiate through them automatically, enabling federated AD!

# FAX and Federated Primitives

Let's take a look at how we can define federated computations in FAX. Recall that in TFF we have 2 placements, `SERVER` and `CLIENTS`. These are represented as a singleton and as a list, respectively.

FAX takes a similar approach - `SERVER` placed values are a singleton.

In [None]:
server_value = jnp.array([-1.0, 1.0])
server_value.shape

(2,)

For `CLIENTS` placements, instead of using a list, we add an extra axis to our tensors. This axis represents our clients.

In [None]:
client_values = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
client_values.shape

(3, 2)

In the above code, we have an example with 3 clients, each of which has a 1-d tensor of shape `(2,)`. The server also has a 1-d tensor of shape `(2,)`. Let's define a map to transform such vectors.

In [None]:
def add_constant(x):
  return x + jnp.array([1.0, 2.0])

print(add_constant(server_value))
print(add_constant(client_values))

[0. 3.]
[[2. 4.]
 [4. 6.]
 [6. 8.]]


We see something interesting here - JAX can recognize additional axes, as in the case of the CLIENTS value, and automatically tries to vectorize.

Now let's try to apply a `federated_map`.

In [None]:
@fax.fax_program(placements={'clients': 3})
def clients_add_one(x):
  return fax.federated_map_clients(add_constant, x)

clients_add_one(client_values)

Array([[2., 4.],
       [4., 6.],
       [6., 8.]], dtype=float32)

We see the same thing! Let's try another primitive, `federated_broadcast`.

In [None]:
@fax.fax_program(placements={'clients': 3})
def broadcast(x):
  return fax.federated_broadcast(x)

broadcast(server_value)

Array([[-1.,  1.],
       [-1.,  1.],
       [-1.,  1.]], dtype=float32)

This does what we expect - we send the same vector to all clients. So far, no real surprises. The cool thing though, is that we can apply differentiation, just as we did above.

The only real restriction is that we want to differentiate server values, with respect to server values. So let's do that.

In [None]:
@fax.fax_program(placements={'clients': 3})
def broadcast_and_sum(x):
  client_x = fax.federated_broadcast(x)
  return fax.federated_sum(client_x)

broadcast_and_sum(2.0)

Array(6., dtype=float32)

Now we differentiate using *forward-mode differentiation* (more on that at the end). We have to tell JAX which arg to differentiate with respect to, for posterity.

In [None]:
jax.grad(broadcast_and_sum, argnums=0)(2.0)

Array(3., dtype=float32)

Why is this the derivative? Well, let's think about what this function is. We get something like:

```
x -> [x, x, x] -> sum([x, x, x]) = 3x
```
Taking a derivative with respect to `x`, we should get 3!

# Federated Linear Regression and federated AD

Let's see how this all works in a more interesting example. We're going to do something akin to linear regression.

Let's assume all clients have their own 2d vector `y`. Given a 2d linear regression model `x`, we'll set up our objective function as follows:

In [None]:
def compute_loss(x, y):
  return 0.5*jnp.square(jnp.dot(x, y) - 1.0)

Essentially, this is doing linear regression where (1) each client has a single example and (2) all clients have label 1.0 for that example. This is not an important observation, the point is that `x, y` go in, and we get out some scalar loss.

Let's try it out.

In [None]:
compute_loss(jnp.array([2.0, 3.0]), jnp.array([5.0, 6.0]))

Array(364.5, dtype=float32)

Great! Now, we can do a `federated_eval` with this loss function. To do that we will:

1. Broadcast `x` to the clients.
2. Do `compute_loss(x, y)` for each client.
3. Average the results.

We can do that in FAX as follows.

In [None]:
@fax.fax_program(placements={'clients': 3})
def federated_eval(server_vector, client_vectors):
  broadcast_vector = fax.federated_broadcast(server_vector)
  client_losses = fax.federated_map_clients(compute_loss, (broadcast_vector, client_vectors))
  return fax.federated_mean(client_losses)

server_vector = jnp.array([2.0, -1.0])
client_vectors = jnp.array([[1.0, 2.0], [3.0, -4.0], [-7.0, 6.0]])
federated_eval(server_vector, client_vectors)

Array(87.16667, dtype=float32)

Just as above, we can use `jax.grad` to differentiate through this function.

In [None]:
grad_fn = jax.grad(federated_eval, argnums=0)
grad_fn(server_vector, client_vectors)

Array([ 57.666668, -54.666668], dtype=float32)

Let's pause here to think about what this means.

We have some loss function $\ell(x, y)$. Our federated evaluation is computing, for some set of clients $C$:

$$
\dfrac{d}{dx}\left(\dfrac{1}{|C|}\sum_{i \in C} \ell(x, y_i) \right) = \dfrac{1}{|C|} \sum_{i \in C} \dfrac{d\ell}{dx}(x, y_i).
$$

By taking a derivative we get:

$$
\dfrac{d}{dx}\left(\dfrac{1}{|C|}\sum_{i \in C} \ell(x, y_i) \right) = \dfrac{1}{|C|} \sum_{i \in C} \dfrac{d\ell}{dx}(x, y_i).
$$

In other words we are just computing **the average gradient across clients**. This is exactly what FedSGD does! The only missing ingredient is to use that derivative to update our model. We'll do that via gradient descent with learning rate 0.01.

In [None]:
@fax.fax_program(placements={'clients': 3})
def fed_sgd_step(server_vector, client_vectors):
  server_grad = grad_fn(server_vector, client_vectors)
  return server_vector - 0.01 * server_grad

print(federated_eval(server_vector, client_vectors))
updated_server_vector = fed_sgd_step(server_vector, client_vectors)
print(federated_eval(updated_server_vector, client_vectors))

87.16667
35.551258


As expected, after applying a step of FedSGD, we get a vector with lower loss!

Moreover, going from `federated_eval -> FedSGD` was essentially trivial - we just applied federated AD!