Skip to content

Commit

Permalink
Revert "Auto merge of rust-lang#103779 - the8472:simd-str-contains, r…
Browse files Browse the repository at this point in the history
…=thomcc"

The current implementation seems to be unsound. See rust-lang#104726.
  • Loading branch information
pietroalbini committed Nov 22, 2022
1 parent a78c9be commit 7953508
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 311 deletions.
65 changes: 7 additions & 58 deletions library/alloc/benches/str.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use core::iter::Iterator;
use test::{black_box, Bencher};

#[bench]
Expand Down Expand Up @@ -123,13 +122,14 @@ fn bench_contains_short_short(b: &mut Bencher) {
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
let needle = "sit";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(black_box(haystack).contains(black_box(needle)));
assert!(haystack.contains(needle));
})
}

static LONG_HAYSTACK: &str = "\
#[bench]
fn bench_contains_short_long(b: &mut Bencher) {
let haystack = "\
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse quis lorem sit amet dolor \
ultricies condimentum. Praesent iaculis purus elit, ac malesuada quam malesuada in. Duis sed orci \
eros. Suspendisse sit amet magna mollis, mollis nunc luctus, imperdiet mi. Integer fringilla non \
Expand Down Expand Up @@ -164,48 +164,10 @@ feugiat. Etiam quis mauris vel risus luctus mattis a a nunc. Nullam orci quam, i
vehicula in, porttitor ut nibh. Duis sagittis adipiscing nisl vitae congue. Donec mollis risus eu \
leo suscipit, varius porttitor nulla porta. Pellentesque ut sem nec nisi euismod vehicula. Nulla \
malesuada sollicitudin quam eu fermentum.";

#[bench]
fn bench_contains_2b_repeated_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "::";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_short_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "english";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_16b_in_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "english language";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_32b_in_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "the english language sample text";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
assert!(!haystack.contains(needle));
})
}

Expand All @@ -214,20 +176,8 @@ fn bench_contains_bad_naive(b: &mut Bencher) {
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let needle = "aaaaaaaab";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_bad_simd(b: &mut Bencher) {
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let needle = "aaabaaaa";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
assert!(!haystack.contains(needle));
})
}

Expand All @@ -236,9 +186,8 @@ fn bench_contains_equal(b: &mut Bencher) {
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
let needle = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(black_box(haystack).contains(black_box(needle)));
assert!(haystack.contains(needle));
})
}

Expand Down
26 changes: 5 additions & 21 deletions library/alloc/tests/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,27 +1590,11 @@ fn test_bool_from_str() {
assert_eq!("not even a boolean".parse::<bool>().ok(), None);
}

