From d0d4436589c9d7a64f8c296135e337117e0a378f Mon Sep 17 00:00:00 2001 From: Denver Coneybeare Date: Wed, 17 Mar 2021 15:19:20 -0400 Subject: [PATCH 1/2] JniRunnable: Fix a deadlock when Detach() is called from Run(). --- firestore/src/android/jni_runnable_android.cc | 5 +- firestore/src/android/jni_runnable_android.h | 30 ++++- .../android/jni_runnable_android_test.cc | 124 ++++++++++++++---- .../firestore/internal/cpp/JniRunnable.java | 71 ++++++---- 4 files changed, 172 insertions(+), 58 deletions(-) diff --git a/firestore/src/android/jni_runnable_android.cc b/firestore/src/android/jni_runnable_android.cc index 6fd647a5eb..50a590b34d 100644 --- a/firestore/src/android/jni_runnable_android.cc +++ b/firestore/src/android/jni_runnable_android.cc @@ -1,5 +1,6 @@ #include "firestore/src/android/jni_runnable_android.h" +#include "app/src/assert.h" #include "app/src/util_android.h" #include "firestore/src/jni/declaration.h" #include "firestore/src/jni/env.h" @@ -30,9 +31,7 @@ Method kRunOnNewThread("runOnNewThread", Constructor kConstructor("(J)V"); void NativeRun(JNIEnv* env, jobject java_object, jlong data) { - if (data == 0) { - return; - } + FIREBASE_ASSERT_MESSAGE(data != 0, "NativeRun() invoked with data==0"); reinterpret_cast(data)->Run(); } diff --git a/firestore/src/android/jni_runnable_android.h b/firestore/src/android/jni_runnable_android.h index cf20b5904e..8ee9439424 100644 --- a/firestore/src/android/jni_runnable_android.h +++ b/firestore/src/android/jni_runnable_android.h @@ -52,12 +52,11 @@ class JniRunnableBase { * object's `run()` method will do nothing and complete as if successful. * * This method will block until all active invocations of `Run()` have - * completed, and will cause new invocations of the Java `Runnable` object's - * `run()` that occur while this method is blocked to also block until this - * method completes. + * completed. * - * Calling `Detach()` multiple times is allowed, but invocations after the - * first invocation have no effect. + * This method may be safely invoked multiple times. Subsequent invocations + * have no side effects but will still block while there are active + * invocations of `Run()`. */ void Detach(jni::Env& env); @@ -98,7 +97,8 @@ class JniRunnableBase { * A proxy for a Java `Runnable` that calls a C++ function. * * The template parameter `CallbackT` is typically a lambda or function pointer; - * it can be anything that can be "invoked" with zero arguments. + * it can be anything that can be "invoked" with either zero arguments or one + * argument whose type is `JniRunnableBase&`. * * Example: * @@ -118,9 +118,25 @@ class JniRunnable : public JniRunnableBase { JniRunnable(jni::Env& env, CallbackT callback) : JniRunnableBase(env), callback_(firebase::Move(callback)) {} - void Run() override { callback_(); } + void Run() override { Run(*this, callback_); } private: + // These two static overloads of `Run()` use SFINAE to invoke the callback + // with zero arguments or with one argument, depending on the signature of the + // callback. If the callback takes one argument then a reference to the + // `JniRunnable` object is specified for that argument. + template + static auto Run(JniRunnableType&, ZeroArgCallback callback) + -> decltype(callback()) { + return callback(); + } + + template + static auto Run(JniRunnableType& runnable, OneArgCallback callback) + -> decltype(callback(runnable)) { + return callback(runnable); + } + CallbackT callback_; }; diff --git a/firestore/src/tests/android/jni_runnable_android_test.cc b/firestore/src/tests/android/jni_runnable_android_test.cc index 6d47f8fc1e..5d4e4b0e4d 100644 --- a/firestore/src/tests/android/jni_runnable_android_test.cc +++ b/firestore/src/tests/android/jni_runnable_android_test.cc @@ -1,5 +1,7 @@ #include "firestore/src/android/jni_runnable_android.h" +#include "app/memory/atomic.h" +#include "app/src/mutex.h" #include "firestore/src/jni/declaration.h" #include "firestore/src/jni/object.h" #include "firestore/src/jni/ownership.h" @@ -18,6 +20,7 @@ using jni::Global; using jni::Local; using jni::Method; using jni::Object; +using jni::StaticField; using jni::StaticMethod; using jni::Task; using jni::Throwable; @@ -27,6 +30,8 @@ Method kLooperGetThread("getThread", "()Ljava/lang/Thread;"); Method kRunnableRun("run", "()V"); StaticMethod kCurrentThread("currentThread", "()Ljava/lang/Thread;"); Method kThreadGetId("getId", "()J"); +Method kThreadGetState("getState", "()Ljava/lang/Thread$State;"); +StaticField kThreadStateWaiting("WAITING", "Ljava/lang/Thread$State;"); class JniRunnableTest : public FirestoreAndroidIntegrationTest { public: @@ -34,7 +39,9 @@ class JniRunnableTest : public FirestoreAndroidIntegrationTest { FirestoreAndroidIntegrationTest::SetUp(); loader().LoadClass("android/os/Looper", kGetMainLooper, kLooperGetThread); loader().LoadClass("java/lang/Runnable", kRunnableRun); - loader().LoadClass("java/lang/Thread", kCurrentThread, kThreadGetId); + loader().LoadClass("java/lang/Thread", kCurrentThread, kThreadGetId, + kThreadGetState); + loader().LoadClass("java/lang/Thread$State", kThreadStateWaiting); ASSERT_TRUE(loader().ok()); } }; @@ -56,6 +63,16 @@ jlong GetMainThreadId(Env& env) { return env.Call(main_thread, kThreadGetId); } +/** + * Returns whether or not the given thread is in the "waiting" state. + * See java.lang.Thread.State.WAITING. + */ +bool IsThreadWaiting(Env& env, Object& thread) { + Local actual_state = env.Call(thread, kThreadGetState); + Local expected_state = env.Get(kThreadStateWaiting); + return Object::Equals(env, expected_state, actual_state); +} + TEST_F(JniRunnableTest, JavaRunCallsCppRun) { Env env; bool invoked = false; @@ -145,6 +162,27 @@ TEST_F(JniRunnableTest, DetachDetachesEvenIfAnExceptionIsPending) { EXPECT_TRUE(env.ok()); } +// Verify that b/181129657 does not regress; that is, calling `Detach()` from +// `Run()` should not deadlock. +TEST_F(JniRunnableTest, DetachCanBeCalledFromRun) { + Env env; + int run_count = 0; + auto runnable = MakeJniRunnable(env, [&run_count](JniRunnableBase& runnable) { + ++run_count; + Env env; + runnable.Detach(env); + }); + Local java_runnable = runnable.GetJavaRunnable(); + + // Call `run()` twice to verify that the call to `Detach()` successfully + // detaches and the second `run()` invocation does not call C++ `Run()`. + env.Call(java_runnable, kRunnableRun); + env.Call(java_runnable, kRunnableRun); + + EXPECT_TRUE(env.ok()); + EXPECT_EQ(run_count, 1); +} + TEST_F(JniRunnableTest, DestructionCausesJavaRunToDoNothing) { Env env; bool invoked = false; @@ -191,29 +229,21 @@ TEST_F(JniRunnableTest, RunOnMainThreadTaskFailsIfRunThrowsException) { } TEST_F(JniRunnableTest, RunOnMainThreadRunsSynchronouslyFromMainThread) { - class ChainedMainThreadJniRunnable : public JniRunnableBase { - public: - using JniRunnableBase::JniRunnableBase; - - void Run() override { - Env env; - EXPECT_EQ(GetCurrentThreadId(env), GetMainThreadId(env)); - if (is_nested_call_) { - return; - } - is_nested_call_ = true; - Local task = RunOnMainThread(env); - EXPECT_TRUE(task.IsComplete(env)); - EXPECT_TRUE(task.IsSuccessful(env)); - is_nested_call_ = false; - } - - private: - bool is_nested_call_ = false; - }; - Env env; - ChainedMainThreadJniRunnable runnable(env); + bool is_recursive_call = false; + auto runnable = + MakeJniRunnable(env, [&is_recursive_call](JniRunnableBase& runnable) { + Env env; + EXPECT_EQ(GetCurrentThreadId(env), GetMainThreadId(env)); + if (is_recursive_call) { + return; + } + is_recursive_call = true; + Local task = runnable.RunOnMainThread(env); + EXPECT_TRUE(task.IsComplete(env)); + EXPECT_TRUE(task.IsSuccessful(env)); + is_recursive_call = false; + }); Local task = runnable.RunOnMainThread(env); @@ -252,6 +282,54 @@ TEST_F(JniRunnableTest, RunOnNewThreadTaskFailsIfRunThrowsException) { EXPECT_TRUE(env.IsSameObject(exception, thrown_exception)); } +TEST_F(JniRunnableTest, DetachReturnsAfterLastRunOnAnotherThreadCompletes) { + Env env; + compat::Atomic run_count; + run_count.store(0); + Mutex detach_thread_mutex; + Global detach_thread; + auto runnable = + MakeJniRunnable(env, [&run_count, &detach_thread, + &detach_thread_mutex](JniRunnableBase& runnable) { + Env env; + auto old_run_count = run_count.fetch_add(1); + if (old_run_count == 0) { + // Wait for another call of `Run()` by another thread to call + // `Detach()` and start waiting for this call to `Run()` to return. + while (env.ok()) { + MutexLock lock(detach_thread_mutex); + if (detach_thread && IsThreadWaiting(env, detach_thread)) { + break; + } + } + EXPECT_TRUE(env.ok()) << "IsThreadWaiting() failed with an exception"; + } else if (old_run_count == 1) { + { + MutexLock lock(detach_thread_mutex); + detach_thread = env.Call(kCurrentThread); + } + runnable.Detach(env); + EXPECT_TRUE(env.ok()) << "Detach() failed with an exception"; + } else { + EXPECT_TRUE(false) << "Lambda was invoked too many times"; + } + }); + + // Wait for the first invocation of `Run()` to start. + Local task1 = runnable.RunOnNewThread(env); + while (true) { + if (run_count.load() > 0) { + break; + } + } + + // Start the second invocation of `Run()`, which will call `Detach()`. + Local task2 = runnable.RunOnNewThread(env); + + Await(env, task1); + Await(env, task2); +} + } // namespace } // namespace firestore } // namespace firebase diff --git a/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java b/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java index 9e800c8253..1236fa3664 100644 --- a/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java +++ b/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java @@ -4,14 +4,13 @@ import android.os.Looper; import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.TaskCompletionSource; -import java.util.concurrent.locks.ReentrantReadWriteLock; /** A {@link Runnable} whose {@link #run} method calls a native function. */ public final class JniRunnable implements Runnable { - private final ReentrantReadWriteLock.ReadLock readLock; - private final ReentrantReadWriteLock.WriteLock writeLock; - + private final Object lock = new Object(); + private final ThreadLocal currentThreadActiveRunCount = new ThreadLocalRunDepth(); + private int totalActiveRunCount; private long data; /** @@ -26,29 +25,35 @@ public JniRunnable(long data) { "data==0 is forbidden because 0 is reserved to indicate that we are detached from the" + " C++ function"); } - ReentrantReadWriteLock lock = new ReentrantReadWriteLock(/* fair= */ true); - readLock = lock.readLock(); - writeLock = lock.writeLock(); this.data = data; } /** - * Invokes the C++ method encapsulated by this object. + * Invokes the C++ function encapsulated by this object. * *

