Skip to content

Commit

Permalink
[host_callback] Make logging statements conditional.
Browse files Browse the repository at this point in the history
Misuse of conditional logging (logging.vlog) led to significant time wasted
even if logging was off.

PiperOrigin-RevId: 380559720
  • Loading branch information
gnecula authored and jax authors committed Jun 21, 2021
1 parent 50f48e8 commit 64974f3
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,9 @@ def _unpack_transform(name, *params):
try:
arg = api.tree_unflatten(arg_treedef, arrays)
unpacked_transforms = _unpack_transforms(transforms)
logging.vlog(2,
f"Outside call consumer invoking call_func {callback} with {arg}, device={device}, transforms={unpacked_transforms}")
if logging.vlog_is_on(2):
logging.vlog(2,
f"Outside call consumer invoking call_func {callback}, device={device}, transforms={unpacked_transforms}")
res = callback(arg, device, unpacked_transforms)

if identity:
Expand All @@ -960,10 +961,11 @@ def _unpack_transform(name, *params):

canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results))
actual_flat_results_aval = _values_to_avals(canonical_flat_results)
logging.vlog(
2,
f"Outside call consumer {callback} result {res} : {flat_results_aval}. Sending to infeed for device {device}."
)
if logging.vlog_is_on(2):
logging.vlog(
2,
f"Outside call consumer {callback} result {flat_results_aval}. Sending to infeed for device {device}."
)

if not all(ea.strip_weak_type() == ra.strip_weak_type()
for ea, ra in util.safe_zip(flat_results_aval,
Expand Down Expand Up @@ -993,10 +995,12 @@ def _unpack_transform(name, *params):
# TODO: implement a proper error handling for TPU
if device.platform != "tpu":
canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
logging.vlog(2, f"Outside call consumer {callback} exception {e}. Sending to infeed the error result.")
if logging.vlog_is_on(2):
logging.vlog(2, f"Outside call consumer {callback} exception {e}. Sending to infeed the error result.")
device.transfer_to_infeed(tuple(canonical_flat_results))
else:
logging.vlog(2, f"Outside call consumer {callback} exception {e}. On TPU we do not send infeed.")
if logging.vlog_is_on(2):
logging.vlog(2, f"Outside call consumer {callback} exception {e}. On TPU we do not send infeed.")
raise e # Let the exception propagate


Expand Down Expand Up @@ -1627,9 +1631,11 @@ def _initialize_outfeed_receiver(
itertools.chain(*[backend.local_devices() for backend in clients]))
_outfeed_receiver.clients = clients # type: ignore[assignment]
_outfeed_receiver.devices = devices # type: ignore[assignment]
logging.vlog(
2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. "
f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
if logging.vlog_is_on(2):
logging.vlog(
2,
f"Starting outfeed_receiver for {[str(d) for d in devices]}. "
f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
_outfeed_receiver.receiver = outfeed_receiver_module.start(
_outfeed_receiver_callback, tuple(clients),
max_callback_queue_size_bytes)
Expand Down Expand Up @@ -1661,9 +1667,11 @@ def barrier_wait(logging_name: Optional[str] = None):
for this invocation. See `Debugging` in the module documentation.
"""
logging_name = logging_name or ""
logging.vlog(2, f"barrier_wait[{logging_name}]: start")
if logging.vlog_is_on(2):
logging.vlog(2, f"barrier_wait[{logging_name}]: start")
if not _outfeed_receiver.receiver:
logging.vlog(2, f"barrier_wait[{logging_name}]: receiver not started")
if logging.vlog_is_on(2):
logging.vlog(2, f"barrier_wait[{logging_name}]: receiver not started")
return

lock = threading.Lock()
Expand All @@ -1672,22 +1680,30 @@ def barrier_wait(logging_name: Optional[str] = None):

def barrier_tap(dev_idx, _):
nonlocal num_at_large
logging.vlog(
2, f"barrier_wait[{logging_name}]: at barrier_tap for device {_outfeed_receiver.devices[dev_idx]} "
f". Thread {threading.current_thread()}")
if logging.vlog_is_on(2):
logging.vlog(
2,
f"barrier_wait[{logging_name}]: at barrier_tap for device {_outfeed_receiver.devices[dev_idx]} "
f". Thread {threading.current_thread()}")
with lock:
num_at_large -= 1
logging.vlog(2, f"barrier_wait[{logging_name}]: still waiting for {num_at_large} barrier_tap")
if logging.vlog_is_on(2):
logging.vlog(2, f"barrier_wait[{logging_name}]: still waiting for {num_at_large} barrier_tap")
cv.notify()

for d_idx, d in enumerate(_outfeed_receiver.devices):
logging.vlog(2, f"barrier_wait[{logging_name}]: enqueueing barrier on device {d}")
if logging.vlog_is_on(2):
logging.vlog(2,
f"barrier_wait[{logging_name}]: enqueueing barrier on device {d}")
x_on_dev = api.device_put(d_idx, device=d)
api.jit(lambda x: id_tap(barrier_tap, x), device=d)(x_on_dev)
logging.vlog(2, f"barrier_wait[{logging_name}]: waiting for callbacks")
if logging.vlog_is_on(2):
logging.vlog(2,
f"barrier_wait[{logging_name}]: waiting for callbacks")
with lock:
cv.wait_for(lambda: num_at_large == 0)
logging.vlog(2, f"barrier_wait[{logging_name}]: done")
if logging.vlog_is_on(2):
logging.vlog(2, f"barrier_wait[{logging_name}]: done")
if _outfeed_receiver.last_callback_exception is not None:
last_exception, formatted_last_exception = _outfeed_receiver.last_callback_exception
_outfeed_receiver.last_callback_exception = None
Expand Down

0 comments on commit 64974f3

Please sign in to comment.