Skip to content

Commit

Permalink
Skip unneccessary unflattening of avals in pjit lowering path.
Browse files Browse the repository at this point in the history
The avals get flattened again when calling `from_flat_info` (here:
https://github.com/google/jax/blob/1641c8f1415a837f6f6c2537110f4be698621055/jax/_src/stages.py#L347),
so skip unflattening here.

PiperOrigin-RevId: 504260643
  • Loading branch information
LenaMartens authored and jax authors committed Jan 24, 2023
1 parent 1641c8f commit 7064be1
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,12 @@ def lower(*args, **kwargs):

if kwargs:
args_kwargs_in_tree = in_tree
local_in_avals = in_tree.unflatten(flat_local_in_avals)
else:
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)

return stages.Lowered.from_flat_info(
lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree)
lowering, args_kwargs_in_tree, flat_local_in_avals, donate_argnums,
out_tree)

wrapped.lower = lower
return wrapped
Expand Down

0 comments on commit 7064be1

Please sign in to comment.