Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ravel_pytree now produces jit-compatible unravel functions #13834

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 77 additions & 19 deletions jax/_src/flatten_util.py
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import warnings
from typing import Callable, Sequence

import numpy as np

from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import safe_zip, unzip2
from jax._src.tree_util import tree_flatten, tree_unflatten, PyTreeDef
from jax._src.util import safe_zip, unzip2, HashableArrayWrapper
from jax._src.typing import DType, Shape

import jax.numpy as jnp
from jax._src import dtypes
Expand All @@ -26,6 +29,24 @@
zip = safe_zip


@dataclasses.dataclass(frozen=True)
class UnravelPyTree:
unravel_list: Callable
treedef: PyTreeDef

def __call__(self, flat):
return tree_unflatten(self.treedef, self.unravel_list(flat))

def __hash__(self):
return hash((self.unravel_list, self.treedef))

def __eq__(self, other):
if not isinstance(other, UnravelPyTree):
return False
return self.unravel_list == other.unravel_list and self.treedef == other.treedef



def ravel_pytree(pytree):
"""Ravel (flatten) a pytree of arrays down to a 1D array.

Expand All @@ -47,9 +68,60 @@ def ravel_pytree(pytree):
"""
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = _ravel_list(leaves)
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
unravel_pytree = UnravelPyTree(unravel_list, treedef)
return flat, unravel_pytree

@dataclasses.dataclass(frozen=True)
class UnravelListWithSameDtype:
shapes: Sequence[Shape]
indices: np.ndarray

def __call__(self, arr):
chunks = jnp.split(arr, self.indices[:-1])
return [chunk.reshape(shape) for chunk, shape in zip(chunks, self.shapes)]

def __hash__(self):
return hash((self.shapes, HashableArrayWrapper(self.indices)))

def __eq__(self, other):
if not isinstance(other, UnravelListWithSameDtype):
return False
return self.shapes == other.shapes and np.all(self.indices == other.indices)


# When there is more than one distinct input dtype, we perform type
# conversions and produce a dtype-specific unravel function.
@dataclasses.dataclass(frozen=True)
class UnravelListWithDifferentDtypes:
from_dtypes: Sequence[DType]
to_dtype: DType
shapes: Sequence[Shape]
indices: np.ndarray

def __call__(self, arr):
arr_dtype = dtypes.dtype(arr)
if arr_dtype != self.to_dtype:
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
f"but expected dtype {self.to_dtype}")
chunks = jnp.split(arr, self.indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
return [lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in zip(chunks, self.shapes,
self.from_dtypes)]

def __hash__(self):
return hash((self.from_dtypes, self.to_dtype, self.shapes,
HashableArrayWrapper(self.indices)))

def __eq__(self, other):
if not isinstance(other, UnravelListWithDifferentDtypes):
return False
return (self.from_dtypes == other.from_dtypes
and self.to_dtype == other.to_dtype
and self.shapes == other.shapes
and np.all(self.indices == other.indices))

def _ravel_list(lst):
if not lst: return jnp.array([], jnp.float32), lambda _: []
from_dtypes = [dtypes.dtype(l) for l in lst]
Expand All @@ -61,25 +133,11 @@ def _ravel_list(lst):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
# See https://github.com/google/jax/issues/7809.
del from_dtypes, to_dtype
def unravel(arr):
chunks = jnp.split(arr, indices[:-1])
return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
unravel = UnravelListWithSameDtype(shapes, indices)
return raveled, unravel

# When there is more than one distinct input dtype, we perform type
# conversions and produce a dtype-specific unravel function.
def unravel(arr):
arr_dtype = dtypes.dtype(arr)
if arr_dtype != to_dtype:
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
f"but expected dtype {to_dtype}")
chunks = jnp.split(arr, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
return [lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
raveled = jnp.concatenate([ravel(e) for e in lst])
unravel = UnravelListWithDifferentDtypes(from_dtypes, to_dtype, shapes, indices)
return raveled, unravel
23 changes: 23 additions & 0 deletions jax/_src/util.py
Expand Up @@ -513,6 +513,29 @@ def __eq__(self, other):
return self.x == other.x if self.hash is not None else self.x is other.x


class HashableArrayWrapper:
x: np.ndarray
_hash: Optional[int]

def __init__(self, x):
self.x = x
self._hash = None

def __hash__(self):
if self._hash is None:
# This is a simple choice of hash function, that also works for very
# large arrays.
# A more sophisticated choice here would to:
# (a) for small arrays, hash the underlying buffer;
# (b) for large arrays, randomly subsample elements.
# TODO(kidger): do the above if necessary.
self._hash = hash(str(self.x))
return self._hash

def __eq__(self, other):
return isinstance(other, HashableArrayWrapper) and np.all(self.x == other.x)


def _original_func(f):
if isinstance(f, property):
return cast(property, f).fget
Expand Down
19 changes: 19 additions & 0 deletions tests/tree_util_test.py
Expand Up @@ -442,6 +442,25 @@ def testDtypeMonomorphicUnravel(self):
with self.assertRaisesRegex(TypeError, 'but expected dtype'):
_ = unravel(y)

def test_no_recompile(self):
x1 = jnp.array([1, 2])
x2 = jnp.array([3, 4])
x_flat1, unravel1 = flatten_util.ravel_pytree((x1, x2))
x_flat2, unravel2 = flatten_util.ravel_pytree((x1, x2))
num_traces = 0

def run(flat, unravel):
nonlocal num_traces
num_traces += 1
flat = flat + 1
return unravel(flat)

run = jax.jit(run, static_argnums=1)

run(x_flat1, unravel1)
run(x_flat2, unravel2)
self.assertEqual(num_traces, 1)


class TreePrefixErrorsTest(jtu.JaxTestCase):

Expand Down