# Tensor paralleism in JAX with `pmap`

## Setup

Public colab TPU instances (https://colab.research.google.com) have an outdated JAX version, as the new version dropped support for colab TPUs.

To have access to multiple devices we recommend running it using [Kaggle TPU VMs](https://www.kaggle.com/docs/tpu), which gives you 20 hours of TPU access per week.

In [None]:
#@title Imports
from typing import Tuple
import dataclasses
import functools

import jax
print(jax.__version__)

import jax.numpy as jnp
import numpy as np

In [None]:
#@title Notebook setting
USE_MOCK_DEVICES = True #@param {type:"boolean"}
import os

if USE_MOCK_DEVICES:
    print('Using 8 mock devices.')
    # Forces XLA to use `n` CPU threads as host devices.
    os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

    if len(jax.local_devices()) < 8:
        raise Exception("Notebook requires 8 devices to run")

jax.devices()

This is a supporting notebook for Introduction to Tensor Parallelism in JAX.

It focuses on implementing a simplified version of sharded two-layer, multi-layer perceptron (MLP) from [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) paper.  Ideally, readers should familiarise themself with [Parallel Evaluation in JAX](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb) from [JAX 101 series](https://jax.readthedocs.io/en/latest/jax-101/index.html) before working through examples.

#  MLP



## Vanilla transformer $FFN$.
We start by detailing the MLP block.

In the vanilla transformer we have a fully connected feed-forward newtork, which is applied to each position separately and identically. This consists
of two linear transformations with a  $ReLU$ activation:

$$
FFN(x) = \max(0, x W_1 \, + \, b_1) W_2 + b_2
$$


$FFN(x) : \mathbb{R}^{d_{model}} → \mathbb{R}^{d_{model}}$




For small models, the dimensionality of the input and output is $d_{\text{model}}  = 512$, and the inner layer has dimensionality $d_{ff} = 2048$. The actual values are not particularly important. However, it is worth pointing out that $d_{ff} = c \cdot d_{model}$.

Parameters dimensions:
* $W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}}$, $b_1 \in \mathbb{R}^{d_{ff}}$
* $W_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$, $b_2 \in \mathbb{R}^{d_{\text{model}}}$


In [None]:
# To be explicit we define parameters here.
@dataclasses.dataclass
class Params:
    W_1:  jnp.ndarray  # "* d_model d_ff"]
    W_2:  jnp.ndarray  # "* d_ff d_model"]


# A vanilla MLP implentatoi
def mlp(params: Params, x: jax.Array) -> jax.Array:
    """Vanila MLP with dropout."""
    a, b = params
    y = jnp.maximum(jnp.matmul(x, a), 0.0)
    z = jnp.matmul(y, b)
    return z

##  Sharded MLP

We need to compute $Y = GeLU(XA)$, but both $X$ (data matrix) and $A$ are quite large, and we have a few layers.

> One option is to split A along its columns $A = [A_1, A_2 ]$.
>
> This partitioning allows the GeLU nonlinearity to be independently applied to the output of each partitioned GEMM:
> $$ [Y1, \; Y2] = [GeLU(XA_1), \; GeLU(XA_2)]  $$



<img alt="Sharded MLP" src="https://raw.githubusercontent.com/eemlcommunity/PracticalSessions2023/main/tensor_parallelism/mlp.png" width="600px"/>

In [None]:
# To set the ground, let's show how a simple 2-layer MLP
# would look like in pure JAX.

d_model = 4
d_ff = 12

W_1 = jnp.arange(d_model * d_ff).reshape(d_model, d_ff)
b_1 = jnp.arange(d_ff).reshape((1, d_ff))

W_2 = jnp.arange(d_model * d_ff).reshape(d_ff, d_model)
b_2 = jnp.arange(d_model).reshape((1, d_model))

X = jnp.arange(d_model).reshape((1, -1))

# Forward pass of a 2-layer vanilla transformer MLP
print(jax.nn.relu(X @ W_1 + b_1) @ W_2 + b_2)

print("\nTensor shapes:")
for n, t in (("W_1", W_1), ("b_1", b_1), ("W_2", W_2), ("b_2", b_2), ("X", X)):
    print(n.ljust(2), t.shape, t.dtype)

## Implementation sketches

For now we will ignore GeLU and dropout from condisderations.

In [None]:
## Version without biases
A_1, A_2 = jnp.split(W_1, 2, axis=1)  # Split on columns.
B_1, B_2 = jnp.split(W_2, 2, axis=0)  # Split on rows.

Y_1, Y_2 = [jax.nn.relu(X @ A_1),
            jax.nn.relu(X @ A_2)]

Z_1, Z_2 = [Y_1 @ B_1,  Y_2 @ B_2]
Z = Z_1 + Z_2                         # "All-reduce "g" step".

# Sanity check
np.testing.assert_array_equal(
    jax.nn.relu(X @ W_1) @ W_2,
    Z)

In [None]:
## Version with biases
A_1, A_2 = jnp.split(W_1, 2, axis=1)  # Split on columns.
a_1, a_2 = jnp.split(b_1, 2, axis=1)  # Split on columns.

B_1, B_2 = jnp.split(W_2, 2, axis=0)  # Split on rows.

Y_1, Y_2 = [jax.nn.relu(X @ A_1 + a_1),  # This happens on device 1
            jax.nn.relu(X @ A_2 + a_2)]  # This happens on device 2

# For the second bias we need to be a bit smarter,
# basically creating [b_2_0, 0, .., 0], [0, b_2_1, ... 0], version of biases,
# so that they don't interfere with each other during the all-gather sum collective.

## Find indices for the update
idx_1, idx_2 = jnp.split(np.arange(d_model), 2)
## Split the bias on the "model" dimension.
B_b1, B_b2 = jnp.split(b_2, 2, axis=1)

# Do partial update.
Z_1, Z_2 = [(Y_1 @ B_1).at[:,idx_1].add(B_b1),    # This happens on device 1
            (Y_2 @ B_2).at[:,idx_2].add(B_b2)]    # This happens on device 2


Z = Z_1 + Z_2                                     # "All-reduce "g" step".

# Sanity check
np.testing.assert_array_equal(
    jax.nn.relu(X @ W_1 + b_1) @ W_2 + b_2 ,
    Z)

## Matmul from slides

In [None]:
#@title Sharding utils
def shard_on_columns(x: jnp.ndarray, N: int) -> jax.Array:
    """Splits matrix on column axis by number of shards."""
    assert len(x.shape) == 2
    x = jnp.expand_dims(x, axis=0)  # [1, H_1, H_2]
    x = jnp.split(x, N, axis=2)     # [1, H_1, H_2/N] N times
    x = np.concatenate(x, axis=0)   # [N, H_1, H_2/N]
    return x


def shard_on_rows(x: jnp.ndarray, N: int) -> jax.Array:
    """Splits matrix on column axis by number of shards."""
    assert len(x.shape) == 2
    x = jnp.expand_dims(x, axis=0)  # [1, H_1,   H_2]
    x = jnp.split(x, N, axis=1)     # [1, H_1/N, H_2] N times
    x = np.concatenate(x, axis=0)   # [N, H_1,   H_2/N]
    return x


def _unshard(x: jnp.ndarray, N: int, axis: int) -> jax.Array:
    return jnp.squeeze(jnp.concatenate(jnp.split(x, N, axis=0), axis=axis))


def unshard_on_columns(x: jnp.ndarray,  N: int)-> jnp.ndarray:
    # `x` should have shape N R C/N"
    assert len(x.shape) == 3
    n, r, c_by_n = x.shape
    x_unsharded = _unshard(x, N, axis=2)
    assert x_unsharded.shape == (r, c_by_n * n)
    return x_unsharded


def unshard_on_rows(x: jnp.ndarray, N: int) -> jnp.ndarray:
    assert len(x.shape) == 3
    n, r_by_n, c = x.shape
    x_unsharded = _unshard(x, N, axis=1)
    assert x_unsharded.shape == (r_by_n * n, c)
    return x_unsharded


X = np.random.normal(size=(16, 32))
np.testing.assert_allclose(X, unshard_on_columns(shard_on_columns(X, 8), 8))
np.testing.assert_allclose(X, unshard_on_rows(shard_on_rows(X, 8), 8))

In [None]:
dim_in, dim_out = (1024, 1024*16)

A = jax.random.uniform(jax.random.PRNGKey(0), (dim_in, dim_out))
B = jax.random.uniform(jax.random.PRNGKey(1), (dim_out, dim_in))

# Manualy reshape data
A_pmap = shard_on_columns(A, 8)
B_pmap = shard_on_rows(B, 8)

def dot_psum(x, y):
    return jax.lax.psum(x @ y, 'pmap_axis')

In [None]:
dot_pmap = jax.pmap(dot_psum, axis_name="pmap_axis")
C_pmap = dot_pmap(A_pmap, B_pmap)
C_host = A @ B

np.testing.assert_allclose(
    C_pmap[0],  # <- C will be replicated on all N devices, pick any of them.
    C_host,
    rtol=1e-6   # I can't stress it enough, floats are not your friends.
)

In [None]:
%%timeit
for i in range(25):
    C_host = A @ B

In [None]:
@functools.partial(jax.pmap, axis_name="pmap_axis")
def loop(A, B):
    for i in range(25):
        C_pmap = dot_psum(A, B)
    return C_pmap

In [None]:
%%timeit
C_pmap = loop(A_pmap, B_pmap)

## `pmap` Implementation

In [None]:
from typing import Tuple


def init_params(nrows: int, ncols: int, dtype = np.float32):
    """A simple init function."""
    return lambda rng: jax.random.uniform(rng, shape=(nrows, ncols), dtype=dtype)


def make_sharded_mlp(hidden_dim: int,
                     pmap_axis='pmap_axis'):
    """Create Megatron-style sharded MLP, ref: https://arxiv.org/abs/1909.08053.

    Implementation notes:
    ----
    Each Megatron layer consists of two sub-layers:
    1) Y = a(X * A), where a is e.g. GeLU, or any other element-wise function.
    2) Z = d(Y * B), where d is dropout, or any other element-wise function.

    Z = Dropout(GeLU(X @ A) @ B)

    We shard A across its columns, and B across its rows:
    - X has shape   N x M,
    - A has shape   M x C_A and its shard A_i has shape ~ M x (C_A / n)
    - B has shape C_A x C_B and its shard B_j has shape ~ (C_A / n) x C_B

    For `n` devices, we have:

    Z = d(Y @ B) = d([Y_1, ..., Y_n] @ [B_1, ..., B_n]^T)
    = d([      Y_1 @ B_1 + ... +       Y_n @ B_n ])
    = d([ a(X*A_1) @ B_1 + ... +  a(X*A_n) @ B_n ])
             ^                             ^
             |                             |
         Z_1 on device 1               Z_2 on device 2
         \---------------- all reduce ------------/


    where Y_i ~ N x (C_A / n)
        Z_i ~ N x  C_B

    Note: Only after the all-reduce (psum) step, we are allowed to apply d.

    Docs:
    ----

    The returned `apply_fn` must be wrapped inside of a `jax.pmap` whose axis
    name equals `pmap_axis`. The arguments passed to `apply_fn` are assumed to
    be sharded.

    When creating parameters via the returned `init_fn`, the parameters will be
    sharded across devices according to the input shape (the first dimension is
    assumed to be sharded).

    Args:
    hidden_dim: The per-shard hidden dimension. The effective total hidden
    dimension is this multiplied by number of shards.
    pmap_axis: The sharding axis of the outer pmap.
    Returns:
    An init_fn and an apply_fn, as per the Stax API conventions.
    """

    def init_fn(rng: jax.random.PRNGKey, input_shape: Tuple[int, int, int]):
        num_shards, input_dim, output_dim = input_shape

        key_a, key_b = jax.random.split(rng, 2)

        keys_a = jax.random.split(key_a, num=num_shards)
        param_a = jax.pmap(init_params(input_dim, hidden_dim), 'pmap_axis')(keys_a)

        keys_b = jax.random.split(key_b, num=num_shards)
        param_b = jax.pmap(init_params(hidden_dim, output_dim), 'pmap_axis')(keys_b)

        return input_shape, (param_a, param_b)

    def apply_fn(params: Params, x: jax.Array) -> jax.Array:
        # Each layer's parameters are sharded accross N devices.
        a, b = params
        y = jnp.maximum(jnp.matmul(x, a), 0.0)
        z = jnp.matmul(y, b)
        ...  # All-reduce sum (i.e. sum and broadcast).
        return z

    return init_fn, apply_fn

In [None]:
def loss_fn(params: Params, inputs: jax.Array,  targets: jax.Array, logits_fn):
    logits = logits_fn(params, inputs)
    logits = jax.nn.log_softmax(logits, axis=-1)
    # Per batch example.
    loss = -jnp.sum(logits * targets, axis=-1)

    # Mean loss per example.
    return np.mean(loss)

In [None]:
# On TPU donut (2x2) we have 8 cores.
num_shards = jax.device_count()

batch_size = 2
x_dim = 10
hidden_dim = 240  # hidden size of the MLP
out_dim = 10      # output size of the MLP

# Dummy input data with batch size = 2
x = jnp.arange(x_dim * batch_size).reshape(batch_size, x_dim)
y = jax.nn.one_hot(jnp.arange(1, batch_size + 1), out_dim)

# We broadcast* the data.
x_b = jax.lax.broadcast(x, (num_shards,))
y_b = jax.lax.broadcast(y, (num_shards,))

# *) Or to day differntly, we replicate it on different devices.

In [None]:
key = jax.random.PRNGKey(0)
init_fn, apply_fn = make_sharded_mlp(1, hidden_dim // num_shards)

In [None]:
# Wrap sharded apply_fn with pmap
mlp_sharded =  (
    lambda params, inputs: jax.pmap(apply_fn, axis_name='pmap_axis')(params, inputs)
)

In [None]:
# Sanity check if the sharded version is close to the unsharded.
hidden_dim = num_shards*2048

# Create weights on device
W_1_single = np.random.normal(size=(x_dim, hidden_dim))
W_2_single = np.random.normal(size=(hidden_dim, x_dim))
print(W_1_single.shape, W_2_single.shape)

# 'Shard' the data manally for pmap.
W_1_sharded = shard_on_columns(W_1_single, num_shards)
W_2_sharded = shard_on_rows(W_2_single, num_shards)
print(W_1_sharded.shape, W_2_sharded.shape)

np.testing.assert_allclose(
    mlp_sharded((W_1_sharded, W_2_sharded), x_b)[0],
    mlp((W_1_single, W_2_single), x_b[0]),
    atol=1e-5, rtol=1e-3)

In [None]:
print(loss_fn((W_1_single, W_2_single), x_b[0], y_b[0], mlp))
print(loss_fn((W_1_sharded, W_2_sharded), x_b, y_b, mlp_sharded))

In [None]:
# Compare gradients
## Single device MLP.
grads_single = jax.grad(functools.partial(loss_fn, logits_fn=mlp))(
    (W_1_single, W_2_single), x_b[0], y_b[0])

dw1_single, dw2_single = grads_single

## Sharded-MLP on 8 devices.
grads_sharded = jax.grad(functools.partial(loss_fn, logits_fn=mlp_sharded))(
    (W_1_sharded, W_2_sharded), x_b, y_b)

dw1_sharded, dw2_sharded = grads_sharded

In [None]:
# Compare gradient norms.
print(np.linalg.norm(dw1_sharded))
print(np.linalg.norm(dw1_single))

In [None]:
# Compare running times: 2-layer MLP sharded on 8 devices.
%timeit mlp_sharded((W_1_sharded, W_2_sharded), x_b)[0]

In [None]:
# Compare running times: 2-layer MLP on 1 device.
%timeit mlp((W_1_single, W_2_single), x_b[0])

In [None]:
loss_with_mlp_sharded = functools.partial(loss_fn, logits_fn=mlp_sharded)


def update_pmap_inside(params, x_b, y_b):
    loss_val, grads = jax.value_and_grad(loss_with_mlp_sharded)(params, x_b, y_b)
    (A, B), (dA, dB) = params, grads

    A_new = jax.pmap(lambda x, dx: x - 0.01 * dx, axis_name = 'pmap_axis')(A, dA)
    B_new = jax.pmap(lambda x, dx: x - 0.01 * dx, axis_name = 'pmap_axis')(B, dB)

    return loss_val, grads, (A_new, B_new)

In [None]:
# Can you guess, what's wrong with this implementation?

# Can it be improved?
input_shape, params = init_fn(key, (num_shards, x_dim, out_dim))
print(loss_with_mlp_sharded(params, x_b, y_b))

# Warmstart jitted function and print statistics.
_, grads, _ = update_pmap_inside(params, x_b, y_b)

dw1, dw2 = grads
print(f"Norm of dW_1 %.2f" % np.linalg.norm(dw1))
print(f"Norm of dW_2 %.2f\n" % np.linalg.norm(dw2))

for i in range(0, 100):
    if i < 10 or i % 10 == 0:
        print(i, loss_with_mlp_sharded(params, x_b, y_b))
        loss_val, grads, params = update_pmap_inside(params, x_b, y_b)

In [None]:
# Measure the performance.
%timeit _, params = init_fn(key, (num_shards, x_dim, out_dim));
loss_val, grads, params = update_pmap_inside(params, x_b, y_b)

In [None]:
def loss_from_apply_fn(params, xs, ys):
    return loss_fn(params, xs, ys, apply_fn)

@functools.partial(jax.pmap, axis_name='pmap_axis')
def update_pmap_outside(params, xs, ys):
    loss_val, grads = jax.value_and_grad(loss_from_apply_fn)(params, xs, ys)
    new_params = jax.tree_map(
      lambda param, g: param - g * 0.01, params, grads)

    return loss_val, grads, new_params

In [None]:
# Can it be improved?
_, params = init_fn(key, (num_shards, x_dim, out_dim))

# Warmstart jitted function and print statistics.
loss_val, grads, _ =  update_pmap_outside(params, x_b, y_b)

print(loss_val.shape)
print(loss_val[0])

dw1, dw2 = grads
print(f"Norm of dW_1 %.2f" % np.linalg.norm(dw1))
print(f"Norm of dW_2 %.2f\n" % np.linalg.norm(dw2))

In [None]:
%%timeit _, params = init_fn(key, (num_shards, x_dim, out_dim));
loss_val, grads, params = update_pmap_outside(params, x_b, y_b)

## A Note about gradients

As a refresher the partial derivatives of $f: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R} $
$$
f(x,y) = x + y \hspace{0.5in} \rightarrow \hspace{0.5in} \frac{\partial f}{\partial x} = 1 \hspace{0.5in} \frac{\partial f}{\partial y} = 1
$$

in the context of AD it's straightforward to see, that sum operation will distribute gradients equally to all its inputs during backproagation.

In [None]:
x = jnp.array(1.0, dtype=jnp.float32)
y = jnp.array(1.0, dtype=jnp.float32)

z = jnp.ones_like(x) * 5

jax.make_jaxpr(jax.grad(lambda x, y: x + y))(x, y)

In [None]:
# This is evident in the vjp of sum, it just redistributes the incoming signal
# backwards to the input.
jax.make_jaxpr(jax.vjp(lambda x, y:  x + y, *(x, y))[1])(z)

In [None]:
# Now let's investigate the first case,
# pmap inside of gradient.
N_DEVICES = jax.local_device_count()

def f(x):
  # Just call all-gather from all devices.
  return jax.lax.psum(x, axis_name="pmap_axis")

def pmap_f(x):
  # We call all-gather on all devices. The
  # leading device dimension will have the same
  # value on each position.
  return jax.pmap(f, axis_name="pmap_axis")(x)[0]

inputs = np.array([.1] * N_DEVICES)
outs, jax_grads = jax.value_and_grad(pmap_f)(inputs)

np.testing.assert_allclose(outs, inputs * N_DEVICES)
np.testing.assert_allclose(jax_grads, [1.] * N_DEVICES)

In [None]:
# Case 2: pmap outside of gradient.
def fwd(x):
  # Just call all-gather from all devices.
  return jax.lax.psum(x, axis_name='i')

def fwd_bwd(x):
  # Run forward pass, return gradient.
  return jax.value_and_grad(fwd)(x)

In [None]:
N_DEVICES = jax.local_device_count()

input = jax.lax.broadcast(0.5 , (N_DEVICES,))
val, grad = jax.pmap(fwd_bwd, axis_name='i')(input)
# What do you expect gradient to be?

In [None]:
val, grad

In [None]:
# Why is this the case?
print(jax.make_jaxpr(jax.pmap(fwd_bwd, axis_name='i'))(input).pretty_print(use_color=True, source_info=True))

## Custom VJPs

### `psum` forwards, `id` backwards.

In [None]:
# Since, we know what the problem is (an additional psum on the backward pass)
# we need to implement a custom VJP that is psum on the forward pass, and
# identity function during backward pass.

# The process to implement a custom VJP has three steps:

# Step 1: Define the function.
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, ))
def fwd_psum_bwd_id(x, axis_name: str):
    return jax.lax.psum(x, axis_name)

