Skip to content

Commit

Permalink
Lazy eval for vlogging on pmap/pjit critical path.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqiaorjc committed Jun 26, 2021
1 parent a50c273 commit 57669bf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
26 changes: 14 additions & 12 deletions jax/interpreters/pxla.py
Expand Up @@ -704,8 +704,9 @@ def parallel_callable(fun: lu.WrappedFun,
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore
logging.vlog(2, "sharded_avals: %s", sharded_avals)
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
if logging.vlog_is_on(2):
logging.vlog(2, "sharded_avals: %s", sharded_avals)
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)

with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
Expand Down Expand Up @@ -740,16 +741,17 @@ def parallel_callable(fun: lu.WrappedFun,
if local_out_parts is None:
local_out_parts = out_parts

logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
num_global_replicas, num_local_replicas)
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
num_partitions, local_num_partitions)
logging.vlog(2, "arg_parts: %s", arg_parts)
logging.vlog(2, "local_arg_parts: %s", local_arg_parts)
logging.vlog(2, "out_parts: %s", out_parts)
logging.vlog(2, "local_out_parts: %s", local_out_parts)
logging.vlog(2, "devices: %s", devices)
logging.vlog(2, "local_devices: %s", local_devices)
if logging.vlog_is_on(2):
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
num_global_replicas, num_local_replicas)
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
num_partitions, local_num_partitions)
logging.vlog(2, "arg_parts: %s", arg_parts)
logging.vlog(2, "local_arg_parts: %s", local_arg_parts)
logging.vlog(2, "out_parts: %s", out_parts)
logging.vlog(2, "local_out_parts: %s", local_out_parts)
logging.vlog(2, "devices: %s", devices)
logging.vlog(2, "local_devices: %s", local_devices)

num_local_shards = num_local_replicas * local_num_partitions
num_global_shards = num_global_replicas * num_partitions
Expand Down
17 changes: 10 additions & 7 deletions jax/interpreters/sharded_jit.py
Expand Up @@ -79,10 +79,11 @@ def _sharded_callable(
for arg, parts, lparts
in safe_zip(abstract_args, in_parts, local_in_parts)]

logging.vlog(2, "abstract_args: %s", abstract_args)
logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
logging.vlog(2, "in_parts: %s", in_parts)
logging.vlog(2, "local_in_parts: %s", local_in_parts)
if logging.vlog_is_on(2):
logging.vlog(2, "abstract_args: %s", abstract_args)
logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
logging.vlog(2, "in_parts: %s", in_parts)
logging.vlog(2, "local_in_parts: %s", local_in_parts)

jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)

Expand Down Expand Up @@ -115,16 +116,18 @@ def _sharded_callable(
f"sharded_jit computation requires {local_nparts} local devices, "
f"but only {xb.local_device_count()} local devices are available.")

logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts)
if logging.vlog_is_on(2):
logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts)

out_parts = out_parts_thunk()

local_out_parts = local_out_parts_thunk()
if local_out_parts is None:
local_out_parts = out_parts

logging.vlog(2, "out_parts: %s", out_parts)
logging.vlog(2, "local_out_parts: %s", local_out_parts)
if logging.vlog_is_on(2):
logging.vlog(2, "out_parts: %s", out_parts)
logging.vlog(2, "local_out_parts: %s", local_out_parts)

local_out_avals = [pxla.get_local_aval(out, parts, lparts)
for out, parts, lparts
Expand Down

0 comments on commit 57669bf

Please sign in to comment.