Skip to content

Commit

Permalink
Avoid re-flattening in jit() when no donate_argnums are present. (goo…
Browse files Browse the repository at this point in the history
…gle#3945)

Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.
  • Loading branch information
Adrià Puigdomènech committed Aug 3, 2020
1 parent 1b3dd65 commit 4e873f4
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions jax/api.py
Expand Up @@ -164,7 +164,10 @@ def f_jitted(*args, **kwargs):
else:
dyn_args = args
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
donated_invars = donation_vector(donate_argnums, dyn_args, kwargs)
if donate_argnums:
donated_invars = donation_vector(donate_argnums, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args_flat)
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
Expand Down Expand Up @@ -1180,7 +1183,10 @@ def f_pmapped(*args, **kwargs):
else:
dyn_args, dyn_in_axes = args, in_axes
args, in_tree = tree_flatten((dyn_args, kwargs))
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
if donate_argnums:
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args)
in_axes_flat = flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0))
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
for arg in args: _check_arg(arg)
Expand Down

0 comments on commit 4e873f4

Please sign in to comment.