# Step 2: Specify the forward pass, more specifically, the function
# Should return the primal output and residuals (cached activations) that
# we later need to calculate the VJP on the backward pass.
def fwd_psum_bwd_id_fwd(
    x: jnp.ndarray, axis_name: str) -> Tuple[jnp.ndarray, None]:
  # Here, we need to calculate the psum on the forward pass.
  # Since during the backward pass, we don't touch the incoming gradients,
  # on only pass them through to earlier nodes in the computational graph
  # we don't need to return any residuals.
  return fwd_psum_bwd_id(x, axis_name), None

# Step 3: Implement the backward pass.
def fwd_psum_bwd_id_bwd(
    unused_axis_name,
    unused_residuals, g) -> Tuple[jnp.ndarray]:
    # Pass through gradients. Note that we're returning a tuple.
    return (g,)

fwd_psum_bwd_id.defvjp(
    fwd=fwd_psum_bwd_id_fwd,
    bwd=fwd_psum_bwd_id_bwd)

In [None]:
# Now, let's see if the gradients match out expectations.
def fwd(x):
    return fwd_psum_bwd_id(x, axis_name='i')

def fwd_bwd(x):
    return jax.value_and_grad(fwd)(x)

In [None]:
val, grad = jax.pmap(fwd_bwd, axis_name='i')(input)

