Skip to content

Commit

Permalink
[export] Fix handling of float0 when exporting
Browse files Browse the repository at this point in the history
There were two problems:
  * the float0 dtype was not part of the schema,
  * there was a bug invoking jax.vjp on a reloaded
    function, because of a mismatch between the type
    of symbolic zeros.

We changed the schema to add `f0`, but we add that
enum with a value larger than existing values, to
preserve backwards compatibility.
  • Loading branch information
gnecula committed Dec 18, 2023
1 parent 259c285 commit 7aba11f
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 46 deletions.
13 changes: 12 additions & 1 deletion jax/experimental/export/export.py
Expand Up @@ -31,9 +31,11 @@
import jax
from jax import sharding

from jax._src import ad_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
Expand Down Expand Up @@ -1067,7 +1069,16 @@ def f_flat_vjp_fwd(*args_flat):
def f_flat_vjp_bwd(residual, ct_res_flat):
args_flat = residual # residual is the primal argument flat tuple
exp_vjp = exported.vjp()
in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_flat)
# ct_res_flat may contain arrays of zeros where exp_vjp expect float0.
# We make the proper arrays of float0 to invoke exp_vjp.
def fix_float0_ct(ct_res, expected_aval):
if expected_aval.dtype != dtypes.float0:
return ct_res
return ad_util.zeros_like_aval(expected_aval)

ct_res_fixed = map(fix_float0_ct,
ct_res_flat, exp_vjp.in_avals[len(args_flat):])
in_ct_flat = call_exported(exp_vjp)(*args_flat, *ct_res_fixed)
return in_ct_flat

f_flat.defvjp(f_flat_vjp_fwd, f_flat_vjp_bwd)
Expand Down
59 changes: 33 additions & 26 deletions jax/experimental/export/serialization.fbs
Expand Up @@ -42,36 +42,38 @@ enum AbstractValueKind: byte {
}

enum DType: byte {
bool,
i8,
i16,
i32,
i64,
ui8,
ui16,
ui32,
ui64,
f16,
f32,
f64,
c64,
c128,

bf16,

i4,
ui4,

f8_e4m3b11fnuz,
f8_e4m3fn,
f8_e4m3fnuz,
f8_e5m2,
f8_e5m2fnuz,
// Last used id: 22
bool = 0,
i8 = 1,
i16 = 2,
i32 = 3,
i64 = 4,
ui8 = 5,
ui16 = 6,
ui32 = 7,
ui64 = 8,
f0 = 22, // Used in JAX to represent float0
f16 = 9,
f32 = 10,
f64 = 11,
c64 = 12,
c128 = 13,

bf16 = 14,

i4 = 15,
ui4 = 16,

f8_e4m3b11fnuz = 17,
f8_e4m3fn = 18,
f8_e4m3fnuz = 19,
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
}

table AbstractValue {
kind: AbstractValueKind;
shape: [string]; // we support shape polymorphism
shape: [string]; // Support shape polymorphism
dtype: DType;
}

Expand Down Expand Up @@ -101,6 +103,11 @@ table DisabledSafetyCheck {
}

