# JAX Scalify: Quickstart on end-to-end scaled arithmetic

**JAX Scalify** is a library implemeting general scaled arithmetic in JAX, allowing end-to-end scale propagation in computational graphs and easy training/inference of deep neural networks in low precision (mainly FP16 & FP8).

JAX Scalify supports converting any computational graph into a scaled computational graph, i.e. a function with `ScaledArray` inputs/outputs.

```python
@dataclass
class ScaledArray:
    # Main `data` component, with "low precision"
    data: Array
    # Scale, usually scalar, represented in E8M0 or FP32.
    scale: Array
```
It fully decouples scale propagation from model definition, allowing easy experimentation and debugging with low precision formats such as FP16 and FP8.

## Scaled array representation

In Scalify, every tensor is a `ScaledArray`. This systematic approach simplifies the use of FP16 and FP8 in LLM training, decoupling the scale and numerical stability questions from the high-level model definition. 

Below is presented the basics of `ScaledArray` construction, and how it helps representing very large or very small tensors.

In [226]:
import numpy as np
import numpy.testing as npt

import jax
import jax.numpy as jnp
import jax_scalify as jsa

In [227]:
# Let's start at the beginning: convert an array to a ScaledArray.
a = np.array([1, 2], np.float16)
# Analogue of `np.asarray`, with in addition passing of the scale to use.
# NOTE: scale dtype does not have to match core data dtype. Usually using np.float32
sa = jsa.as_scaled_array(a, scale=np.float32(1))
assert isinstance(sa, jsa.ScaledArray)

# `a` and `sa` represent the same formal tensor.
print("Normal `a`:", a)
print("Scaled `a`:", sa, " ~ ", np.asarray(sa))

Normal `a`: [1. 2.]
Scaled `a`: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0)  ~  [1. 2.]


In [228]:
# Scalify preserves the semantics of arrays and computational graphs.
# Passing a different scale does not change the "value" of a represented tensor.
sa = jsa.as_scaled_array(a, scale=np.float32(0.5))
# `a` and `sa` still represent the same formal tensor.
print("Normal `a`:", a)
print("Scaled `a`:", sa, " ~ ", np.asarray(sa))

Normal `a`: [1. 2.]
Scaled `a`: ScaledArray(data=array([2., 4.], dtype=float16), scale=0.5)  ~  [1. 2.]


In [229]:
# Why using Scaled Arrays? => representing very "small" or "large" tensor.
# Large FP32 tensor.
a = np.array([2, 4], np.float32) * 256**2
# Overflowing to Inf in FP16
a_fp16 = a.astype(np.float16)
# "Properly" represented with a large scale.
sa_fp16 = jsa.as_scaled_array(a, scale=np.float32(256**2 * 2)).astype(np.float16)

print("<< Scaled Arrays with large values >>")
print("Normal `a` FP32:", a)
print("Normal `a` FP16:", a_fp16)
# FP16 scaled representation does not overflow. 
print("Scaled `a` FP16:", sa_fp16, " ~ ", np.asarray(sa_fp16, dtype=np.float32))

a = np.array([2, 4], np.float32) * (256*32)**-2
a_fp16 = a.astype(np.float16)
sa_fp16 = jsa.as_scaled_array(a, scale=np.float32(256**-2)).astype(np.float16)

print("\n<< Scaled Arrays with small values >>")
print("Normal `a` FP32:", a)
# Underflowing + loss of precision in sub-normals representation.
print("Normal `a` FP16:", a_fp16)
# FP16 scaled representation does not underflow.
# NOTE: scale factor does not need to be "perfect" to keep accurate representation.
print("Scaled `a` FP16:", sa_fp16, " ~ ", np.asarray(sa_fp16, dtype=np.float32))

<< Scaled Arrays with large values >>
Normal `a` FP32: [131072. 262144.]
Normal `a` FP16: [inf inf]
Scaled `a` FP16: ScaledArray(data=array([1., 2.], dtype=float16), scale=131072.0)  ~  [131072. 262144.]

<< Scaled Arrays with small values >>
Normal `a` FP32: [2.9802322e-08 5.9604645e-08]
Normal `a` FP16: [0.e+00 6.e-08]
Scaled `a` FP16: ScaledArray(data=array([0.001953, 0.003906], dtype=float16), scale=1.5258789e-05)  ~  [2.9802322e-08 5.9604645e-08]


  a_fp16 = a.astype(np.float16)


### Scaled array and FP8 formats

How does it work with FP8? The same `:)`
Last generation GPUs supports two FP8 formats define by the OCP FP8 specification (https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1):
* `float8_e4m3fn`: 4 exponent bits, 3 mantissa bits, no infinity;
* `float8_e5m2fnuz`: 5 exponent bits, 2 mantissa bits, with infinity;

