Skip to content
Browse files

Fix SIGSEGV in StringPiece::find_first_of

Summary:
Our SSE version of find_first_of was reading past the end of
the StringPiece in some cases, which (very rarely) caused a seg-fault
when we were reading outside of our allotted virtual address space.

Modify the code to never read past the end of the underlying buffers
except when we think it's "safe" because we're still within the same
page. (ASSUMPTION: if a process is allowed to read a byte within a
page, then it is allowed to read _all_ bytes within that page.)

Test Plan:
Added tests that verify we won't go across page boundaries.

Sadly, this code hurts our benchmarks -- sometimes by up to 50% for
smaller strings.

Reviewed By: philipp@fb.com

FB internal diff: D707923

Blame Revision: D638500
  • Loading branch information...
1 parent a8b4b5e commit 4f7a54f61beb30942dc69ced22c776d2dd84e3d2 Mike Curtiss committed with jdelong Feb 12, 2013
Showing with 157 additions and 16 deletions.
  1. +96 −15 folly/Range.cpp
  2. +61 −1 folly/test/RangeTest.cpp
View
111 folly/Range.cpp
@@ -19,6 +19,7 @@
#include "folly/Range.h"
+#include <emmintrin.h> // __v16qi
#include "folly/Likely.h"
namespace folly {
@@ -56,6 +57,16 @@ size_t qfind_first_byte_of_memchr(const StringPiece& haystack,
namespace {
+// It's okay if pages are bigger than this (as powers of two), but they should
+// not be smaller.
+constexpr size_t kMinPageSize = 4096;
+#define PAGE_FOR(addr) \
+ (reinterpret_cast<intptr_t>(addr) / kMinPageSize)
+
+// Rounds up to the next multiple of 16
+#define ROUND_UP_16(val) \
+ ((val + 15) & ~0xF)
+
// build sse4.2-optimized version even if -msse4.2 is not passed to GCC
size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
const StringPiece& needles)
@@ -64,22 +75,46 @@ size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
// helper method for case where needles.size() <= 16
size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
const StringPiece& needles) {
+ DCHECK(!haystack.empty());
+ DCHECK(!needles.empty());
DCHECK_LE(needles.size(), 16);
- if (needles.size() <= 2 && haystack.size() >= 256) {
+ if ((needles.size() <= 2 && haystack.size() >= 256) ||
+ // we can't load needles into SSE register if it could cross page boundary
+ (PAGE_FOR(needles.end() - 1) != PAGE_FOR(needles.data() + 15))) {
// benchmarking shows that memchr beats out SSE for small needle-sets
// with large haystacks.
// TODO(mcurtiss): could this be because of unaligned SSE loads?
return detail::qfind_first_byte_of_memchr(haystack, needles);
}
- auto arr2 = __builtin_ia32_loaddqu(needles.data());
- for (size_t i = 0; i < haystack.size(); i+= 16) {
+
+ __v16qi arr2 = __builtin_ia32_loaddqu(needles.data());
+
+ // If true, the last byte we want to load into the SSE register is on the
+ // same page as the last byte of the actual Range. No risk of segfault.
+ bool canSseLoadLastBlock =
+ (PAGE_FOR(haystack.end() - 1) ==
+ PAGE_FOR(haystack.data() + ROUND_UP_16(haystack.size()) - 1));
+ int64_t lastSafeBlockIdx = canSseLoadLastBlock ?
+ haystack.size() : static_cast<int64_t>(haystack.size()) - 16;
+
+ int64_t i = 0;
+ for (; i < lastSafeBlockIdx; i+= 16) {
auto arr1 = __builtin_ia32_loaddqu(haystack.data() + i);
auto index = __builtin_ia32_pcmpestri128(arr2, needles.size(),
arr1, haystack.size() - i, 0);
if (index < 16) {
return i + index;
}
}
+
+ if (!canSseLoadLastBlock) {
+ StringPiece tmp(haystack);
+ tmp.advance(i);
+ auto ret = detail::qfind_first_byte_of_memchr(tmp, needles);
+ if (ret != StringPiece::npos) {
+ return ret + i;
+ }
+ }
return StringPiece::npos;
}
@@ -127,6 +162,46 @@ size_t qfind_first_byte_of_byteset(const StringPiece& haystack,
return StringPiece::npos;
}
+inline size_t scanHaystackBlock(const StringPiece& haystack,
+ const StringPiece& needles,
+ int64_t idx)
+// inlining is okay because it's only called from other sse4.2 functions
+ __attribute__ ((__target__("sse4.2")));
+
+// Scans a 16-byte block of haystack (starting at blockStartIdx) to find first
+// needle. If blockStartIdx is near the end of haystack, it may read a few bytes
+// past the end; it is the caller's responsibility to ensure this is safe.
+inline size_t scanHaystackBlock(const StringPiece& haystack,
+ const StringPiece& needles,
+ int64_t blockStartIdx) {
+ // small needle sets should be handled by qfind_first_byte_of_needles16()
+ DCHECK_GT(needles.size(), 16);
+ DCHECK(blockStartIdx + 16 <= haystack.size() ||
+ (PAGE_FOR(haystack.data() + blockStartIdx) ==
+ PAGE_FOR(haystack.data() + blockStartIdx + 15)));
+ size_t b = 16;
+ auto arr1 = __builtin_ia32_loaddqu(haystack.data() + blockStartIdx);
+ int64_t j = 0;
+ for (; j < static_cast<int64_t>(needles.size()) - 16; j += 16) {
+ auto arr2 = __builtin_ia32_loaddqu(needles.data() + j);
+ auto index = __builtin_ia32_pcmpestri128(
+ arr2, 16, arr1, haystack.size() - blockStartIdx, 0);
+ b = std::min<size_t>(index, b);
+ }
+
+ // Avoid reading any bytes past the end needles by just reading the last
+ // 16 bytes of needles. We know this is safe because needles.size() > 16.
+ auto arr2 = __builtin_ia32_loaddqu(needles.end() - 16);
+ auto index = __builtin_ia32_pcmpestri128(
+ arr2, 16, arr1, haystack.size() - blockStartIdx, 0);
+ b = std::min<size_t>(index, b);
+
+ if (b < 16) {
+ return blockStartIdx + b;
+ }
+ return StringPiece::npos;
+}
+
size_t qfind_first_byte_of_sse42(const StringPiece& haystack,
const StringPiece& needles)
__attribute__ ((__target__("sse4.2"), noinline));
@@ -141,20 +216,26 @@ size_t qfind_first_byte_of_sse42(const StringPiece& haystack,
return qfind_first_byte_of_needles16(haystack, needles);
}
- size_t index = haystack.size();
- for (size_t i = 0; i < haystack.size(); i += 16) {
- size_t b = 16;
- auto arr1 = __builtin_ia32_loaddqu(haystack.data() + i);
- for (size_t j = 0; j < needles.size(); j += 16) {
- auto arr2 = __builtin_ia32_loaddqu(needles.data() + j);
- auto index = __builtin_ia32_pcmpestri128(arr2, needles.size() - j,
- arr1, haystack.size() - i, 0);
- b = std::min<size_t>(index, b);
- }
- if (b < 16) {
- return i + b;
+ int64_t i = 0;
+ for (; i < static_cast<int64_t>(haystack.size()) - 16; i += 16) {
+ auto ret = scanHaystackBlock(haystack, needles, i);
+ if (ret != StringPiece::npos) {
+ return ret;
}
};
+
+ if (i == haystack.size() - 16 ||
+ PAGE_FOR(haystack.end() - 1) == PAGE_FOR(haystack.data() + i + 15)) {
+ return scanHaystackBlock(haystack, needles, i);
+ } else {
+ auto ret = qfind_first_byte_of_nosse(StringPiece(haystack.data() + i,
+ haystack.end()),
+ needles);
+ if (ret != StringPiece::npos) {
+ return i + ret;
+ }
+ }
+
return StringPiece::npos;
}
View
62 folly/test/RangeTest.cpp
@@ -17,11 +17,14 @@
// @author Kristina Holst (kholst@fb.com)
// @author Andrei Alexandrescu (andrei.alexandrescu@fb.com)
+#include "folly/Range.h"
+
#include <limits>
+#include <stdlib.h>
#include <string>
+#include <sys/mman.h>
#include <boost/range/concepts.hpp>
#include <gtest/gtest.h>
-#include "folly/Range.h"
namespace folly { namespace detail {
@@ -336,3 +339,60 @@ TYPED_TEST(NeedleFinderTest, Base) {
}
}
}
+
+const size_t kPageSize = 4096;
+// Updates contents so that any read accesses past the last byte will
+// cause a SIGSEGV. It accomplishes this by changing access to the page that
+// begins immediately after the end of the contents (as allocators and mmap()
+// all operate on page boundaries, this is a reasonable assumption).
+// This function will also initialize buf, which caller must free().
+void createProtectedBuf(StringPiece& contents, char** buf) {
+ ASSERT_LE(contents.size(), kPageSize);
+ const size_t kSuccess = 0;
+ char* tmp;
+ if (kSuccess != posix_memalign((void**)buf, kPageSize, 2 * kPageSize)) {
+ ASSERT_FALSE(true);
+ }
+ mprotect(*buf + kPageSize, kPageSize, PROT_NONE);
+ size_t newBegin = kPageSize - contents.size();
+ memcpy(*buf + newBegin, contents.data(), contents.size());
+ contents.reset(*buf + newBegin, contents.size());
+}
+
+TYPED_TEST(NeedleFinderTest, NoSegFault) {
+ const string base = string(32, 'a') + string("b");
+ const string delims = string(32, 'c') + string("b");
+ for (int i = 0; i <= 32; i++) {
+ for (int j = 0; j <= 33; j++) {
+ for (int shouldFind = 0; shouldFind <= 1; ++shouldFind) {
+ StringPiece s1(base);
+ s1.advance(i);
+ ASSERT_TRUE(!s1.empty());
+ if (!shouldFind) {
+ s1.pop_back();
+ }
+ StringPiece s2(delims);
+ s2.advance(j);
+ char* buf1;
+ char* buf2;
+ createProtectedBuf(s1, &buf1);
+ createProtectedBuf(s2, &buf2);
+ // printf("s1: '%s' (%ld) \ts2: '%s' (%ld)\n",
+ // string(s1.data(), s1.size()).c_str(), s1.size(),
+ // string(s2.data(), s2.size()).c_str(), s2.size());
+ auto r1 = this->find_first_byte_of(s1, s2);
+ auto f1 = std::find_first_of(s1.begin(), s1.end(),
+ s2.begin(), s2.end());
+ auto e1 = (f1 == s1.end()) ? StringPiece::npos : f1 - s1.begin();
+ EXPECT_EQ(r1, e1);
+ auto r2 = this->find_first_byte_of(s2, s1);
+ auto f2 = std::find_first_of(s2.begin(), s2.end(),
+ s1.begin(), s1.end());
+ auto e2 = (f2 == s2.end()) ? StringPiece::npos : f2 - s2.begin();
+ EXPECT_EQ(r2, e2);
+ free(buf1);
+ free(buf2);
+ }
+ }
+ }
+}

0 comments on commit 4f7a54f

Please sign in to comment.
Something went wrong with that request. Please try again.