table Exported {
/// We increment the serialization version every time we change the
/// schema, even if the change is backwards compatible.
/// Note that this field has different semantics and purpose from
/// `mlir_module_serialization_version`, which encodes
/// the calling convention of the `mlir_module_serialized`.
serialization_version: uint16;

function_name: string;
Expand Down
10 changes: 8 additions & 2 deletions jax/experimental/export/serialization.py
Expand Up @@ -36,6 +36,11 @@
T = TypeVar("T")
SerT = TypeVar("SerT")

# The _SERIALIZATION_VERSION changes when we change the serialization schema
# even if the change is backwards compatible.
# Version 1, Nov 2023, first version.
# Version 2, Dec 16th, 2023, adds the f0 dtype.
_SERIALIZATION_VERSION = 2

def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray:
"""Serialize an Exported.
Expand Down Expand Up @@ -102,7 +107,7 @@ def _serialize_exported(
vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1)

ser_flatbuf.ExportedStart(builder)
ser_flatbuf.ExportedAddSerializationVersion(builder, 1)
ser_flatbuf.ExportedAddSerializationVersion(builder, _SERIALIZATION_VERSION)
ser_flatbuf.ExportedAddFunctionName(builder, fun_name)
ser_flatbuf.ExportedAddInTree(builder, in_tree)
ser_flatbuf.ExportedAddInAvals(builder, in_avals)
Expand Down Expand Up @@ -142,7 +147,7 @@ def _serialize_array(

def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported:
serialization_version = exp.SerializationVersion()
if serialization_version != 1:
if serialization_version != _SERIALIZATION_VERSION:
raise NotImplementedError(
f"deserialize unsupported version {serialization_version}"
)
Expand Down Expand Up @@ -296,6 +301,7 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
np.dtype("uint16"): ser_flatbuf.DType.ui16,
np.dtype("uint32"): ser_flatbuf.DType.ui32,
np.dtype("uint64"): ser_flatbuf.DType.ui64,
dtypes.float0: ser_flatbuf.DType.f0,
np.dtype("float16"): ser_flatbuf.DType.f16,
np.dtype("float32"): ser_flatbuf.DType.f32,
np.dtype("float64"): ser_flatbuf.DType.f64,
Expand Down
28 changes: 17 additions & 11 deletions jax/experimental/export/serialization_generated.py
Expand Up @@ -20,20 +20,20 @@
from flatbuffers.compat import import_numpy
np = import_numpy()

class PyTreeDefKind:
class PyTreeDefKind(object):
leaf = 0
none = 1
tuple = 2
list = 3
dict = 4


class AbstractValueKind:
class AbstractValueKind(object):
shapedArray = 0
abstractToken = 1


class DType:
class DType(object):
bool = 0
i8 = 1
i16 = 2
Expand All @@ -56,20 +56,21 @@ class DType:
f8_e4m3fnuz = 19
f8_e5m2 = 20
f8_e5m2fnuz = 21
f0 = 22


class ShardingKind:
class ShardingKind(object):
unspecified = 0
hlo_sharding = 1


class DisabledSafetyCheckKind:
class DisabledSafetyCheckKind(object):
platform = 0
custom_call = 1
shape_assertions = 2


class PyTreeDef:
class PyTreeDef(object):
__slots__ = ['_tab']

@classmethod
Expand Down Expand Up @@ -161,7 +162,7 @@ def PyTreeDefEnd(builder):



class AbstractValue:
class AbstractValue(object):
__slots__ = ['_tab']

@classmethod
Expand Down Expand Up @@ -233,7 +234,7 @@ def AbstractValueEnd(builder):



class Sharding:
class Sharding(object):
__slots__ = ['_tab']

@classmethod
Expand Down Expand Up @@ -302,7 +303,7 @@ def ShardingEnd(builder):



class Effect:
class Effect(object):
__slots__ = ['_tab']

@classmethod
Expand Down Expand Up @@ -338,7 +339,7 @@ def EffectEnd(builder):



class DisabledSafetyCheck:
class DisabledSafetyCheck(object):
__slots__ = ['_tab']

@classmethod
Expand Down Expand Up @@ -384,7 +385,7 @@ def DisabledSafetyCheckEnd(builder):



class Exported:
class Exported(object):
__slots__ = ['_tab']

@classmethod
Expand All @@ -402,6 +403,11 @@ def GetRootAsExported(cls, buf, offset=0):
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# We increment the serialization version every time we change the
# schema, even if the change is backwards compatible.
# Note that this field has different semantics and purpose from
# `mlir_module_serialization_version`, which encodes
# the calling convention of the `mlir_module_serialized`.
# Exported
def SerializationVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
Expand Down
51 changes: 45 additions & 6 deletions tests/export_test.py
Expand Up @@ -141,12 +141,12 @@ def _testing_multi_platform_fun_expected(x,
]


def get_exported(fun, max_vjp_orders=0,
def get_exported(fun, vjp_order=0,
**export_kwargs):
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = serialization.serialize(exp, vjp_order=max_vjp_orders)
serialized = serialization.serialize(exp, vjp_order=vjp_order)
return serialization.deserialize(serialized)
return serde_exported

Expand Down Expand Up @@ -343,15 +343,15 @@ def test_primitive_lowering(ctx, arg):
def test_grad(self):
f = lambda x: jnp.sum(jnp.sin(x))
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f, max_vjp_orders=1)(x)
exp_f = get_exported(f, vjp_order=1)(x)

f1 = export.call_exported(exp_f)
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))

def test_higher_order_grad(self):
f = lambda x: x ** 3
x = np.float32(4.)
exp_f = get_exported(f, max_vjp_orders=3)(x)
exp_f = get_exported(f, vjp_order=3)(x)

f1 = export.call_exported(exp_f)
self.assertAllClose(jax.grad(f)(x),
Expand All @@ -361,14 +361,53 @@ def test_higher_order_grad(self):
self.assertAllClose(jax.grad(jax.grad(jax.grad(f)))(x),
jax.grad(jax.grad(jax.grad(f1)))(x))

def test_grad_int(self):
def f(xi, xf):
return (2 * xi.T, xf.T * xf.T)

xi = np.arange(6, dtype=np.int32).reshape((2, 3))
xf = np.arange(12, dtype=np.float32).reshape((3, 4))

# Native JAX 1st order vjp
(f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf)
f_outi_ct = np.ones(f_outi.shape, dtype=f_outi.dtype)
f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype)
xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct))

# Native JAX 2nd order vjp
res, f_vjp2 = jax.vjp(f_vjp, (f_outi_ct, f_outf_ct))
self.assertAllClose(res, (xi_ct, xf_ct))
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))

exp = get_exported(f, vjp_order=2)(xi, xf)
fr = export.call_exported(exp)

res = fr(xi, xf)
self.assertAllClose(res, (f_outi, f_outf))

# Reloaded 1st order vjp
(fr_outi, fr_outf), fr_vjp = jax.vjp(fr, xi, xf)
self.assertAllClose(fr_outi, f_outi)
self.assertAllClose(fr_outf, f_outf)
xri_ct, xrf_ct = fr_vjp((f_outi_ct, f_outf_ct))
self.assertAllClose(xri_ct, xi_ct)
self.assertAllClose(xrf_ct, xf_ct)

# Reloaded 2nd order vjp
res, f_vjp2 = jax.vjp(fr_vjp, (f_outi_ct, f_outf_ct))
self.assertAllClose(res, (xi_ct, xf_ct))
(fr_outi_ct2, fr_outf_ct2), = f_vjp2((xi_ct, xf_ct))
self.assertAllClose(fr_outi_ct2, f_outi_ct2)
self.assertAllClose(fr_outf_ct2, f_outf_ct2)

def test_pytree_vjp(self):
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=2. * a, b=3. * b),
jnp.sin(4. * a))

a = np.arange(4, dtype=np.float32)
b = np.arange(6, dtype=np.float32)
exp_f = get_exported(f, max_vjp_orders=1)((a, b), a=a, b=b)
exp_f = get_exported(f, vjp_order=1)((a, b), a=a, b=b)

out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
def f1_jax(a, b): # For VJP, make a function without kwargs
Expand Down Expand Up @@ -902,7 +941,7 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]
if with_mesh:
stack.enter_context(mesh)
# Serialize higher-order gradiends
exp = get_exported(f_jax_pjit, max_vjp_orders=2)(x)
exp = get_exported(f_jax_pjit, vjp_order=2)(x)
exp_vjp = exp.vjp()
# Try 2nd order grad as well
exp_vjp2 = exp_vjp.vjp()
Expand Down

0 comments on commit 7aba11f

Please sign in to comment.