In [None]:
val, grad

In [None]:
print(jax.make_jaxpr(jax.pmap(fwd_bwd, axis_name='i'))(input))

### `id` forwards, `psum` backwards.

In [None]:
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, ))
def fwd_id_bwd_psum(x, axis_name: str):
    ...


def fwd_id_bwd_psum_fwd(x, axis_name) -> Tuple[jnp.ndarray, None]:
    ...


def fwd_id_bwd_psum_bwd(axis_name, unused_residuals, g):
    ...


fwd_id_bwd_psum.defvjp(
    fwd=fwd_id_bwd_psum_fwd,
    bwd=fwd_id_bwd_psum_bwd)

In [None]:
# Now, let's see if the gradients match out expectations.
def fwd(x):
    return fwd_id_bwd_psum(x, axis_name='i')

def fwd_bwd(x):
    return jax.value_and_grad(fwd)(x)

In [None]:
jax.pmap(fwd_bwd, axis_name='i')(inputs)

# Full example

In [None]:
def mlp_with_custom_vjps(params: Params, x: jax.Array) -> jax.Array:
    """Vanila MLP without dropout."""
    a, b = params

    ...  # Id forwards, All-reduce backwards.
    y = jnp.maximum(jnp.matmul(x, a), 0.0)
    z = jnp.matmul(y, b)
    ...   # All-reduce forwards, id backwards.
    return z