fn check_contains_all_substrings(haystack: &str) {
let mut modified_needle = String::new();

for i in 0..haystack.len() {
// check different haystack lengths since we special-case short haystacks.
let haystack = &haystack[0..i];
assert!(haystack.contains(""));
for j in 0..haystack.len() {
for k in j + 1..=haystack.len() {
let needle = &haystack[j..k];
assert!(haystack.contains(needle));
modified_needle.clear();
modified_needle.push_str(needle);
modified_needle.replace_range(0..1, "\0");
assert!(!haystack.contains(&modified_needle));

modified_needle.clear();
modified_needle.push_str(needle);
modified_needle.replace_range(needle.len() - 1..needle.len(), "\0");
assert!(!haystack.contains(&modified_needle));
}
fn check_contains_all_substrings(s: &str) {
assert!(s.contains(""));
for i in 0..s.len() {
for j in i + 1..=s.len() {
assert!(s.contains(&s[i..j]));
}
}
}
Expand Down
232 changes: 0 additions & 232 deletions library/core/src/str/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
)]

use crate::cmp;
use crate::cmp::Ordering;
use crate::fmt;
use crate::slice::memchr;

Expand Down Expand Up @@ -947,32 +946,6 @@ impl<'a, 'b> Pattern<'a> for &'b str {
haystack.as_bytes().starts_with(self.as_bytes())
}

/// Checks whether the pattern matches anywhere in the haystack
#[inline]
fn is_contained_in(self, haystack: &'a str) -> bool {
if self.len() == 0 {
return true;
}

match self.len().cmp(&haystack.len()) {
Ordering::Less => {
if self.len() == 1 {
return haystack.as_bytes().contains(&self.as_bytes()[0]);
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
if self.len() <= 32 {
if let Some(result) = simd_contains(self, haystack) {
return result;
}
}

self.into_searcher(haystack).next_match().is_some()
}
_ => self == haystack,
}
}

/// Removes the pattern from the front of haystack, if it matches.
#[inline]
fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> {
Expand Down Expand Up @@ -1711,208 +1684,3 @@ impl TwoWayStrategy for RejectAndMatch {
SearchStep::Match(a, b)
}
}

/// SIMD search for short needles based on
/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
///
/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
/// does) by probing the first and last byte of the needle for the whole vector width
/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
///
/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
/// should be evaluated.
///
/// For haystacks smaller than vector-size + needle length it falls back to
/// a naive O(n*m) search so this implementation should not be called on larger needles.
///
/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
#[inline]
fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
let needle = needle.as_bytes();
let haystack = haystack.as_bytes();

debug_assert!(needle.len() > 1);

use crate::ops::BitAnd;
use crate::simd::mask8x16 as Mask;
use crate::simd::u8x16 as Block;
use crate::simd::{SimdPartialEq, ToBitMask};

let first_probe = needle[0];

// the offset used for the 2nd vector
let second_probe_offset = if needle.len() == 2 {
// never bail out on len=2 needles because the probes will fully cover them and have
// no degenerate cases.
1
} else {
// try a few bytes in case first and last byte of the needle are the same
let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
// fall back to other search methods if we can't find any different bytes
// since we could otherwise hit some degenerate cases
return None;
};
second_probe_offset
};

// do a naive search if the haystack is too small to fit
if haystack.len() < Block::LANES + second_probe_offset {
return Some(haystack.windows(needle.len()).any(|c| c == needle));
}

let first_probe: Block = Block::splat(first_probe);
let second_probe: Block = Block::splat(needle[second_probe_offset]);
// first byte are already checked by the outer loop. to verify a match only the
// remainder has to be compared.
let trimmed_needle = &needle[1..];

// this #[cold] is load-bearing, benchmark before removing it...
let check_mask = #[cold]
|idx, mask: u16, skip: bool| -> bool {
if skip {
return false;
}

// and so is this. optimizations are weird.
let mut mask = mask;

while mask != 0 {
let trailing = mask.trailing_zeros();
let offset = idx + trailing as usize + 1;
// SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
// and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
unsafe {
let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
if small_slice_eq(sub, trimmed_needle) {
return true;
}
}
mask &= !(1 << trailing);
}
return false;
};

let test_chunk = |idx| -> u16 {
// SAFETY: this requires at least LANES bytes being readable at idx
// that is ensured by the loop ranges (see comments below)
let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
// SAFETY: this requires LANES + block_offset bytes being readable at idx
let b: Block = unsafe {
haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
};
let eq_first: Mask = a.simd_eq(first_probe);
let eq_last: Mask = b.simd_eq(second_probe);
let both = eq_first.bitand(eq_last);
let mask = both.to_bitmask();

return mask;
};

let mut i = 0;
let mut result = false;
// The loop condition must ensure that there's enough headroom to read LANE bytes,
// and not only at the current index but also at the index shifted by block_offset
const UNROLL: usize = 4;
while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
let mut masks = [0u16; UNROLL];
for j in 0..UNROLL {
masks[j] = test_chunk(i + j * Block::LANES);
}
for j in 0..UNROLL {
let mask = masks[j];
if mask != 0 {
result |= check_mask(i + j * Block::LANES, mask, result);
}
}
i += UNROLL * Block::LANES;
}
while i + second_probe_offset + Block::LANES < haystack.len() && !result {
let mask = test_chunk(i);
if mask != 0 {
result |= check_mask(i, mask, result);
}
i += Block::LANES;
}

// Process the tail that didn't fit into LANES-sized steps.
// This simply repeats the same procedure but as right-aligned chunk instead
// of a left-aligned one. The last byte must be exactly flush with the string end so
// we don't miss a single byte or read out of bounds.
let i = haystack.len() - second_probe_offset - Block::LANES;
let mask = test_chunk(i);
if mask != 0 {
result |= check_mask(i, mask, result);
}

Some(result)
}

/// Compares short slices for equality.
///
/// It avoids a call to libc's memcmp which is faster on long slices
/// due to SIMD optimizations but it incurs a function call overhead.
///
/// # Safety
///
/// Both slices must have the same length.
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
#[inline]
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
// This function is adapted from
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32

// If we don't have enough bytes to do 4-byte at a time loads, then
// fall back to the naive slow version.
//
// Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
// of a loop. Benchmark it.
if x.len() < 4 {
for (&b1, &b2) in x.iter().zip(y) {
if b1 != b2 {
return false;
}
}
return true;
}
// When we have 4 or more bytes to compare, then proceed in chunks of 4 at
// a time using unaligned loads.
//
// Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
// that this particular version of memcmp is likely to be called with tiny
// needles. That means that if we do 8 byte loads, then a higher proportion
// of memcmp calls will use the slower variant above. With that said, this
// is a hypothesis and is only loosely supported by benchmarks. There's
// likely some improvement that could be made here. The main thing here
// though is to optimize for latency, not throughput.

// SAFETY: Via the conditional above, we know that both `px` and `py`
// have the same length, so `px < pxend` implies that `py < pyend`.
// Thus, derefencing both `px` and `py` in the loop below is safe.
//
// Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
// end of of `px` and `py`. Thus, the final dereference outside of the
// loop is guaranteed to be valid. (The final comparison will overlap with
// the last comparison done in the loop for lengths that aren't multiples
// of four.)
//
// Finally, we needn't worry about alignment here, since we do unaligned
// loads.
unsafe {
let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
while px < pxend {
let vx = (px as *const u32).read_unaligned();
let vy = (py as *const u32).read_unaligned();
if vx != vy {
return false;
}
px = px.add(4);
py = py.add(4);
}
let vx = (pxend as *const u32).read_unaligned();
let vy = (pyend as *const u32).read_unaligned();
vx == vy
}
}

0 comments on commit 7953508

Please sign in to comment.