Skip to content

Commit

Permalink
Fix memory order in Latch
Browse files Browse the repository at this point in the history
Summary: Counter decrements should synchronize among themselves to guarantee that anything before the decrements happens before the semaphore post.

Reviewed By: dmm-fb

Differential Revision: D40828812

fbshipit-source-id: 93386a259c3a6f6a45716a9f0757f297bcf103b8
  • Loading branch information
ot authored and facebook-github-bot committed Oct 29, 2022
1 parent 3e98846 commit d8ed9cd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
2 changes: 1 addition & 1 deletion folly/synchronization/Latch.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Latch {
FOLLY_ALWAYS_INLINE void count_down(ptrdiff_t n = 1) noexcept {
terminate_if(n < 0 || n > max());
if (FOLLY_LIKELY(n)) {
const auto count = count_.fetch_sub(n, std::memory_order_relaxed);
const auto count = count_.fetch_sub(n, std::memory_order_acq_rel);
terminate_if(count < n);
if (FOLLY_UNLIKELY(count == n)) {
semaphore_.post();
Expand Down
45 changes: 23 additions & 22 deletions folly/synchronization/test/LatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <numeric>
#include <thread>

#include <folly/portability/GTest.h>
Expand Down Expand Up @@ -91,98 +92,98 @@ TEST(LatchTest, CountDownN) {
}

TEST(LatchTest, CountDownThreads) {
std::atomic_int completed{0};
const int N = 32;
std::vector<int> done(N);
folly::Latch latch(N);
std::vector<std::thread> threads;
for (int i = 0; i < N; i++) {
threads.emplace_back([&] {
completed++;
threads.emplace_back([&, i] {
done[i] = 1;
latch.count_down();
});
}
EXPECT_TRUE(latch.try_wait_for(std::chrono::seconds(60)));
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
for (auto& t : threads) {
t.join();
}
}

TEST(LatchTest, CountDownThreadsTwice1) {
std::atomic_int completed{0};
const int N = 32;
std::vector<int> done(N);
folly::Latch latch(N * 2);
std::vector<std::thread> threads;
for (int i = 0; i < N; i++) {
threads.emplace_back([&] {
completed++;
threads.emplace_back([&, i] {
done[i] = 1;
// count_down() multiple times within same thread
latch.count_down();
latch.count_down();
});
}
EXPECT_TRUE(latch.try_wait_for(std::chrono::seconds(60)));
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
for (auto& t : threads) {
t.join();
}
}

TEST(LatchTest, CountDownThreadsTwice2) {
std::atomic_int completed{0};
const int N = 32;
std::vector<int> done(N);
folly::Latch latch(N * 2);
std::vector<std::thread> threads;
for (int i = 0; i < N; i++) {
threads.emplace_back([&] {
completed++;
threads.emplace_back([&, i] {
done[i] = 1;
// count_down() multiple times within same thread
latch.count_down(2);
});
}
EXPECT_TRUE(latch.try_wait_for(std::chrono::seconds(60)));
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
for (auto& t : threads) {
t.join();
}
}

TEST(LatchTest, CountDownThreadsWait) {
std::atomic_int completed{0};
const int N = 32;
std::vector<int> done(N);
folly::Latch latch(N);
std::vector<std::thread> threads;
for (int i = 0; i < N; i++) {
threads.emplace_back([&] {
completed++;
threads.emplace_back([&, i] {
done[i] = 1;
// count_down() and wait() within thread
latch.count_down();
EXPECT_TRUE(latch.try_wait_for(std::chrono::seconds(60)));
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
});
}
EXPECT_TRUE(latch.try_wait_for(std::chrono::seconds(60)));
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
for (auto& t : threads) {
t.join();
}
}

TEST(LatchTest, CountDownThreadsArriveAndWait) {
std::atomic_int completed{0};
const int N = 32;
std::vector<int> done(N);
folly::Latch latch(N);
std::vector<std::thread> threads;
for (int i = 0; i < N; i++) {
threads.emplace_back([&] {
completed++;
threads.emplace_back([&, i] {
done[i] = 1;
// count_down() and wait() within thread
latch.arrive_and_wait();
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
});
}
EXPECT_TRUE(latch.try_wait_for(std::chrono::seconds(60)));
EXPECT_EQ(completed.load(), N);
EXPECT_EQ(std::accumulate(done.begin(), done.end(), 0), N);
for (auto& t : threads) {
t.join();
}
Expand Down

0 comments on commit d8ed9cd

Please sign in to comment.