**Note:** there is still on-going IEEE standardization work on FP8 formats (see https://github.com/P3109/Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20Report.pdf). 

In [230]:
# FP8 formats properties
print("FP8-E4M3:", jnp.finfo(jnp.float8_e4m3fn))
print("FP8-E5M2:", jnp.finfo(jnp.float8_e5m2fnuz))

FP8-E4M3: Machine parameters for float8_e4m3fn
---------------------------------------------------------------
precision =   1   resolution = 1.00e-01
machep =     -3   eps =        1.25e-01
negep =      -4   epsneg =     6.25e-02
minexp =     -6   tiny =       1.56e-02
maxexp =      9   max =        4.48e+02
nexp =        4   min =        -max
smallest_normal = 1.56e-02   smallest_subnormal = 1.95e-03
---------------------------------------------------------------

FP8-E5M2: Machine parameters for float8_e5m2fnuz
---------------------------------------------------------------
precision =   1   resolution = 1.00e-01
machep =     -2   eps =        2.50e-01
negep =      -3   epsneg =     1.25e-01
minexp =    -15   tiny =       3.05e-05
maxexp =     16   max =        5.73e+04
nexp =        5   min =        -max
smallest_normal = 3.05e-05   smallest_subnormal = 7.63e-06
---------------------------------------------------------------



In [231]:

a = jnp.array([400, 448, 512], np.float32)
# Overflowing to NaN as no Inf available.
a_fp8_e4m3 = a.astype(jnp.float8_e4m3fn)
# Scaled representation, without overflowing.
as_fp8_e4m3 = jsa.as_scaled_array(a, scale=np.float32(128)).astype(jnp.float8_e4m3fn)

print("Normal `a` FP32:", a)
# NOTE: the loss of precision due to 3-bit mantissa.
print("Normal `a` FP8-E4M3:", a_fp8_e4m3)
print("Scaled `a` FP8-E4M3:", as_fp8_e4m3, " ~ ", np.asarray(as_fp8_e4m3, dtype=np.float32))

Normal `a` FP32: [400. 448. 512.]
Normal `a` FP8-E4M3: [384 448 nan]
Scaled `a` FP8-E4M3: ScaledArray(data=Array([3, 3.5, 4], dtype=float8_e4m3fn), scale=128.0)  ~  [384. 448. 512.]


In [252]:
import ml_dtypes
# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.
# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes
sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))

@jsa.scalify
def cast_fn(v):
    return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)

sc_fp8 = cast_fn(sc)
print("Scaled input in FP32:", sc)
# NOTE: still using FP32 (or FP16) as underlying storage.
print("Pseudo-cast to ML dtypes:", sc_fp8)

Scaled input in FP32: ScaledArray(data=array([8.5, 9.5]), scale=2.0)
Pseudo-cast to ML dtypes: ScaledArray(data=Array([ 8., 10.], dtype=float32), scale=2.0)


## `scalify` transform: end-to-end scale propagation

The `scalify` transform performs end-to-end scale propagation, with application of "unit-scaling" type rules. `scalify` for now only supports a subset of [LAX operators](../docs/operators.md), and will raise an error if unsupported operations are used in the computational graph.

In [232]:
# `scalify` transform is tracing the graph, adding scale propagation where necessary.
@jsa.scalify
def fn(a, b):
    return a + b

In [233]:
# Let's start with standard JAX inputs
a = np.array([1, 2], np.float16)
b = np.array([3, 6], np.float16)
# The function `fn` is unchanged with unscaled inputs. 
out = fn(a, b)

print("INPUTS:", a, b)
# "Unscaled" inputs => "normal" JAX mode with unscaled outputs
print("OUTPUT:", out, out.dtype, type(out))

INPUTS: [1. 2.] [3. 6.]
OUTPUT: [4. 8.] float16 <class 'jaxlib.xla_extension.ArrayImpl'>


In [234]:
# Let's create input scaled arrays.
sa = jsa.as_scaled_array(a, scale=np.float32(2))
sb = jsa.as_scaled_array(b, scale=np.float32(4))

print(f"Scaled inputs:\n\t{sa}\n\t{sb}")
# `as_scaled_array` does not change the semantic: same underlying tensor represented.
print("Equivalent input arrays:", np.asarray(sa), np.asarray(sb))

Scaled inputs:
	ScaledArray(data=array([0.5, 1. ], dtype=float16), scale=2.0)
	ScaledArray(data=array([0.75, 1.5 ], dtype=float16), scale=4.0)
Equivalent input arrays: [1. 2.] [3. 6.]


In [235]:
# Running `fn` on scaled arrays triggers `scalify` graph transformation & scale propagtion
sout = fn(sa, sb)
# NOTE: by default, scale propagation is using power-of-2.
assert isinstance(sout, jsa.ScaledArray)
print("Scaled output:", sout)
print("Equivalent unscaled output:", np.asarray(sout))

# To choose a different scale rounding:
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):
    print("\nScaled output without scale rounding:", fn(sa, sb))

