##### Copyright 2024 Google Inc.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DrJAX - Differentiable MapReduce Primitives in JAX

In this colab, we will learn what DrJAX is, how to use it, and why it was designed. We'll go over a couple of JAX-related topics on the way, but some minor familiarity with JAX may be useful.

# Imports

In [None]:
!pip install --upgrade drjax

In [None]:
import drjax
import jax
import jax.numpy as jnp

# What is DrJAX?

DrJAX has a few goals.

1. Create a JAX authoring surface for MapReduce-style algorithms built on JAX primitives.
2. Enable the expression of differentiable MapReduce computations (in the style of [Federated Automatic Differentiation](https://arxiv.org/abs/2301.07806). We refer to this MapReduce AD.
3. Ensure that DrJAX algorithms are as scalable and efficient as possible, especially when sharding across accelerators.

We will discuss only (1) and (2) in this colab.

## What is MapReduce AD?

Suppose I have a computation that computes a loss value using MapReduce-style operations. In many settings, we want to compute the **derivative** of this function - `dz/dx`. This allows us to do things like gradient descent.

For non-MapReduce computations, in frameworks like TF, PyTorch, JAX, we can just call `grad(foo)` For example, in JAX we can do the following:


In [None]:
def square_and_dot(x, y):
  return jnp.dot(jnp.square(x), y)

In [None]:
x = jnp.array([1.0, -3.0])
y = jnp.array([2.0, 2.0])
square_and_dot(x, y)

Array(20., dtype=float32)

To get the derivative of the output with respect to `x`, we can just use `jax.grad`:

In [None]:
jax.grad(square_and_dot)(x, y)

Array([  4., -12.], dtype=float32)

With a single call to `jax.grad`, we can compute the derivative. This is using what's known as **automatic differentiation** (AD).

We would like to be able to use this for **MapReduce-style computations**. While the [federated AD paper](https://arxiv.org/abs/2301.07806) gives a theoretical framework for doing this, it does not have any direct implementation.

This is where DrJAX comes in. DrJAX defines MapReduce primitives (eg. `broadcast`, `map_fn`, `reduce_sum`.) as JAX primitives. This allows JAX to differentiate through them automatically, enabling MapReduce AD!

# DrJAX and MapReduce Primitives

Let's take a look at how we can define MapReduce computations in DrJAX. Conceptually, DrJAX operates on two different kinds of values: non-partitioned values (these are just standard values) and **partitioned** values. The latter are values that are partitioned across workers in a MapReduce computation. We will also sometimes refer to these as unplaced and placed values. Intuitively, the "placement" is an extra data dimension over which we MapReduce.

In [None]:
nonpartitioned_value = jnp.array([-1.0, 1.0])
nonpartitioned_value.shape

(2,)

For non-partitioned values, we add an extra axis to our tensors. This axis represents the partition of data on which MapReduce will operate.

In [None]:
partitioned_value = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
partitioned_value.shape

(3, 2)

In the above code, we have an example with 3 groups of data, each of which has a 1-d tensor of shape `(2,)`. There is also has a non-partitioned 1-d tensor of shape `(2,)`. Let's define a map to transform such vectors.

In [None]:
def add_constant(x):
  return x + jnp.array([1.0, 2.0])

print(add_constant(nonpartitioned_value))
print(add_constant(partitioned_value))

[0. 3.]
[[2. 4.]
 [4. 6.]
 [6. 8.]]


We see something interesting here - JAX can recognize additional axes, as in the case of the partitioned value, and automatically tries to vectorize.

Now let's map a function over the partitioned values via `drjax.map_fn`. We have to tell DrJAX the name of our MapReduce axis (in this case we call it `data_groups`), and how many groups of data we are mapping over (in this case, we'll use 3).

In [None]:
@drjax.program(placements={'data_groups': 3})
def partitioned_add_one(x):
  return drjax.map_fn(add_constant, x)

partitioned_add_one(partitioned_value)

Array([[2., 4.],
       [4., 6.],
       [6., 8.]], dtype=float32)

We see the same thing! Let's try another primitive, `drjax.broadcast`.

In [None]:
@drjax.program(placements={'data_groups': 3})
def broadcast(x):
  return drjax.broadcast(x)

broadcast(nonpartitioned_value)

Array([[-1.,  1.],
       [-1.,  1.],
       [-1.,  1.]], dtype=float32)

This does what we expect - we send the same vector to all data groups. So far, no real surprises. The cool thing though, is that we can apply differentiation, just as we did above.

The only real restriction is that we want to differentiate nonpartitioned values, with respect to nonpartitioned values. So let's do that.

In [None]:
@drjax.program(placements={'data_groups': 3})
def broadcast_and_sum(x):
  broadcast_x = drjax.broadcast(x)
  return drjax.reduce_sum(broadcast_x)

broadcast_and_sum(2.0)

Array(6., dtype=float32)

Now we differentiate using *reverse-mode AD* (more on that at the end). We have to tell JAX which arg to differentiate with respect to, for posterity.

In [None]:
jax.grad(broadcast_and_sum, argnums=0)(2.0)

Array(3., dtype=float32)

Why is this the derivative? Well, let's think about what this function is. We get something like:

```
x -> [x, x, x] -> sum([x, x, x]) = 3x
```
Taking a derivative with respect to `x`, we should get 3!

# Linear Regression and MapReduce AD

Let's see how this all works in a more interesting example. We're going to do something akin to linear regression.

Let's assume all data groups have their own 2d vector `y`. Given a 2d linear regression model `x`, we'll set up our objective function as follows:

In [None]:
def compute_loss(x, y):
  return 0.5*jnp.square(jnp.dot(x, y) - 1.0)

Essentially, this is doing linear regression where (1) each data group has a single example and (2) all groups have label 1.0 for that example. This is not an important observation, the point is that `x, y` go in, and we get out some scalar loss.

Let's try it out.

In [None]:
compute_loss(jnp.array([2.0, 3.0]), jnp.array([5.0, 6.0]))

Array(364.5, dtype=float32)

Great! Now, we can evaluate across data groups with this loss function. To do that we will:

1. Broadcast `x` across the MapReduce axis (ie. the `data_groups`).
2. Do `compute_loss(x, y)` for each group.
3. Average the results.

We can do that in DrJAX as follows.

In [None]:
@drjax.program(placements={'data_groups': 3})
def partitioned_eval(model, partitioned_data):
  broadcast_vector = drjax.broadcast(model)
  partitioned_losses = drjax.map_fn(compute_loss, (broadcast_vector, partitioned_data))
  return drjax.reduce_mean(partitioned_losses)

model = jnp.array([2.0, -1.0])
partitioned_data = jnp.array([[1.0, 2.0], [3.0, -4.0], [-7.0, 6.0]])
partitioned_eval(model, partitioned_data)

Array(87.16667, dtype=float32)

Just as above, we can use `jax.grad` to differentiate through this function.

In [None]:
grad_fn = jax.grad(partitioned_eval, argnums=0)
grad_fn(model, partitioned_data)

Array([ 57.666668, -54.666668], dtype=float32)

Let's pause here to think about what this means.

We have some loss function $\ell(x, y)$. Our partitioned evaluation is computing, for some partitioned set of data $C$:

$$
\dfrac{1}{|C|}\sum_{i \in C} \ell(x, y_i).
$$

By taking a derivative we get:

$$
\dfrac{d}{dx}\left(\dfrac{1}{|C|}\sum_{i \in C} \ell(x, y_i) \right) = \dfrac{1}{|C|} \sum_{i \in C} \dfrac{d\ell}{dx}(x, y_i).
$$

In other words we are just computing **the average gradient across the partitioned data**. This is exactly what distributed SGD does! The only missing ingredient is to use that derivative to update our model. We'll do that via gradient descent with learning rate 0.01.

In [None]:
@drjax.program(placements={'data_groups': 3})
def distributed_sgd_step(model, partitioned_data):
  grad = grad_fn(model, partitioned_data)
  return model - 0.01 * grad

print(partitioned_eval(model, partitioned_data))
updated_model = distributed_sgd_step(model, partitioned_data)
print(partitioned_eval(updated_model, partitioned_data))

87.16667
35.551258


As expected, after applying a step of distributed SGD, we get a vector with lower loss!

Moreover, going from `partitioned_Eval -> distributed SGD` was essentially trivial - we just applied MapReduce AD!

# Conclusion

Above, we showed how to use DrJAX to define MapReduce-style computations, and how to apply MapReduce AD to differentiate through them. We encourage you to try out your own MapReduce-style computations, especially at scale.