Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 147 additions & 92 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def quantize(x, q_dtype, scale, compute_dtype):
def dequantize(x, dq_dtype, scale):
return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape)


def quantize_dequantize(x, q_dtype, scale, compute_dtype):
def qdq(x, q_dtype, scale, compute_dtype):
qx = quantize(x, q_dtype, scale, compute_dtype)
return dequantize(qx, x.dtype, scale)

Expand All @@ -165,8 +164,8 @@ def compute_amax_history(x, amax_history):
return new_history


def quantize_and_update(
x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False
def update_fp8_meta(
x, q_dtype, scale, amax_history
):
is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32)
# convert fm32->f32 so we can do math
Expand All @@ -181,20 +180,20 @@ def quantize_and_update(
new_scale = compute_scale(amax_from_history, scale, dtype_max)
new_history = compute_amax_history(x, amax_history)

# convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly
if is_fmax32:
new_history = lax.convert_element_type(new_history, fp32_max_grad)
new_scale = lax.convert_element_type(new_scale, fp32_max_grad)

# Quantize the input
if not use_direct_quant:
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)
return qx, new_scale, new_history

return new_scale, new_history

def quantize_dequantize_update(x, q_dtype, scale, amax_history, compute_dtype):
updated_scale, updated_history = update_fp8_meta(x, q_dtype, scale, amax_history)
qdq_x = qdq(x, q_dtype, _fm32_to_float32(updated_scale), compute_dtype)
return qdq_x, updated_scale, updated_history

return qx, new_scale, new_history
def _fm32_to_float32(value):
if value.dtype == fm32:
return lax.convert_element_type(value, jnp.float32)
return value

def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
Expand Down Expand Up @@ -242,14 +241,14 @@ def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,

@partial(custom_vjp, nondiff_argnums=(0, 1))
def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history):
qin, _, _ = quantize_and_update(
qin, _, _ = quantize_dequantize_update(
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin


def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
qin, new_scale, new_history = quantize_and_update(
qin, new_scale, new_history = quantize_dequantize_update(
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin, (new_scale, new_history)
Expand All @@ -275,7 +274,7 @@ def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history):

def out_qdq_bwd(compute_dtype, q_dtype, res, g):
scale, amax_history = res
q_g, new_scale, new_history = quantize_and_update(
q_g, new_scale, new_history = quantize_dequantize_update(
g, q_dtype, scale, amax_history, compute_dtype
)
return q_g, new_scale, new_history
Expand All @@ -284,91 +283,103 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g):
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


def q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training
):
new_lhs_scale, new_lhs_amax_history = quantize_and_update(
lhs,
jnp.float8_e4m3fn,
lhs_scale,
lhs_amax_history,
compute_dtype,
use_direct_quant=True
)
new_rhs_scale, new_rhs_amax_history = quantize_and_update(
rhs,
jnp.float8_e4m3fn,
rhs_scale,
rhs_amax_history,
compute_dtype,
use_direct_quant=True
@partial(custom_vjp, nondiff_argnums=(0, 1))
def in_q(compute_dtype, q_dtype, inp, scale, amax_history):
new_scale, _ = update_fp8_meta(inp, q_dtype, scale, amax_history)
qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype)
return qin, new_scale

def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
new_scale, new_history = update_fp8_meta(inp, q_dtype, scale, amax_history)
qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype)
return (qin, new_scale), (new_scale, new_history)

def in_q_bwd(compute_dtype, q_dtype, res, _):
new_scale, new_history = res
# We don't compute gradients for inp, scale and amax_history, but we pass through scale and history
return None, new_scale, new_history

in_q.defvjp(in_q_fwd, in_q_bwd)


@partial(custom_vjp, nondiff_argnums=(0, ))
def out_dq(dq_type, lhs_scale, rhs_scale, out):
q_out = dequantize(
out,
dq_type,
_fm32_to_float32(lhs_scale) * _fm32_to_float32(rhs_scale)
)
return q_out

def out_dq_fwd(dq_type, lhs_scale, rhs_scale, out):
return out_dq(dq_type, lhs_scale, rhs_scale, out), None

def out_dq_bwd(dq_type, _, g):
return None, None, g

out_dq.defvjp(out_dq_fwd, out_dq_bwd)

q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type)
q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type)

