Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU Workflow Hangs #11637

Closed
ntenenz opened this issue Jul 28, 2022 · 6 comments
Closed

Multi-GPU Workflow Hangs #11637

ntenenz opened this issue Jul 28, 2022 · 6 comments
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@ntenenz
Copy link

ntenenz commented Jul 28, 2022

Potentially related to #10969.

When running the following code on a single device, it will run as expected. However, if n_devices > 1, I'm observing all devices will allocate memory but an inconsistent number will perform computation (as displayed via nvtop / nvidia-smi), leading to an indefinite hang.

Environment:

  • OS: ubuntu 20.04
  • Num GPU: 4
  • Python Version: 3.8.10
  • Nvidia Driver: 515.48.07
  • CUDA Version: 11.7
  • Jax Version: 0.3.14

While the actual code in which I'm seeing this occur is far more complex, I've created a simpler repro.

def func(num_rows_per_device, num_devices, num_features):
    num_entries = num_rows_per_device * num_devices * num_features
    
    # often hangs
    def fn(x_):
        return jax.lax.cond(
            x_.sum() > num_entries,
            lambda: 0.,
            lambda: jnp.sum(x_**2)
        )
    
    # doesn't seem to hang
    # def fn(x_):
    #     return x_.sum()
    
    data = jnp.arange(num_entries).reshape(num_rows_per_device, num_devices, num_features).astype(jnp.float32)
    return jax.lax.map(jax.pmap(fn), data)

print(jax.jit(func, static_argnames=["num_rows_per_device", "num_devices", "num_features"])(5, 4, 2))
# seems to hang quite frequently
@ntenenz ntenenz added the bug Something isn't working label Jul 28, 2022
@ntenenz ntenenz changed the title Multi-GPU Workflow Hangs with jax.lax.cond Multi-GPU Workflow Hangs Jul 28, 2022
@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Jul 28, 2022
@ntenenz
Copy link
Author

ntenenz commented Jul 28, 2022

I know you mentioned you couldn't reproduce on a 4xT4 node @hawkinsp. Based on your experience, do you have a recommended path we can explore on our side to mitigate? As you can imagine, this is a blocker.

@hawkinsp
Copy link
Member

Actually I have reproduced it and I'm looking at it. It's not deterministic; I just needed to run it more times.

@hawkinsp hawkinsp self-assigned this Jul 28, 2022
@ntenenz
Copy link
Author

ntenenz commented Jul 28, 2022

Thanks! Let me know if I can be of assistance in running any tests / dev branches.

@hawkinsp
Copy link
Member

I think the XLA patch below fixes things. I'll send a cleaned-up version of it out for review and it should make it into the next jaxlib.

The patch adds a barrier that ensures all participants complete ncclCommInitRank before any participant is allowed to issue the collective. Without a barrier, we can experience deadlocks. As best I understand it, the deadlock scenario looks like this:

Thread A:

  • calls ncclCommInitRank(), which succeeds,
  • issues the collective operation,
  • calls an operation that manipulates the device page tables, e.g., copying a device buffer to an unpinned host buffer.
  • Since this action manipulates the device page tables, it seems that this action blocks waiting for the device stream.

Thread B:

  • calls ncclCommInitRank(), which calls cudaMalloc().
  • cudaMalloc() also manipulates device page tables, and cannot proceed without acquiring an internal lock around the device page table state
    But thread A already holds this lock, but thread A cannot make progress until thread B issues its collective.

This is a deadlock: neither thread can make progress. We can avoid the problem by requiring a barrier after the calls to ncclCommInitRank(), requiring all GPUs to finish initialization before any of them can issue their collective operation.

diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index ffa157879b8..b67a1727677 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include <utility>

 #include "absl/strings/str_format.h"
