1. Pad sizes with 8
2. bfloat16 represents the same range as 16
3. Summation, exp, log should use full precision

https://on-demand.gputechconf.com/gtcdc/2019/pdf/dc91247-automatic-mixed-precision-in-tensorflow.pdf
https://www.tensorflow.org/guide/mixed_precision
https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21929-tensor-core-performance-on-nvidia-gpus-the-ultimate-guide.pdf
https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn


jax.config.update('jax_enable_x64', True)
# jax.config.update('jax_platform_name', 'cpu')
key = jax.random.PRNGKey(0)

I0000 00:00:1699456257.913252   28069 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [11]:
N = 1 << 12
dtype = jnp.float16
x = jax.device_put(jnp.zeros((N,), dtype=dtype))
ln = nn.Dense(N, dtype=dtype, param_dtype=dtype)

In [12]:
out, params = ln.init_with_output(key, x)
fn = jax.jit(ln.apply)
fn(params, x)
out.dtype, jax.tree_map(lambda x: x.dtype, params)

(dtype('float16'),
 {'params': {'bias': dtype('float16'), 'kernel': dtype('float16')}})

In [10]:
%%timeit
fn(params, x).block_until_ready()

313 µs ± 12.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
%%timeit
fn(params, x).block_until_ready()

181 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
M = jnp.zeros((1024, 1024), dtype=jnp.float16)

In [11]:
%%timeit
jnp.sum(M @ M).block_until_ready()

551 µs ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
%%timeit
jnp.sum(M @ M).block_until_ready()

357 µs ± 39.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
