Skip to content

Commit

Permalink
Merge pull request rust-lang#3 from bluss/memrchr-must-go-faster
Browse files Browse the repository at this point in the history
Provide a faster fallback for memrchr
  • Loading branch information
BurntSushi committed Aug 19, 2015
2 parents 1a2171b + 6a1b5d7 commit 0663286
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 5 deletions.
4 changes: 4 additions & 0 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ fn iterator(b: &mut test::Bencher) {
b.iter(|| {
assert!(haystack.iter().position(|&b| b == needle).is_none());
});
b.bytes = haystack.len() as u64;
}

#[bench]
Expand All @@ -23,6 +24,7 @@ fn libc_memchr(b: &mut test::Bencher) {
b.iter(|| {
assert!(memchr::memchr(needle, &haystack).is_none());
});
b.bytes = haystack.len() as u64;
}

#[bench]
Expand All @@ -32,6 +34,7 @@ fn iterator_reversed(b: &mut test::Bencher) {
b.iter(|| {
assert!(haystack.iter().rposition(|&b| b == needle).is_none());
});
b.bytes = haystack.len() as u64;
}

#[bench]
Expand All @@ -41,4 +44,5 @@ fn libc_memrchr(b: &mut test::Bencher) {
b.iter(|| {
assert!(memchr::memrchr(needle, &haystack).is_none());
});
b.bytes = haystack.len() as u64;
}
133 changes: 128 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ to the corresponding functions in `libc`.
extern crate libc;

use libc::funcs::c95::string;
use libc::types::common::c95::c_void;
use libc::types::os::arch::c95::{c_int, size_t};
use libc::c_void;
use libc::{c_int, size_t};

/// A safe interface to `memchr`.
///
Expand Down Expand Up @@ -78,18 +78,127 @@ pub fn memrchr(needle: u8, haystack: &[u8]) -> Option<usize> {
}
}

#[cfg(not(target_os = "linux"))]
#[cfg(all(not(target_os = "linux"),
any(target_pointer_width = "32", target_pointer_width = "64")))]
fn memrchr_specific(needle: u8, haystack: &[u8]) -> Option<usize> {
fallback::memrchr(needle, haystack)
}

// For the rare case of neither 32 bit nor 64-bit platform.
#[cfg(all(not(target_os = "linux"),
not(target_pointer_width = "32"),
not(target_pointer_width = "64")))]
fn memrchr_specific(needle: u8, haystack: &[u8]) -> Option<usize> {
haystack.iter().rposition(|&b| b == needle)
}

memrchr_specific(needle, haystack)
}

#[cfg(not(target_os = "linux"))]
mod fallback {
use std::cmp;

const LO_U64: u64 = 0x0101010101010101;
const HI_U64: u64 = 0x8080808080808080;

// use truncation
const LO_USIZE: usize = LO_U64 as usize;
const HI_USIZE: usize = HI_U64 as usize;

#[cfg(target_pointer_width = "32")]
const USIZE_BYTES: usize = 4;
#[cfg(target_pointer_width = "64")]
const USIZE_BYTES: usize = 8;

/// Return `true` if `x` contains any zero byte.
///
/// From *Matters Computational*, J. Arndt
///
/// "The idea is to subtract one from each of the bytes and then look for
/// bytes where the borrow propagated all the way to the most significant
/// bit."
#[inline]
fn contains_zero_byte(x: usize) -> bool {
x.wrapping_sub(LO_USIZE) & !x & HI_USIZE != 0
}

#[cfg(target_pointer_width = "32")]
#[inline]
fn repeat_byte(b: u8) -> usize {
let mut rep = (b as usize) << 8 | b as usize;
rep = rep << 16 | rep;
rep
}

#[cfg(target_pointer_width = "64")]
#[inline]
fn repeat_byte(b: u8) -> usize {
let mut rep = (b as usize) << 8 | b as usize;
rep = rep << 16 | rep;
rep = rep << 32 | rep;
rep
}

/// Return the last index matching the byte `a` in `text`.
pub fn memrchr(x: u8, text: &[u8]) -> Option<usize> {
// Scan for a single byte value by reading two `usize` words at a time.
//
// Split `text` in three parts
// - unaligned tail, after the last word aligned address in text
// - body, scan by 2 words at a time
// - the first remaining bytes, < 2 word size
let len = text.len();
let ptr = text.as_ptr();

// search to an aligned boundary
let endptr = unsafe { ptr.offset(text.len() as isize) };
let align = (endptr as usize) & (USIZE_BYTES - 1);
let tail;
if align > 0 {
tail = cmp::min(USIZE_BYTES - align, len);
for (index, &byte) in text[len - tail..].iter().enumerate().rev() {
if byte == x {
return Some(len - tail + index);
}
}
} else {
tail = 0;
}

// search the body of the text
let repeated_x = repeat_byte(x);
let mut offset = len - tail;

while offset >= 2 * USIZE_BYTES {
unsafe {
let u = *(ptr.offset(offset as isize - 2 * USIZE_BYTES as isize) as *const usize);
let v = *(ptr.offset(offset as isize - USIZE_BYTES as isize) as *const usize);

// break if there is a matching byte
let zu = contains_zero_byte(u ^ repeated_x);
let zv = contains_zero_byte(v ^ repeated_x);
if zu || zv {
break;
}
}
offset -= 2 * USIZE_BYTES;
}

// find a zero after the point the body loop stopped
for (index, &byte) in text[..offset].iter().enumerate().rev() {
if byte == x {
return Some(index);
}
}
None
}
}

#[cfg(target_os = "linux")]
mod ffi {
use libc::types::common::c95::c_void;
use libc::types::os::arch::c95::{c_int, size_t};
use libc::c_void;
use libc::{c_int, size_t};
extern {
pub fn memrchr(cx: *const c_void, c: c_int, n: size_t) -> *mut c_void;
}
Expand Down Expand Up @@ -186,4 +295,18 @@ mod tests {
}
quickcheck::quickcheck(prop as fn(u8, Vec<u8>) -> bool);
}

#[test]
fn qc_correct_reversed() {
fn prop(a: Vec<u8>) -> bool {
for byte in 0..256u32 {
let byte = byte as u8;
if memrchr(byte, &a) != a.iter().rposition(|elt| *elt == byte) {
return false;
}
}
true
}
quickcheck::quickcheck(prop as fn(Vec<u8>) -> bool);
}
}

0 comments on commit 0663286

Please sign in to comment.