def loss_from_mlp(
    params: Params, xs: jax.Array, ys: jax.Array):
    return loss_fn(params, xs, ys, mlp_with_custom_vjps)


@functools.partial(jax.pmap, axis_name='pmap_axis')
def update_pmap_outside(params: Params, xs: jax.Array, ys: jax.Array):
    loss_val, grads = jax.value_and_grad(loss_from_mlp)(params, xs, ys)
    new_params = jax.tree_map(
      lambda param, g: param - g * 0.01, params, grads)

    return loss_val, grads, new_params

In [None]:
_, params = init_fn(key, (num_shards, x_dim, out_dim))

# Warmstart pmapped function and print statistics.
loss_val, grads, _ =  update_pmap_outside(params, x_b, y_b)

print(loss_val.shape)
print(loss_val[0])

dw1, dw2 = grads
print(f"Norm of dW_1 %.2f" % np.linalg.norm(dw1))
print(f"Norm of dW_2 %.2f\n" % np.linalg.norm(dw2))

In [None]:
%%timeit _, params = init_fn(key, (num_shards, x_dim, out_dim)); update_pmap_outside(params, x_b, y_b)
loss_val, grads, params = update_pmap_outside(params, x_b, y_b)
loss_val = loss_val[0]

# Excercises for the reader

## MNIST

* Implement the training loop with real data.
* Compare single device mlp with sharded mlp (wrong gradients) and sharded mlp with properly implemented backwards pass.

In [None]:
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = np.reshape(train_images, (len(train_images), num_pixels))
train_labels = jax.nn.one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = np.reshape(test_images, (len(test_images), num_pixels))
test_labels = jax.nn.one_hot(test_labels, num_labels)


print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

## Take home challenges.

* Implement sharded Self-Attention from [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf).
* Try adding sharded biases (and/or dropout). How would you handle random keys?
* Experiments with GLU Variants: [GLU Variants Improve Transformer](https://arxiv.org/pdf/2002.05202.pdf)