Scaled output: ScaledArray(data=Array([1., 2.], dtype=float16), scale=Array(4., dtype=float32))
Equivalent unscaled output: [4. 8.]

Scaled output without scale rounding: ScaledArray(data=Array([0.8945, 1.789 ], dtype=float16), scale=Array(4.472136, dtype=float32))


### Why using unit-scaling rules in `scalify`?

We present in this section how unit-scaling rules implemented in `scalify` are propagating optimally scaling in operations. We show a simple example of FP16 matrix multiplication where `scalify` avoids overflowing (compared to normal model).

In [236]:
# `scalify` scale propagation is using `unit-scaling` static scale propagation rules.
@jsa.scalify
def matmul_fn(a, b):
    return a @ b

In [237]:
# Large reduction axis K.
M, N, K = 16, 8, 1024
ascale = 128
bscale = 64
# IID Gaussian inputs.
a = np.random.randn(M, K).astype(np.float32) * ascale
b = np.random.randn(K, N).astype(np.float32) * bscale

# The function `fn` is unchanged with unscaled inputs. 
out = matmul_fn(a, b)

print("INPUTS std:", np.std(a), np.std(b))
# Large matmul output standard deviation.
print("OUTPUT std:", np.std(out))

INPUTS std: 128.61899 62.915386
OUTPUT std: 251969.02


In [238]:
# Let's create equivalent input scaled arrays.
sa = jsa.as_scaled_array(a, scale=np.float32(ascale))
sb = jsa.as_scaled_array(b, scale=np.float32(bscale))

# Scale propagation in `matmul`
sout = matmul_fn(sa, sb)

print("INPUTS data std:", np.std(sa.data), np.std(sb.data))
print("INPUTS scales:", sa.scale, sb.scale)
# Large scale is getting incorporated in `scale`, with main `data` std ~ 1.
print("OUTPUT data std and scale:", np.std(sout.data), sout.scale)

INPUTS data std: 1.0048358 0.9830529
INPUTS scales: 128.0 64.0
OUTPUT data std and scale: 0.9611855 262144.0


In [239]:
# How about the same matmul in FP16
a_fp16 = a.astype(np.float16)
b_fp16 = b.astype(np.float16)
out_fp16 = matmul_fn(a_fp16, b_fp16)

# Finite inputs, but overflowing output.
print("Are INPUTS finite?", np.all(np.isfinite(a_fp16)), np.all(np.isfinite(b_fp16)))
print("How many OUTPUT values finite? (vs nb entries)", np.sum(np.isfinite(out_fp16)), out_fp16.size)


Are INPUTS finite? True True
How many OUTPUT values finite? (vs nb entries) 28 128


In [240]:
# Let's create equivalent input scaled arrays.
sa_fp16 = sa.astype(np.float16)
sb_fp16 = sb.astype(np.float16)

# Scale propagation in `matmul` FP16 => not overflowing.
sout_fp16 = matmul_fn(sa_fp16, sb_fp16)

print("INPUTS data std:", np.std(sa_fp16.data), np.std(sb_fp16.data))
print("INPUTS scales:", sa_fp16.scale, sb_fp16.scale)
# Large scale is getting incorporated in `scale`, with main `data` std ~ 1.
print("OUTPUT data std and scale:", np.std(sout_fp16.data), sout_fp16.scale)

# Relative error vs FP32 matmul
rel_error = np.abs(np.asarray(sout_fp16, dtype=np.float32) - out) / out
print("Scalify FP16 matmul relative error (mean/max)", np.mean(rel_error), np.max(rel_error))

INPUTS data std: 1.005 0.983
INPUTS scales: 128.0 64.0
OUTPUT data std and scale: 0.961 262144.0
Scalify FP16 matmul relative error (mean/max) 0.00057976914 0.057415348


### `scalify` dynamic rescaling

As well known, the neural-network activations, weights and gradients do not follow perfect a Gaussian assumption. As a consequence, we provide in `scalify` a way to dynamically rescale tensor at any point in the computational graph. 

In [241]:
a = np.random.randn(1024).astype(np.float32) * 64
sa_in = jsa.as_scaled_array(a, scale=np.float32(4))

print("INPUT a:", np.std(a))
print("Static scaled INPUT a:", np.std(sa_in.data), sa_in.scale)


INPUT a: 65.71072
Static scaled INPUT a: 16.42768 4.0


In [242]:
# Dynamic rescaling of scaled array, using L2 norm (rounded to power-of-two).
sa_out = jsa.ops.dynamic_rescale_l2(sa_in)
print("Dynamic (re)scaled INPUT a:", np.std(sa_out.data), sa_out.scale)

# `dynamic_rescale` operations do not change the semantic of the tensor.
npt.assert_array_almost_equal(np.asarray(sa_out), a)

Dynamic (re)scaled INPUT a: 1.02673 64.0


In [245]:
# NOTE: in normal JAX mode, these rescale operations are no-ops:
jsa.ops.dynamic_rescale_max(a) is a

True