+#include "absl/synchronization/notification.h"
 #include "absl/time/time.h"
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/global_device_id.h"
@@ -122,6 +123,12 @@ StatusOr<std::string> LocalNcclUniqueIdCallback(const NcclCliqueKey&) {
 struct NcclCliqueState {
   ncclUniqueId unique_id;
   int64_t run_id = -1;
+
+  absl::Notification ready;
+  absl::Mutex mu;
+  Status status;
+  int lowest_rank = -1;
+  absl::flat_hash_map<int, std::unique_ptr<NcclComm>> communicators;
 };

 using NcclClique = Lockable<NcclCliqueState>;
@@ -222,8 +229,11 @@ StatusOr<NcclComm::Lock> AcquireNcclComm(

   if (!clique->ok()) return clique->status();

-  auto comm_key = std::make_pair(std::move(clique_key), rank);
-  static auto& comms = *new ThreadSafeMap<decltype(comm_key), NcclComm>;
+  struct AllCommunicators {
+    absl::Mutex mu;
+    std::vector<NcclComm*> communicators ABSL_GUARDED_BY(mu);
+  };
+  static auto& all_communicators = *new AllCommunicators;

   // Launch a thread that periodically checks all NCCL communicators for
   // asynchronous errors. If an asynchronous error is observed, the communicator
@@ -233,16 +243,49 @@ StatusOr<NcclComm::Lock> AcquireNcclComm(
           tensorflow::ThreadOptions(), "nccl_async_error_thread", [&] {
             while (true) {
               absl::SleepFor(absl::Seconds(30));
-              comms.ForEachValue(CheckNcclAsyncError);
+              absl::MutexLock lock(&all_communicators.mu);
+              for (NcclComm* comm : all_communicators.communicators) {
+                CheckNcclAsyncError(*comm);
+              }
             }
           });
   (void)check_async_error_thread;  // Silence unused variable warning.

-  NcclComm::Lock comm = comms[comm_key].Acquire();
-  if (*comm == nullptr) {
-    int nranks = comm_key.first.devices().size();
+  NcclComm::Lock comm;
+  Status status;
+  if (!(**clique)->ready.HasBeenNotified()) {
+    auto comm_ptr = std::make_unique<NcclComm>();
+    comm = comm_ptr->Acquire();
+    int nranks = clique_key.devices().size();
     const ncclUniqueId& id = (**clique)->unique_id;
-    XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(comm.get(), nranks, id, rank));
+    Status status = XLA_CUDA_STATUS(ncclCommInitRank(comm.get(), nranks, id, rank));
+
+    {
+      absl::MutexLock lock(&all_communicators.mu);
+      all_communicators.communicators.push_back(comm_ptr.get());
+    }
+
+    absl::MutexLock lock(&(**clique)->mu);
+    (**clique)->status.Update(status);
+    if ((**clique)->lowest_rank < 0 || rank < (**clique)->lowest_rank) {
+      (**clique)->lowest_rank = rank;
+    }
+    (**clique)->communicators[rank] = std::move(comm_ptr);
+
+    // Wait for all communicators to initialize before allowing any progress. Otherwise we may get deadlocks
+    auto all_initialized = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED((**clique->mu)) {
+      return (**clique)->communicators.size() == num_local_participants;
+    };
+    (**clique)->mu.Await(absl::Condition(&all_initialized));
+    status = (**clique)->status;
+    if (rank == (**clique)->lowest_rank) {
+      (**clique)->ready.Notify();
+    }
+  } else {
+    comm = (**clique)->communicators[rank]->Acquire();
+  }
+  if (!(**clique)->status.ok()) {
+    return (**clique)->status;
   }
   return comm;
 }
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
index 39d7fbabefd..3a5949b661a 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
@@ -84,7 +84,8 @@ class Lockable {
   // RAII type that will release the exclusive lock when it is destroyed.
   using Lock = std::unique_ptr<T, std::function<void(T*)>>;

-  explicit Lockable(T value = T()) : value_(std::move(value)) {}
+  Lockable() = default;
+  explicit Lockable(T&& value) : value_(std::move(value)) {}

   Lock Acquire() {
     absl::MutexLock lock(&mutex_);

@ntenenz
Copy link
Author

ntenenz commented Jul 28, 2022

Thanks for such a prompt response!

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jul 29, 2022
… ncclCommInitRank() before allowing any participant to proceed.

Without a barrier, we can experience deadlocks. As best I understand it, the deadlock scenario looks like this:

Thread A:
* calls ncclCommInitRank(), which succeeds,
* issues the collective operation,
* calls an operation that manipulates the device page tables, e.g., copying a device buffer to an unpinned host buffer.
* Since this action manipulates the device page tables, it seems that this action blocks waiting for the device stream.

Thread B:
* calls ncclCommInitRank(), which calls cudaMalloc().
* cudaMalloc() also manipulates device page tables, and cannot proceed without acquiring an internal lock around the device page table state
But thread A already holds this lock, but thread A cannot make progress until thread B issues its collective.

This is a deadlock: neither thread can make progress. We can avoid the problem by requiring a barrier after the calls to ncclCommInitRank(), requiring all GPUs to finish initialization before any of them can issue their collective operation.

Fixes google/jax#11637

PiperOrigin-RevId: 464164328
@hawkinsp hawkinsp mentioned this issue Aug 1, 2022
@hawkinsp
Copy link
Member

hawkinsp commented Aug 1, 2022

This is now fixed at head; you'll have to build jaxlib from source to get the fix or wait for us to make a release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

2 participants