# JAX Scaled Arithmetics / AutoScale quickstart

**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8) with full scale propagation.

In [1]:
import numpy as np
import jax
import jax_scalify as jsa

In [2]:
# `scalify` interpreter is tracing the graph, adding scale propagation where necessary.
@jsa.scalify
def fn(a, b):
    return a + b

In [3]:
# Let's start with standard JAX inputs
a = np.array([1, 2], np.float16)
b = np.array([3, 6], np.float16)
out = fn(a, b)

print("INPUTS:", a, b)
# No scaled arithmetics => "normal" JAX mode.
print("OUTPUT:", out, type(out))



INPUTS: [1. 2.] [3. 6.]
OUTPUT: [4. 8.] <class 'jaxlib.xla_extension.DeviceArray'>


In [4]:
# Let's create input scaled arrays.
# NOTE: scale dtype does not have to match core data dtype.
sa = jsa.as_scaled_array(a, scale=np.float32(1))
sb = jsa.as_scaled_array(b, scale=np.float32(2))

print("SCALED inputs:", sa, sb)
# `as_scaled_array` does not change the value of tensor:
print("UNSCALED inputs:", np.asarray(sa), np.asarray(sb))

SCALED inputs: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ScaledArray(data=array([1.5, 3. ], dtype=float16), scale=2.0)
UNSCALED inputs: [1. 2.] [3. 6.]


In [5]:
# Running `fn` on scaled arrays triggers `scalify` graph transformation
sout = fn(sa, sb)
# NOTE: by default, scale propagation is using power-of-2.
print("SCALED OUTPUT:", sout)

# To choose a different scale rounding:
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):
    print("No scale rounding:", fn(sa, sb))

SCALED OUTPUT: ScaledArray(data=DeviceArray([2., 4.], dtype=float16), scale=DeviceArray(2., dtype=float32))
No scale rounding: ScaledArray(data=DeviceArray([1.789, 3.578], dtype=float16), scale=DeviceArray(2.236068, dtype=float32))


In [6]:
# JAX Scaled Arithmetics offers basic dynamic rescaling methods. e.g.: max, l1, l2
sout_rescaled = jsa.ops.dynamic_rescale_max(sout)
print("RESCALED OUTPUT:", sout_rescaled)

# Equivalent methods are available to dynamically rescale gradients:
jsa.ops.dynamic_rescale_l1_grad

RESCALED OUTPUT: ScaledArray(data=DeviceArray([0.5, 1. ], dtype=float16), scale=DeviceArray(8., dtype=float32))


functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7fc7b337c4c0>, <function dynamic_rescale_l1_base at 0x7fc7b3380430>)

In [7]:
# NOTE: in normal JAX mode, these rescale operations are no-ops:
jsa.ops.dynamic_rescale_max(a) is a

True

In [8]:
import ml_dtypes
# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.
# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes
sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))

@jsa.scalify
def cast_fn(v):
    return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)

sc_fp8 = cast_fn(sc)
print(sc_fp8)

ScaledArray(data=DeviceArray([16., 20.], dtype=float32), scale=1.0)
