In [None]:
import functools

import flax
from flax import linen as nn
import jax
from jax import numpy as jnp
from jax import random
from jax.sharding import NamedSharding, PartitionSpec
from jax import tree_util

# Single device

In [None]:
BATCH_SIZE = 1024
INPUT_SIZE = 2**13
LAYER_SIZE = 2**13
USE_BIAS = True

class MLP(nn.Module):
    layer_size: int
    use_bias: bool

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.layer_size, use_bias=self.use_bias)(x)
        return x

In [None]:
model = MLP(LAYER_SIZE, USE_BIAS)
x = jnp.ones((BATCH_SIZE, INPUT_SIZE))
prng_key = random.key(0)
params = model.init(prng_key, x)

display(params['params']['Dense_0']['kernel'].devices())
params

In [None]:
_ = model.apply(params, x)
%timeit model.apply(params, x).block_until_ready()

In [None]:
def run_timeit(model, params, x):
    _ = model.apply(params, x)
    result = %timeit -o model.apply(params, x).block_until_ready()
    return result

def get_flops(batch_size, input_size, layer_size, t):
   return 2 * batch_size * input_size * layer_size / t 

def get_space(batch_size, input_size, layer_size):
    return 4 * (batch_size * input_size + input_size * layer_size)

In [None]:
batch_sizes = [4**i for i in range(9)]

timeits = []
for batch_size in batch_sizes:
    x = jnp.ones((batch_size, INPUT_SIZE))
    result = run_timeit(model, params, x)
    timeits.append(result)

In [None]:
import matplotlib.pyplot as plt

tflops = [get_flops(b, INPUT_SIZE, LAYER_SIZE, t) / 1e12 for b, t in zip(batch_sizes, [t.average for t in timeits])]
plt.plot(batch_sizes, tflops);

# Sharding `params` with `jax.eval_shape`
`jax.device_put` can shard an array that is on a single chip. Therefore, it can't be used for arrays that are too large to fit on a single chip. Some functions have an argument for specifying the sharding. For example, `jax.numpy.ones` has an argument `device`. Use `jax.eval_shape` for a general solution to this problem. Note that all arguments passed via `eval_shape()` will be treated as dynamic. Static arguments can be included via closure, for example using `functools.partial()`.

In [None]:
BATCH_SIZE = 2**16
INPUT_SIZE = 2**16
LAYER_SIZE = 2**16
USE_BIAS = True

# x = jnp.ones((BATCH_SIZE, INPUT_SIZE)) # fails with RESOURCE_EXHAUSTED on TPU v3
# params = MLP(LAYER_SIZE, USE_BIAS).init(random.key(0), jnp.ones((1, INPUT_SIZE))) # fails with RESOURCE_EXHAUSTED on TPU v3

In [None]:
mesh = jax.make_mesh((8, 1), ('x', 'y'))
mesh

In [None]:
class MLP(nn.Module):
    layer_size: int
    use_bias: bool

    @nn.compact
    def __call__(self, x):
        kernel_init = nn.with_partitioning(nn.initializers.lecun_normal(), ('x', 'y'))
        bias_init = nn.with_partitioning(nn.initializers.constant(0), ('x'))
        x = nn.Dense(self.layer_size, use_bias=self.use_bias, kernel_init=kernel_init, bias_init=bias_init)(x)
        return x

In [None]:
model = MLP(LAYER_SIZE, USE_BIAS)
prng_key = random.key(0)
params_shape = jax.eval_shape(functools.partial(model.init, prng_key), jax.ShapeDtypeStruct((BATCH_SIZE, INPUT_SIZE), dtype=float))
params_shape

In [None]:
params_sharding = nn.get_sharding(params_shape, mesh)
params_sharding

In [None]:
params = jax.jit(model.init, out_shardings=params_sharding)(prng_key, jax.ShapeDtypeStruct((BATCH_SIZE, INPUT_SIZE), dtype=float))
params

In [None]:
jax.debug.visualize_array_sharding(params['params']['Dense_0']['kernel'].value)

In [None]:
jax.debug.visualize_array_sharding(params['params']['Dense_0']['bias'].value)

In [None]:
batch_sizes = [4**i for i in range(7)]

timeits = []
for batch_size in batch_sizes:
    x = jnp.ones((batch_size, INPUT_SIZE))
    result = run_timeit(model, params, x)
    timeits.append(result)

In [None]:
import matplotlib.pyplot as plt

tflops = [get_flops(b, INPUT_SIZE, LAYER_SIZE, t) / 1e12 for b, t in zip(batch_sizes, [t.average for t in timeits])]
plt.plot(batch_sizes, tflops);

# Sharding `x`

## Using `device` argument in `jnp.ones`

In [None]:
x = jnp.ones((BATCH_SIZE, INPUT_SIZE), device=NamedSharding(mesh, PartitionSpec('x', 'y')))
jax.debug.visualize_array_sharding(x)

## Using `jax.jit` and `NamedSharding`

In [None]:
initializer = nn.initializers.normal() # nn.initializers.constant(1)
initializer_sharding = NamedSharding(mesh, PartitionSpec('x', 'y'))
initializer_sharded = jax.jit(functools.partial(initializer, prng_key, (BATCH_SIZE, INPUT_SIZE), dtype=float), out_shardings=initializer_sharding)
x = initializer_sharded()
jax.debug.visualize_array_sharding(x)

## Using `jax.jit` and `jax.eval_shape`

In [None]:
initializer = nn.initializers.normal()
initializer

In [None]:
initializer_partitioned = nn.with_partitioning(initializer, ('x', 'y'))
initializer_partitioned

In [None]:
initializer_shape = jax.eval_shape(functools.partial(initializer_partitioned, prng_key, (BATCH_SIZE, INPUT_SIZE)))
initializer_shape

In [None]:
initializer_sharding = nn.get_sharding(initializer_shape, mesh)
initializer_sharding

In [None]:
x = jax.jit(functools.partial(initializer, prng_key, (BATCH_SIZE, INPUT_SIZE), dtype=float), out_shardings=initializer_sharding)()
jax.debug.visualize_array_sharding(x)

In [None]:
y = model.apply(params, x)
jax.debug.visualize_array_sharding(y)