def quantized_dot_impl(
lhs,
q_lhs,
lhs_scale, # actualy new lhs scale
rhs,
q_rhs, # actualy new rhs scale
rhs_scale,
out_grad_scale, # old out grad scale
out_grad_amax_history, # old out grad amax history
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training
):
out = lax.dot_general(
q_lhs,
q_rhs,
dimension_numbers,
preferred_element_type=preferred_element_type,
precision=lax.Precision.DEFAULT,
)

out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale)
if is_training:
res = (
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
new_lhs_scale,
new_rhs_scale,
rhs_scale,
out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
out_grad_amax_history,
)
return out, res
else:
return out


@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11))
def q_dot_dq(
def quantized_dot(
lhs,
q_lhs,
lhs_scale, # actualy new lhs scale
rhs,
lhs_scale,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
out_grad_scale, # old out grad scale
out_grad_amax_history, # old out grad amax history
compute_dtype,
dimension_numbers,
precision=None,
preferred_element_type=None
):
return q_dot_dq_impl(
return quantized_dot_impl(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
Expand All @@ -377,29 +388,28 @@ def q_dot_dq(
is_training=False,
)


def q_dot_dq_fwd(
def quantized_dot_fwd(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
):
return q_dot_dq_impl(
return quantized_dot_impl(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
Expand All @@ -408,8 +418,7 @@ def q_dot_dq_fwd(
is_training=True
)


def q_dot_dq_bwd(
def quantized_dot_bwd(
compute_dtype,
dimension_numbers,
precision,
Expand All @@ -419,27 +428,23 @@ def q_dot_dq_bwd(
):
(
lhs,
rhs,
q_lhs,
lhs_scale,
rhs,
q_rhs,
new_lhs_scale,
new_rhs_scale,
rhs_scale,
out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
out_grad_amax_history,
) = res

new_out_grad_scale, new_out_grad_amax_history = quantize_and_update(
new_out_grad_scale, new_out_grad_amax_history = update_fp8_meta(
g,
jnp.float8_e5m2,
out_grad_scale,
out_grad_amax_history,
compute_dtype,
use_direct_quant=True
)

q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type)
q_g = quantize(g, jnp.float8_e5m2, _fm32_to_float32(new_out_grad_scale), preferred_element_type)

grad_lhs = dot_general_transpose_lhs(
q_g,
Expand All @@ -449,7 +454,11 @@ def q_dot_dq_bwd(
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)
grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale)
grad_lhs = dequantize(
grad_lhs,
preferred_element_type,
_fm32_to_float32(rhs_scale) * _fm32_to_float32(new_out_grad_scale)
)

grad_rhs = dot_general_transpose_rhs(
q_g,
Expand All @@ -459,21 +468,67 @@ def q_dot_dq_bwd(
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)
grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale)
grad_rhs = dequantize(
grad_rhs,
preferred_element_type,
_fm32_to_float32(lhs_scale) * _fm32_to_float32(new_out_grad_scale)
)

return (
grad_lhs,
None,
None,
grad_rhs,
new_lhs_scale,
new_rhs_scale,
None,
None,
new_out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
new_out_grad_amax_history,
)

q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd)
quantized_dot.defvjp(quantized_dot_fwd, quantized_dot_bwd)

# Convenience wrappers for the quantize-dot-dequantize
def q_dot_dq(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision=None,
preferred_element_type=None
):
q_lhs, new_lhs_scale = in_q(
compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history
)
q_rhs, new_rhs_scale = in_q(
compute_dtype, jnp.float8_e4m3fn, rhs, rhs_scale, rhs_amax_history
)
y = quantized_dot(
lhs,
q_lhs,
new_lhs_scale,
rhs,
q_rhs,
new_rhs_scale,
out_grad_scale,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type
)
y = out_dq(
dq_type=preferred_element_type,
lhs_scale=new_lhs_scale,
rhs_scale=new_rhs_scale,
out=y
)
return y # type: ignore

@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(
Expand Down
Loading
Loading