Skip to content

Commit

Permalink
Add timeout to collective ops to detect deadlocks.
Browse files Browse the repository at this point in the history
The timeout is set as an argument to a collective op. When non zero value, a completion timeout is set to detect staleness. If a timeout goes off, the execution is aborted through a DEADLINE_EXCEEDED error.

PiperOrigin-RevId: 313861868
Change-Id: I7fee45736608ad7fbcc9dd980db2fd302c9cb4df
  • Loading branch information
tensorflower-gardener committed May 29, 2020
1 parent 85396ef commit 66529c3
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 38 deletions.
60 changes: 51 additions & 9 deletions tensorflow/core/common_runtime/base_collective_executor.cc
Expand Up @@ -221,23 +221,42 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams& col_params,
const string& exec_key,
StatusCallback done) {
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);

// On any individual collective Op failure we need to abort the
// BufRendezvous so that other Ops in the instance don't hang
// waiting for transmissions that will never happen. Do so after a
// delay so that the original error status is more likely to
// propagate up, and peers are unlikely to re-create the purged
// BufRendezvous by late-arriving requests.
StatusCallback done_safe = [this, done](const Status& s) {
if (!s.ok()) {
Ref(); // Ensure this lasts until the closure executes.
SchedNonBlockingClosureAfter(1000000, [this, s] {
remote_access_->buf_rendezvous()->StartAbort(s);
Unref();
});
StatusCallback done_safe = [this, done, is_callback_called](const Status& s) {
auto should_call_callback = !is_callback_called->exchange(true);
if (should_call_callback) {
if (!s.ok()) {
Ref(); // Ensure this lasts until the closure executes.
SchedNonBlockingClosureAfter(1000000, [this, s] {
remote_access_->buf_rendezvous()->StartAbort(s);
Unref();
});
}
done(s);
}
done(s);
};

auto timeout_microseconds = static_cast<int64>(
col_params.instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(
timeout_microseconds, [is_callback_called, done_safe] {
if (!is_callback_called->load()) {
auto status = Status(error::DEADLINE_EXCEEDED,
"Collective has timed out during execution.");
done_safe(status);
}
});
}

Tensor* output = ctx->mutable_output(0);
const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
col_params.instance.type == GATHER_COLLECTIVE ||
Expand Down Expand Up @@ -284,7 +303,30 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
StatusCallback done) {
cp->instance.gpu_ring_order = *gpu_ring_order_;
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_with_timeout = done;
auto timeout_microseconds =
static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(
timeout_microseconds, [is_callback_called, done] {
if (!is_callback_called->load()) {
auto status =
Status(error::DEADLINE_EXCEEDED,
"Collective has timed out waiting for other workers.");
done(status);
}
});
done_with_timeout = [is_callback_called, done](const Status& s) {
auto should_call_callback = !is_callback_called->exchange(true);
if (should_call_callback) {
done(s);
}
};
}
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
done_with_timeout);
}

Status BaseCollectiveExecutor::CreateCollective(
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/collective.h
Expand Up @@ -84,6 +84,8 @@ struct CollImplDetails {
dependencies; // collective instances on which this node depends
string communication_hint; // user-supplied hint for implementation choice,
// e.g. ring or nccl
float timeout_seconds; // If non zero, set a completion timeout for the
// collective op to detect staleness.
};

// Data common to all members of a collective instance.
Expand Down
15 changes: 14 additions & 1 deletion tensorflow/core/kernels/collective_ops.cc
Expand Up @@ -85,6 +85,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
col_params_.group.device_type = c->device_type();
Expand Down Expand Up @@ -176,10 +179,14 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
<< " merge_op " << merge_op_name << " final_op " << final_op_name
<< " communication_hint "
<< col_params_.instance.impl_details.communication_hint;
<< col_params_.instance.impl_details.communication_hint
<< " timeout " << col_params_.instance.impl_details.timeout_seconds;

const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
Expand Down Expand Up @@ -284,6 +291,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};

Expand Down Expand Up @@ -363,6 +373,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/ops/collective_ops.cc
Expand Up @@ -31,6 +31,7 @@ REGISTER_OP("CollectiveReduce")
.Attr("subdiv_offsets: list(int)")
.Attr("wait_for: list(int) = []")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);

Expand All @@ -43,6 +44,7 @@ REGISTER_OP("CollectiveGather")
.Attr("instance_key: int")
.Attr("shape: shape")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Scalar input is not supported.
Expand Down Expand Up @@ -86,6 +88,7 @@ REGISTER_OP("CollectiveBcastSend")
.Attr("instance_key: int")
.Attr("shape: shape")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::ExplicitShape);

Expand All @@ -97,6 +100,7 @@ REGISTER_OP("CollectiveBcastRecv")
.Attr("instance_key: int")
.Attr("shape: shape")
.Attr("communication_hint: string = 'auto'")
.Attr("timeout_seconds: float = 0")
.SetIsStateful()
.SetShapeFn(shape_inference::ExplicitShape);

Expand Down
66 changes: 52 additions & 14 deletions tensorflow/python/ops/collective_ops.py
Expand Up @@ -20,8 +20,15 @@
from tensorflow.python.ops import gen_collective_ops


def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
subdiv_offsets=(0,), communication_hint='auto'):
def all_reduce(t,
group_size,
group_key,
instance_key,
merge_op,
final_op,
subdiv_offsets=(0,),
communication_hint='auto',
timeout=0):
"""Reduces tensors collectively, across devices.
Args:
Expand All @@ -40,6 +47,9 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed reduction.
Expand All @@ -57,11 +67,16 @@ def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
merge_op=merge_op,
final_op=final_op,
subdiv_offsets=subdiv_offsets,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)


def all_gather(t, group_size, group_key, instance_key,
communication_hint='auto'):
def all_gather(t,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Accumulates tensors collectively, across devices, along first dimension.
Args:
Expand All @@ -73,6 +88,9 @@ def all_gather(t, group_size, group_key, instance_key,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed operation.
Expand All @@ -88,11 +106,18 @@ def all_gather(t, group_size, group_key, instance_key,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower())


def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
communication_hint='auto'):
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)


def broadcast_send(t,
shape,
dtype,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Broadcasts one tensor to a group of others, across devices.
Args:
Expand All @@ -107,6 +132,9 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the distributed broadcast send.
Expand Down Expand Up @@ -139,11 +167,17 @@ def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)


def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
communication_hint='auto'):
def broadcast_recv(shape,
dtype,
group_size,
group_key,
instance_key,
communication_hint='auto',
timeout=0):
"""Receives a broadcasts tensor, across devices.
Args:
Expand All @@ -157,6 +191,9 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
communication_hint: preferred collective communication. The implementation
may fall back to another mechanism. Options include `auto`, `ring`, and
`nccl`.
timeout: If set to a non zero, set a completion timeout to detect staleness.
If the timer goes off, a DeadlineExceededError is raised.
The timeout value in seconds. This feature is experimental.
Returns:
An Op implementing the broadcast receive.
Expand All @@ -173,4 +210,5 @@ def broadcast_recv(shape, dtype, group_size, group_key, instance_key,
group_size=group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication_hint.lower())
communication_hint=communication_hint.lower(),
timeout_seconds=timeout)

0 comments on commit 66529c3

Please sign in to comment.