Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def quantize_dequantize(x, q_dtype, scale, compute_dtype):


def compute_scale(amax, scale, fp8_max, margin=0):
"""Default function to convert amax to scaling factor."""
# This function copied from the TransformerEngine is used to compute its
# `scale`. However, our scale matches its `scale_inv` concept. So, we apply
# the reciprocal operation at the entry and exit of the function.
# The algorithm for computing the new scale is sourced from
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas
# wherein the `original_scale` corresponds to the reciprocal of the `scale`
# passed in this function.
scale = 1.0 / scale
exp = jnp.floor(jnp.log2(fp8_max / amax)) - margin
sf = jnp.round(lax.pow(2., jnp.abs(exp)))

sf = (fp8_max / amax) / (2**margin)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(lax.is_finite(amax), sf, scale)
sf = jnp.where(exp < 0, 1.0 / sf, sf)
sf = jnp.where(jnp.isfinite(amax), sf, scale)

return 1.0 / sf


Expand Down Expand Up @@ -155,7 +155,7 @@ def setup(self) -> None:
OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args)


def __call__(self, *args, **kwargs) -> jnp.ndarray:
def __call__(self, *args, **kwargs):

assert len(args) == 3
x = args[0]
Expand Down