1. Pad sizes with 8
2. bfloat16 represents the same range as 16
3. Summation, exp, log should use full precision

https://on-demand.gputechconf.com/gtcdc/2019/pdf/dc91247-automatic-mixed-precision-in-tensorflow.pdf
https://www.tensorflow.org/guide/mixed_precision
https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21929-tensor-core-performance-on-nvidia-gpus-the-ultimate-guide.pdf
https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn


# jax.config.update('jax_enable_x64', True)
# jax.config.update('jax_platform_name', 'cpu')
key = jax.random.PRNGKey(0)

In [32]:
jax.lax.Precision.DEFAULT

<Precision.DEFAULT: 0>

In [29]:
m = nn.Dense(5, dtype=jnp.bfloat16)
x = jnp.zeros(4, dtype=jnp.bfloat16)
p = m.init(key, x)

jax.make_jaxpr(m.apply)(p, x)

{ lambda ; a:f32[5] b:f32[4,5] c:bf16[4]. let
    d:bf16[4,5] = convert_element_type[new_dtype=bfloat16 weak_type=False] b
    e:bf16[5] = convert_element_type[new_dtype=bfloat16 weak_type=False] a
    f:bf16[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] c d
    g:bf16[5] = add f e
  in (g,) }

In [30]:
m = nn.Dense(5, dtype=jnp.bfloat16, precision=jax.lax.Precision.DEFAULT)
x = jnp.zeros(4, dtype=jnp.bfloat16)
p = m.init(key, x)

jax.make_jaxpr(m.apply)(p, x)

{ lambda ; a:f32[5] b:f32[4,5] c:bf16[4]. let
    d:bf16[4,5] = convert_element_type[new_dtype=bfloat16 weak_type=False] b
    e:bf16[5] = convert_element_type[new_dtype=bfloat16 weak_type=False] a
    f:bf16[5] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      precision=(<Precision.DEFAULT: 0>, <Precision.DEFAULT: 0>)
    ] c d
    g:bf16[5] = add f e
  in (g,) }

In [33]:
m = nn.Dense(5, dtype=jnp.bfloat16, precision=jax.lax.Precision.HIGHEST)
x = jnp.zeros(4, dtype=jnp.bfloat16)
p = m.init(key, x)

jax.make_jaxpr(m.apply)(p, x)

{ lambda ; a:f32[5] b:f32[4,5] c:bf16[4]. let
    d:bf16[4,5] = convert_element_type[new_dtype=bfloat16 weak_type=False] b
    e:bf16[5] = convert_element_type[new_dtype=bfloat16 weak_type=False] a
    f:bf16[5] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      precision=(<Precision.HIGHEST: 2>, <Precision.HIGHEST: 2>)
    ] c d
    g:bf16[5] = add f e
  in (g,) }