If {@link #detach} has been invoked then this method does nothing and returns as if * successful. - * - *

This method will block if there is a thread blocked in {@link #detach}; otherwise, - * it will call the C++ function without blocking. This may even result in concurrent/parallel - * calls to the C++ function if {@link #run} is invoked concurrently. */ @Override public void run() { - readLock.lock(); + long dataCopy; + synchronized (lock) { + if (data == 0) { + return; + } + dataCopy = data; + totalActiveRunCount++; + currentThreadActiveRunCount.set(currentThreadActiveRunCount.get() + 1); + } + try { - nativeRun(data); + nativeRun(dataCopy); } finally { - readLock.unlock(); + synchronized (lock) { + currentThreadActiveRunCount.set(currentThreadActiveRunCount.get() - 1); + totalActiveRunCount--; + lock.notifyAll(); + } } } @@ -58,18 +63,27 @@ public void run() { *

After this method returns, all future invocations of {@link #run} will do nothing and return * as if successful. * - *

This method will block if there are active invocations of {@link #run}. Once all - * active invocations of {@link #run} have completed, then this method will proceed and return - * nearly instantly. Any invocations of {@link #run} that occur while {@link #detach} is blocked - * will also block, allowing the number of active invocations of {@link #run} to eventually reach - * zero and allow this method to proceed. + *

This method blocks until all invocations of the native function called from {@link #run} + * complete; therefore, when this method returns it is safe to delete any data that would be + * referenced by the native function. + * + *

This method may be safely invoked multiple times. Subsequent invocations have no side + * effects but will still block while there are active invocations of the native function. + * + * @throws InterruptedException if waiting for completion of the native function invocations is + * interrupted. */ - public void detach() { - writeLock.lock(); - try { + public void detach() throws InterruptedException { + synchronized (lock) { data = 0; - } finally { - writeLock.unlock(); + + // Wait for invocations of the native function to complete before returning. Do not consider + // native function invocations made by the current thread, which would happen if the native + // function called detach(), because that would cause this method to deadlock because the + // total run count would never reach zero. + while (totalActiveRunCount - currentThreadActiveRunCount.get() > 0) { + lock.wait(); + } } } @@ -139,4 +153,11 @@ void setException(Exception exception) { taskCompletionSource.setException(exception); } } + + private static final class ThreadLocalRunDepth extends ThreadLocal { + @Override + protected Integer initialValue() { + return 0; + } + } } From e54d5717074ba8201cadd5396e936e05da79b0df Mon Sep 17 00:00:00 2001 From: Denver Coneybeare Date: Wed, 17 Mar 2021 15:20:26 -0400 Subject: [PATCH 2/2] Modify JniRunnable.java to use a simple `synchronized` block. --- .../android/jni_runnable_android_test.cc | 75 ++++++++++--------- .../firestore/internal/cpp/JniRunnable.java | 41 ++-------- 2 files changed, 46 insertions(+), 70 deletions(-) diff --git a/firestore/src/tests/android/jni_runnable_android_test.cc b/firestore/src/tests/android/jni_runnable_android_test.cc index 5d4e4b0e4d..3af6795690 100644 --- a/firestore/src/tests/android/jni_runnable_android_test.cc +++ b/firestore/src/tests/android/jni_runnable_android_test.cc @@ -31,7 +31,7 @@ Method kRunnableRun("run", "()V"); StaticMethod kCurrentThread("currentThread", "()Ljava/lang/Thread;"); Method kThreadGetId("getId", "()J"); Method kThreadGetState("getState", "()Ljava/lang/Thread$State;"); -StaticField kThreadStateWaiting("WAITING", "Ljava/lang/Thread$State;"); +StaticField kThreadStateBlocked("BLOCKED", "Ljava/lang/Thread$State;"); class JniRunnableTest : public FirestoreAndroidIntegrationTest { public: @@ -41,7 +41,7 @@ class JniRunnableTest : public FirestoreAndroidIntegrationTest { loader().LoadClass("java/lang/Runnable", kRunnableRun); loader().LoadClass("java/lang/Thread", kCurrentThread, kThreadGetId, kThreadGetState); - loader().LoadClass("java/lang/Thread$State", kThreadStateWaiting); + loader().LoadClass("java/lang/Thread$State", kThreadStateBlocked); ASSERT_TRUE(loader().ok()); } }; @@ -64,12 +64,12 @@ jlong GetMainThreadId(Env& env) { } /** - * Returns whether or not the given thread is in the "waiting" state. - * See java.lang.Thread.State.WAITING. + * Returns whether or not the given thread is in the "blocked" state. + * See java.lang.Thread.State.BLOCKED. */ -bool IsThreadWaiting(Env& env, Object& thread) { +bool IsThreadBlocked(Env& env, Object& thread) { Local actual_state = env.Call(thread, kThreadGetState); - Local expected_state = env.Get(kThreadStateWaiting); + Local expected_state = env.Get(kThreadStateBlocked); return Object::Equals(env, expected_state, actual_state); } @@ -284,50 +284,55 @@ TEST_F(JniRunnableTest, RunOnNewThreadTaskFailsIfRunThrowsException) { TEST_F(JniRunnableTest, DetachReturnsAfterLastRunOnAnotherThreadCompletes) { Env env; - compat::Atomic run_count; - run_count.store(0); + compat::Atomic runnable1_run_invoke_count; + runnable1_run_invoke_count.store(0); Mutex detach_thread_mutex; Global detach_thread; - auto runnable = - MakeJniRunnable(env, [&run_count, &detach_thread, - &detach_thread_mutex](JniRunnableBase& runnable) { + + auto runnable1 = MakeJniRunnable( + env, [&runnable1_run_invoke_count, &detach_thread, &detach_thread_mutex] { + runnable1_run_invoke_count.fetch_add(1); Env env; - auto old_run_count = run_count.fetch_add(1); - if (old_run_count == 0) { - // Wait for another call of `Run()` by another thread to call - // `Detach()` and start waiting for this call to `Run()` to return. - while (env.ok()) { - MutexLock lock(detach_thread_mutex); - if (detach_thread && IsThreadWaiting(env, detach_thread)) { - break; - } + // Wait for `detach()` to be called and start blocking; then, return to + // allow `detach()` to unblock and do its job. + while (env.ok()) { + MutexLock lock(detach_thread_mutex); + if (detach_thread && IsThreadBlocked(env, detach_thread)) { + break; } - EXPECT_TRUE(env.ok()) << "IsThreadWaiting() failed with an exception"; - } else if (old_run_count == 1) { - { - MutexLock lock(detach_thread_mutex); - detach_thread = env.Call(kCurrentThread); - } - runnable.Detach(env); - EXPECT_TRUE(env.ok()) << "Detach() failed with an exception"; - } else { - EXPECT_TRUE(false) << "Lambda was invoked too many times"; } + EXPECT_TRUE(env.ok()) << "IsThreadBlocked() failed with an exception"; + }); + + auto runnable2 = + MakeJniRunnable(env, [&runnable1, &detach_thread, &detach_thread_mutex] { + Env env; + { + MutexLock lock(detach_thread_mutex); + detach_thread = env.Call(kCurrentThread); + } + runnable1.Detach(env); + EXPECT_TRUE(env.ok()) << "Detach() failed with an exception"; }); - // Wait for the first invocation of `Run()` to start. - Local task1 = runnable.RunOnNewThread(env); + // Wait for the `runnable1.Run()` to start to ensure that the lock is held. + Local task1 = runnable1.RunOnNewThread(env); while (true) { - if (run_count.load() > 0) { + if (runnable1_run_invoke_count.load() != 0) { break; } } - // Start the second invocation of `Run()`, which will call `Detach()`. - Local task2 = runnable.RunOnNewThread(env); + // Start a new thread to call `runnable1.Detach()`. + Local task2 = runnable2.RunOnNewThread(env); Await(env, task1); Await(env, task2); + + // Invoke `run()` again and ensure that `Detach()` successfully did its job; + // that is, verify that `Run()` is not invoked. + env.Call(runnable1.GetJavaRunnable(), kRunnableRun); + EXPECT_EQ(runnable1_run_invoke_count.load(), 1); } } // namespace diff --git a/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java b/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java index 1236fa3664..9362ee8a18 100644 --- a/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java +++ b/firestore/src_java/com/google/firebase/firestore/internal/cpp/JniRunnable.java @@ -9,8 +9,6 @@ public final class JniRunnable implements Runnable { private final Object lock = new Object(); - private final ThreadLocal currentThreadActiveRunCount = new ThreadLocalRunDepth(); - private int totalActiveRunCount; private long data; /** @@ -36,24 +34,15 @@ public JniRunnable(long data) { */ @Override public void run() { - long dataCopy; + // NOTE: Because of the `synchronized` block below, the native function will not be called + // concurrently. If concurrent invocations are desired, then this class can be modified with a + // more complicated synchronization mechanism. + // e.g. https://gist.github.com/dconeybe/2d95fbc75f88de58a49804df5c55157b synchronized (lock) { if (data == 0) { return; } - dataCopy = data; - totalActiveRunCount++; - currentThreadActiveRunCount.set(currentThreadActiveRunCount.get() + 1); - } - - try { - nativeRun(dataCopy); - } finally { - synchronized (lock) { - currentThreadActiveRunCount.set(currentThreadActiveRunCount.get() - 1); - totalActiveRunCount--; - lock.notifyAll(); - } + nativeRun(data); } } @@ -69,21 +58,10 @@ public void run() { * *

This method may be safely invoked multiple times. Subsequent invocations have no side * effects but will still block while there are active invocations of the native function. - * - * @throws InterruptedException if waiting for completion of the native function invocations is - * interrupted. */ - public void detach() throws InterruptedException { + public void detach() { synchronized (lock) { data = 0; - - // Wait for invocations of the native function to complete before returning. Do not consider - // native function invocations made by the current thread, which would happen if the native - // function called detach(), because that would cause this method to deadlock because the - // total run count would never reach zero. - while (totalActiveRunCount - currentThreadActiveRunCount.get() > 0) { - lock.wait(); - } } } @@ -153,11 +131,4 @@ void setException(Exception exception) { taskCompletionSource.setException(exception); } } - - private static final class ThreadLocalRunDepth extends ThreadLocal { - @Override - protected Integer initialValue() { - return 0; - } - } }