diff --git a/Cargo.lock b/Cargo.lock index fb0b6d0..6c22aed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -322,7 +322,6 @@ dependencies = [ "serde", "serde_json", "sonic-rs", - "sonic-simd", "v_jsonescape", ] diff --git a/Cargo.toml b/Cargo.toml index cda204f..4073608 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,6 @@ name = "escape" harness = false [dependencies] -sonic-simd = "0.1" [dev-dependencies] criterion2 = "3" @@ -39,6 +38,8 @@ sonic-rs = "0.5" [profile.bench] lto = true codegen-units = 1 +debug = true +strip = false [profile.instruments] inherits = "release" diff --git a/benches/escape.rs b/benches/escape.rs index 02616b1..a60916a 100644 --- a/benches/escape.rs +++ b/benches/escape.rs @@ -41,6 +41,10 @@ fn get_affine_sources() -> Vec { } fn run_benchmarks(c: &mut Criterion, sources: &[String], prefix: &str) { + let first = &sources[0]; + assert_eq!(escape(first), sonic_rs::to_string(first).unwrap()); + assert_eq!(escape(first), serde_json::to_string(first).unwrap()); + c.bench_function(&format!("{} escape simd", prefix), |b| { b.iter(|| { for source in sources { diff --git a/src/lib.rs b/src/lib.rs index e0114ec..7be2b40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,13 @@ //! //! Only takes the string escaping part to avoid the abstraction overhead. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +use std::arch::is_x86_feature_detected; use std::slice::from_raw_parts; -#[cfg(not(all(target_feature = "neon", target_arch = "aarch64")))] -use sonic_simd::u8x32; -use sonic_simd::{BitMask, Mask, Simd}; -#[cfg(all(target_feature = "neon", target_arch = "aarch64"))] -use sonic_simd::{bits::NeonBits, u8x16}; +use simd::{BitMask, Mask, Simd}; + +mod simd; #[inline(always)] unsafe fn load(ptr: *const u8) -> V { @@ -292,6 +292,22 @@ const NEED_ESCAPED: [u8; 256] = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]; +#[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + not(feature = "codspeed") +))] +static COMPUTE_LANES: std::sync::Once = std::sync::Once::new(); +#[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + not(feature = "codspeed") +))] +static mut LANES: usize = simd::avx2::Simd256u::LANES; +#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "codspeed"))] +const LANES: usize = simd::avx2::Simd256u::LANES; + +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] +const LANES: usize = 16; + // only check the src length. #[inline(always)] unsafe fn escape_unchecked(src: &mut *const u8, nb: &mut usize, dst: &mut *mut u8) { @@ -324,36 +340,148 @@ fn check_cross_page(ptr: *const u8, step: usize) -> bool { } #[inline(always)] -fn format_string(value: &str, dst: &mut [u8]) -> usize { - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - let mut v: u8x16; - #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))] - let mut v: u8x32; - - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - const LANES: usize = 16; - #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))] - const LANES: usize = 32; - - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - #[inline] - fn escaped_mask(v: u8x16) -> NeonBits { - let x1f = u8x16::splat(0x1f); // 0x00 ~ 0x20 - let blash = u8x16::splat(b'\\'); - let quote = u8x16::splat(b'"'); - let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); - v.bitmask() - } +fn escaped_mask_generic(v: simd::v128::Simd128u) -> u16 { + use simd::v128::Simd128u as u8x16; + + let x1f = u8x16::splat(0x1f); // 0x00 ~ 0x20 + let blash = u8x16::splat(b'\\'); + let quote = u8x16::splat(b'"'); + let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); + v.bitmask() +} - #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))] - #[inline] - fn escaped_mask(v: u8x32) -> u32 { - let x1f = u8x32::splat(0x1f); // 0x00 ~ 0x20 - let blash = u8x32::splat(b'\\'); - let quote = u8x32::splat(b'"'); - let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); - v.bitmask() - } +#[cfg(target_arch = "aarch64")] +#[inline(always)] +fn escaped_mask_neon(v: simd::neon::Simd128u) -> simd::bits::NeonBits { + use simd::neon::Simd128u as u8x16; + + let x1f = u8x16::splat(0x1f); // 0x00 ~ 0x20 + let blash = u8x16::splat(b'\\'); + let quote = u8x16::splat(b'"'); + let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); + v.bitmask() +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[inline(always)] +fn escaped_mask_sse2(v: simd::sse2::Simd128u) -> u16 { + use simd::sse2::Simd128u as u8x16; + + let x1f = u8x16::splat(0x1f); // 0x00 ~ 0x20 + let blash = u8x16::splat(b'\\'); + let quote = u8x16::splat(b'"'); + let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); + v.bitmask() +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[inline(always)] +fn escaped_mask_avx2(v: simd::avx2::Simd256u) -> u32 { + use simd::avx2::Simd256u as u8x32; + + let x1f = u8x32::splat(0x1f); // 0x00 ~ 0x20 + let blash = u8x32::splat(b'\\'); + let quote = u8x32::splat(b'"'); + let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); + v.bitmask() +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[inline(always)] +fn escaped_mask_avx512(v: simd::avx512::Simd512u) -> u64 { + use simd::avx512::Simd512u as u8x64; + + let x1f = u8x64::splat(0x1f); // 0x00 ~ 0x20 + let blash = u8x64::splat(b'\\'); + let quote = u8x64::splat(b'"'); + let v = v.le(&x1f) | v.eq(&blash) | v.eq("e); + v.bitmask() +} + +macro_rules! escape { + ($mask:expr, $nb:expr, $dptr:expr, $sptr:expr) => { + if $mask.all_zero() { + $nb -= LANES; + $dptr = $dptr.add(LANES); + $sptr = $sptr.add(LANES); + } else { + let cn = $mask.first_offset(); + $nb -= cn; + $dptr = $dptr.add(cn); + $sptr = $sptr.add(cn); + escape_unchecked(&mut $sptr, &mut $nb, &mut $dptr); + } + }; +} + +macro_rules! load_v { + ($placeholder:expr, $sptr:expr, $nb:expr) => {{ + #[cfg(not(any(target_os = "linux", target_os = "macos")))] + { + std::ptr::copy_nonoverlapping($sptr, $placeholder[..].as_mut_ptr(), $nb); + load($placeholder[..].as_ptr()) + } + #[cfg(any(target_os = "linux", target_os = "macos"))] + { + if check_cross_page($sptr, LANES) { + std::ptr::copy_nonoverlapping($sptr, $placeholder[..].as_mut_ptr(), $nb); + load($placeholder[..].as_ptr()) + } else { + #[cfg(any(debug_assertions, miri))] + { + std::ptr::copy_nonoverlapping($sptr, $placeholder[..].as_mut_ptr(), $nb); + load($placeholder[..].as_ptr()) + } + #[cfg(not(any(debug_assertions, miri)))] + { + load($sptr) + } + } + } + }}; +} + +#[inline(always)] +fn format_string(value: &str, dst: &mut [u8]) -> usize { + #[cfg(target_arch = "aarch64")] + let mut v_neon: simd::neon::Simd128u; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let mut v_sse2: simd::sse2::Simd128u; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let mut v_avx2: simd::avx2::Simd256u; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let mut v_avx512: simd::avx512::Simd512u; + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let has_avx512 = is_x86_feature_detected!("avx512f"); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let has_avx2 = is_x86_feature_detected!("avx2"); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + let has_sse2 = is_x86_feature_detected!("sse2"); + + #[cfg(target_arch = "aarch64")] + let has_neon = cfg!(target_os = "macos") || std::arch::is_aarch64_feature_detected!("neon"); + + let mut v_generic: simd::v128::Simd128u; + + #[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + not(feature = "codspeed") + ))] + COMPUTE_LANES.call_once(|| { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("avx512f") { + unsafe { + LANES = simd::avx512::Simd512u::LANES; + } + } else if !is_x86_feature_detected!("avx2") { + unsafe { + LANES = simd::sse2::Simd128u::LANES; + } + } + } + }); unsafe { let slice = value.as_bytes(); @@ -365,66 +493,184 @@ fn format_string(value: &str, dst: &mut [u8]) -> usize { *dptr = b'"'; dptr = dptr.add(1); while nb >= LANES { - v = load(sptr); - v.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut(dptr, LANES)); - let mask = escaped_mask(v); - if mask.all_zero() { - nb -= LANES; - dptr = dptr.add(LANES); - sptr = sptr.add(LANES); - } else { - let cn = mask.first_offset(); - nb -= cn; - dptr = dptr.add(cn); - sptr = sptr.add(cn); - escape_unchecked(&mut sptr, &mut nb, &mut dptr); + #[cfg(target_arch = "aarch64")] + { + if has_neon { + v_neon = load(sptr); + v_neon.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_neon(v_neon); + escape!(mask, nb, dptr, sptr); + } else { + v_generic = load(sptr); + v_generic.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_generic(v_generic); + escape!(mask, nb, dptr, sptr); + } + } + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + { + if has_avx512 { + v_avx512 = load(sptr); + v_avx512.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_avx512(v_avx512); + escape!(mask, nb, dptr, sptr); + } else if has_avx2 { + v_avx2 = load(sptr); + v_avx2.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_avx2(v_avx2); + escape!(mask, nb, dptr, sptr); + } else if has_sse2 { + v_sse2 = load(sptr); + v_sse2.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_sse2(v_sse2); + escape!(mask, nb, dptr, sptr); + } else { + v_generic = load(sptr); + v_generic.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_generic(v_generic); + escape!(mask, nb, dptr, sptr); + } } } - // Scratch buffer reused for mask materialisation; stay uninitialised. - #[cfg(not(miri))] - #[allow(invalid_value, clippy::uninit_assumed_init)] - let mut placeholder: [u8; LANES] = core::mem::MaybeUninit::uninit().assume_init(); - #[cfg(miri)] - let mut placeholder: [u8; LANES] = [0; LANES]; - while nb > 0 { - v = { - #[cfg(not(any(target_os = "linux", target_os = "macos")))] - { - std::ptr::copy_nonoverlapping(sptr, placeholder[..].as_mut_ptr(), nb); - load(placeholder[..].as_ptr()) + #[cfg(target_arch = "aarch64")] + { + if has_neon { + const LANES: usize = simd::neon::Simd128u::LANES; + let mut placeholder: [u8; LANES] = [0; LANES]; + while nb > 0 { + v_neon = load_v!(placeholder, sptr, nb); + v_neon.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_neon(v_neon).clear_high_bits(LANES - nb); + if mask.all_zero() { + dptr = dptr.add(nb); + break; + } else { + let cn = mask.first_offset(); + nb -= cn; + dptr = dptr.add(cn); + sptr = sptr.add(cn); + escape_unchecked(&mut sptr, &mut nb, &mut dptr); + } } - #[cfg(any(target_os = "linux", target_os = "macos"))] - { - if check_cross_page(sptr, LANES) { - std::ptr::copy_nonoverlapping(sptr, placeholder[..].as_mut_ptr(), nb); - load(placeholder[..].as_ptr()) + } else { + const LANES: usize = simd::v128::Simd128u::LANES; + let mut placeholder: [u8; LANES] = [0; LANES]; + while nb > 0 { + v_generic = load_v!(placeholder, sptr, nb); + v_generic.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_generic(v_generic).clear_high_bits(LANES - nb); + if mask.all_zero() { + dptr = dptr.add(nb); + break; } else { - #[cfg(not(debug_assertions))] - { - // disable memory sanitizer here - load(sptr) - } - #[cfg(debug_assertions)] - { - std::ptr::copy_nonoverlapping(sptr, placeholder[..].as_mut_ptr(), nb); - load(placeholder[..].as_ptr()) - } + let cn = mask.first_offset(); + nb -= cn; + dptr = dptr.add(cn); + sptr = sptr.add(cn); + escape_unchecked(&mut sptr, &mut nb, &mut dptr); + } + } + } + } + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + { + if has_avx512 { + const LANES: usize = simd::avx512::Simd512u::LANES; + let mut placeholder: [u8; LANES] = [0; LANES]; + while nb > 0 { + v_avx512 = load_v!(placeholder, sptr, nb); + v_avx512.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_avx512(v_avx512).clear_high_bits(LANES - nb); + if mask.all_zero() { + dptr = dptr.add(nb); + break; + } else { + let cn = mask.first_offset(); + nb -= cn; + dptr = dptr.add(cn); + sptr = sptr.add(cn); + escape_unchecked(&mut sptr, &mut nb, &mut dptr); + } + } + } else if has_avx2 { + const LANES: usize = simd::avx2::Simd256u::LANES; + let mut placeholder: [u8; LANES] = [0; LANES]; + while nb > 0 { + v_avx2 = load_v!(placeholder, sptr, nb); + v_avx2.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_avx2(v_avx2).clear_high_bits(LANES - nb); + if mask.all_zero() { + dptr = dptr.add(nb); + break; + } else { + let cn = mask.first_offset(); + nb -= cn; + dptr = dptr.add(cn); + sptr = sptr.add(cn); + escape_unchecked(&mut sptr, &mut nb, &mut dptr); + } + } + } else if has_sse2 { + const LANES: usize = simd::sse2::Simd128u::LANES; + let mut placeholder: [u8; LANES] = [0; LANES]; + while nb > 0 { + v_sse2 = load_v!(placeholder, sptr, nb); + v_sse2.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_sse2(v_sse2).clear_high_bits(LANES - nb); + if mask.all_zero() { + dptr = dptr.add(nb); + break; + } else { + let cn = mask.first_offset(); + nb -= cn; + dptr = dptr.add(cn); + sptr = sptr.add(cn); + escape_unchecked(&mut sptr, &mut nb, &mut dptr); } } - }; - v.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut(dptr, LANES)); - - let mask = escaped_mask(v).clear_high_bits(LANES - nb); - if mask.all_zero() { - dptr = dptr.add(nb); - break; } else { - let cn = mask.first_offset(); - nb -= cn; - dptr = dptr.add(cn); - sptr = sptr.add(cn); - escape_unchecked(&mut sptr, &mut nb, &mut dptr); + const LANES: usize = simd::v128::Simd128u::LANES; + let mut placeholder: [u8; LANES] = [0; LANES]; + while nb > 0 { + v_generic = load_v!(placeholder, sptr, nb); + v_generic.write_to_slice_unaligned_unchecked(std::slice::from_raw_parts_mut( + dptr, LANES, + )); + let mask = escaped_mask_generic(v_generic).clear_high_bits(LANES - nb); + if mask.all_zero() { + dptr = dptr.add(nb); + break; + } else { + let cn = mask.first_offset(); + nb -= cn; + dptr = dptr.add(cn); + sptr = sptr.add(cn); + escape_unchecked(&mut sptr, &mut nb, &mut dptr); + } + } } } *dptr = b'"'; @@ -436,7 +682,10 @@ fn format_string(value: &str, dst: &mut [u8]) -> usize { pub fn escape(value: &str) -> String { let capacity = value.len() * 6 + 32 + 3; let mut buf = Vec::with_capacity(capacity); - unsafe { buf.set_len(capacity) }; + #[allow(clippy::uninit_vec)] + unsafe { + buf.set_len(capacity) + }; let cnt = format_string(value, &mut buf); unsafe { buf.set_len(cnt) }; unsafe { String::from_utf8_unchecked(buf) } diff --git a/src/simd/README.md b/src/simd/README.md new file mode 100644 index 0000000..aa1b464 --- /dev/null +++ b/src/simd/README.md @@ -0,0 +1,11 @@ +# sonic_simd + +Borrowed from https://github.com/cloudwego/sonic-rs. +With the runtime SIMD features detection rather than compile-time detection. + +A portable SIMD library that provides low-level APIs for x86, ARM. Other platforms will use the fallback scalar implementation. + +TODO: + +1. support RISC-V. +2. support wasm. \ No newline at end of file diff --git a/src/simd/avx2.rs b/src/simd/avx2.rs new file mode 100644 index 0000000..cbad942 --- /dev/null +++ b/src/simd/avx2.rs @@ -0,0 +1,89 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use std::ops::{BitAnd, BitOr, BitOrAssign}; + +use super::{Mask, Simd}; + +#[derive(Debug)] +#[repr(transparent)] +pub struct Simd256u(__m256i); + +#[derive(Debug)] +#[repr(transparent)] +pub struct Mask256(__m256i); + +impl Mask for Mask256 { + type BitMask = u32; + type Element = u8; + + #[inline(always)] + fn bitmask(self) -> Self::BitMask { + unsafe { _mm256_movemask_epi8(self.0) as u32 } + } +} + +impl BitAnd for Mask256 { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Mask256) -> Self::Output { + unsafe { Mask256(_mm256_and_si256(self.0, rhs.0)) } + } +} + +impl BitOr for Mask256 { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Mask256) -> Self::Output { + unsafe { Mask256(_mm256_or_si256(self.0, rhs.0)) } + } +} + +impl BitOrAssign for Mask256 { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Mask256) { + unsafe { self.0 = _mm256_or_si256(self.0, rhs.0) } + } +} + +impl Simd for Simd256u { + const LANES: usize = 32; + type Mask = Mask256; + type Element = u8; + + #[inline(always)] + unsafe fn loadu(ptr: *const u8) -> Self { + unsafe { Simd256u(_mm256_loadu_si256(ptr as *const __m256i)) } + } + + #[inline(always)] + unsafe fn storeu(&self, ptr: *mut u8) { + unsafe { _mm256_storeu_si256(ptr as *mut __m256i, self.0) } + } + + #[inline(always)] + fn eq(&self, rhs: &Self) -> Self::Mask { + unsafe { + let eq = _mm256_cmpeq_epi8(self.0, rhs.0); + Mask256(eq) + } + } + + #[inline(always)] + fn splat(ch: u8) -> Self { + unsafe { Simd256u(_mm256_set1_epi8(ch as i8)) } + } + + #[inline(always)] + fn le(&self, rhs: &Self) -> Self::Mask { + unsafe { + let max = _mm256_max_epu8(self.0, rhs.0); + let eq = _mm256_cmpeq_epi8(max, rhs.0); + Mask256(eq) + } + } +} diff --git a/src/simd/avx512.rs b/src/simd/avx512.rs new file mode 100644 index 0000000..98efdb6 --- /dev/null +++ b/src/simd/avx512.rs @@ -0,0 +1,82 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use std::ops::{BitAnd, BitOr, BitOrAssign}; + +use super::{Mask, Simd}; + +#[derive(Debug)] +#[repr(transparent)] +pub struct Simd512u(__m512i); + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct Mask512(__mmask64); + +impl Mask for Mask512 { + type BitMask = u64; + type Element = u8; + + #[inline(always)] + fn bitmask(self) -> Self::BitMask { + self.0 + } +} + +impl BitOr for Mask512 { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Self) -> Self::Output { + Mask512(self.0 | rhs.0) + } +} + +impl BitOrAssign for Mask512 { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } +} + +impl BitAnd for Mask512 { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Mask512) -> Self::Output { + Mask512(self.0 & rhs.0) + } +} + +impl Simd for Simd512u { + const LANES: usize = 64; + type Element = u8; + type Mask = Mask512; + + #[inline(always)] + unsafe fn loadu(ptr: *const u8) -> Self { + unsafe { Simd512u(_mm512_loadu_si512(ptr as *const __m512i)) } + } + + #[inline(always)] + unsafe fn storeu(&self, ptr: *mut u8) { + unsafe { _mm512_storeu_si512(ptr as *mut __m512i, self.0) } + } + + #[inline(always)] + fn eq(&self, rhs: &Self) -> Self::Mask { + unsafe { Mask512(_mm512_cmpeq_epi8_mask(self.0, rhs.0)) } + } + + #[inline(always)] + fn splat(ch: u8) -> Self { + unsafe { Simd512u(_mm512_set1_epi8(ch as i8)) } + } + + #[inline(always)] + fn le(&self, rhs: &Self) -> Self::Mask { + unsafe { Mask512(_mm512_cmple_epu8_mask(self.0, rhs.0)) } + } +} diff --git a/src/simd/bits.rs b/src/simd/bits.rs new file mode 100644 index 0000000..3bdb694 --- /dev/null +++ b/src/simd/bits.rs @@ -0,0 +1,105 @@ +use super::traits::BitMask; + +macro_rules! impl_bits { + () => {}; + ($($ty:ty)*) => { + $( + impl BitMask for $ty { + const LEN: usize = std::mem::size_of::<$ty>() * 8; + + #[inline] + fn before(&self, rhs: &Self) -> bool { + (self.as_little_endian() & rhs.as_little_endian().wrapping_sub(1)) != 0 + } + + #[inline] + fn first_offset(&self) -> usize { + self.as_little_endian().trailing_zeros() as usize + } + + #[inline] + fn as_little_endian(&self) -> Self { + #[cfg(target_endian = "little")] + { + self.clone() + } + #[cfg(target_endian = "big")] + { + self.swap_bytes() + } + } + + #[inline] + fn all_zero(&self) -> bool { + *self == 0 + } + + #[inline] + fn clear_high_bits(&self, n: usize) -> Self { + debug_assert!(n <= Self::LEN); + *self & ((u64::MAX as $ty) >> n) + } + } + )* + }; +} + +impl_bits!(u16 u32 u64); + +#[cfg(target_arch = "aarch64")] +/// Use u64 representation the bitmask of Neon vector. +/// (low) +/// Vector: 00-ff-ff-ff-ff-00-00-00 +/// Mask : 0000-1111-1111-1111-1111-0000-0000-0000 +/// +/// first_offset() = 1 +/// clear_high_bits(4) = Mask(0000-1111-1111-1111-[0000]-0000-0000-0000) +/// +/// reference: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon +pub struct NeonBits(u64); + +#[cfg(target_arch = "aarch64")] +impl NeonBits { + #[inline] + pub fn new(u: u64) -> Self { + Self(u) + } +} + +#[cfg(target_arch = "aarch64")] +impl BitMask for NeonBits { + const LEN: usize = 16; + + #[inline] + fn first_offset(&self) -> usize { + (self.as_little_endian().0.trailing_zeros() as usize) >> 2 + } + + #[inline] + fn before(&self, rhs: &Self) -> bool { + (self.as_little_endian().0 & rhs.as_little_endian().0.wrapping_sub(1)) != 0 + } + + #[inline] + fn as_little_endian(&self) -> Self { + #[cfg(target_endian = "little")] + { + Self::new(self.0) + } + #[cfg(target_endian = "big")] + { + Self::new(self.0.swap_bytes()) + } + } + + #[inline] + fn all_zero(&self) -> bool { + self.0 == 0 + } + + #[inline] + fn clear_high_bits(&self, n: usize) -> Self { + debug_assert!(n <= Self::LEN); + Self(self.0 & u64::MAX >> (n * 4)) + } +} diff --git a/src/simd/mod.rs b/src/simd/mod.rs new file mode 100644 index 0000000..a4c80ff --- /dev/null +++ b/src/simd/mod.rs @@ -0,0 +1,16 @@ +#![allow(non_camel_case_types)] + +pub mod bits; +mod traits; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +pub(crate) mod avx2; +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +pub(crate) mod avx512; +#[cfg(target_arch = "aarch64")] +pub(crate) mod neon; +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +pub(crate) mod sse2; +pub(crate) mod v128; + +pub use self::traits::{BitMask, Mask, Simd}; diff --git a/src/simd/neon.rs b/src/simd/neon.rs new file mode 100644 index 0000000..d365062 --- /dev/null +++ b/src/simd/neon.rs @@ -0,0 +1,90 @@ +use std::arch::aarch64::*; + +use super::{Mask, Simd, bits::NeonBits}; + +#[derive(Debug)] +#[repr(transparent)] +pub struct Simd128u(uint8x16_t); + +impl Simd for Simd128u { + const LANES: usize = 16; + type Mask = Mask128; + type Element = u8; + + #[inline(always)] + unsafe fn loadu(ptr: *const u8) -> Self { + unsafe { Self(vld1q_u8(ptr)) } + } + + #[inline(always)] + unsafe fn storeu(&self, ptr: *mut u8) { + unsafe { vst1q_u8(ptr, self.0) }; + } + + #[inline(always)] + fn eq(&self, lhs: &Self) -> Self::Mask { + unsafe { Mask128(vceqq_u8(self.0, lhs.0)) } + } + + #[inline(always)] + fn splat(ch: u8) -> Self { + unsafe { Self(vdupq_n_u8(ch)) } + } + + // less or equal + #[inline(always)] + fn le(&self, lhs: &Self) -> Self::Mask { + unsafe { Mask128(vcleq_u8(self.0, lhs.0)) } + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct Mask128(pub(crate) uint8x16_t); + +impl Mask for Mask128 { + type BitMask = NeonBits; + type Element = u8; + + /// Convert Mask Vector 0x00-ff-ff to Bits 0b0000-1111-1111 + /// Reference: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon + #[inline(always)] + fn bitmask(self) -> Self::BitMask { + unsafe { + let v16 = vreinterpretq_u16_u8(self.0); + let sr4 = vshrn_n_u16(v16, 4); + let v64 = vreinterpret_u64_u8(sr4); + NeonBits::new(vget_lane_u64(v64, 0)) + } + } +} + +// Bitwise AND for Mask128 +impl std::ops::BitAnd for Mask128 { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Mask128) -> Self::Output { + unsafe { Self(vandq_u8(self.0, rhs.0)) } + } +} + +// Bitwise OR for Mask128 +impl std::ops::BitOr for Mask128 { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Mask128) -> Self::Output { + unsafe { Self(vorrq_u8(self.0, rhs.0)) } + } +} + +// Bitwise OR assignment for Mask128 +impl std::ops::BitOrAssign for Mask128 { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Mask128) { + unsafe { + self.0 = vorrq_u8(self.0, rhs.0); + } + } +} diff --git a/src/simd/sse2.rs b/src/simd/sse2.rs new file mode 100644 index 0000000..543bc21 --- /dev/null +++ b/src/simd/sse2.rs @@ -0,0 +1,86 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use std::ops::{BitAnd, BitOr, BitOrAssign}; + +use super::{Mask, Simd}; + +#[derive(Debug)] +#[repr(transparent)] +pub struct Simd128u(__m128i); + +#[derive(Debug)] +#[repr(transparent)] +pub struct Mask128(__m128i); + +impl Mask for Mask128 { + type BitMask = u16; + type Element = u8; + + #[inline(always)] + fn bitmask(self) -> Self::BitMask { + unsafe { _mm_movemask_epi8(self.0) as u16 } + } +} + +impl BitAnd for Mask128 { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Mask128) -> Self::Output { + unsafe { Mask128(_mm_and_si128(self.0, rhs.0)) } + } +} + +impl BitOr for Mask128 { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Mask128) -> Self::Output { + unsafe { Mask128(_mm_or_si128(self.0, rhs.0)) } + } +} + +impl BitOrAssign for Mask128 { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Mask128) { + self.0 = unsafe { _mm_or_si128(self.0, rhs.0) }; + } +} + +impl Simd for Simd128u { + const LANES: usize = 16; + type Mask = Mask128; + type Element = u8; + + #[inline(always)] + unsafe fn loadu(ptr: *const u8) -> Self { + Simd128u(unsafe { _mm_loadu_si128(ptr as *const __m128i) }) + } + + #[inline(always)] + unsafe fn storeu(&self, ptr: *mut u8) { + unsafe { _mm_storeu_si128(ptr as *mut __m128i, self.0) } + } + + #[inline(always)] + fn eq(&self, rhs: &Self) -> Self::Mask { + Mask128(unsafe { _mm_cmpeq_epi8(self.0, rhs.0) }) + } + + #[inline(always)] + fn splat(ch: u8) -> Self { + Simd128u(unsafe { _mm_set1_epi8(ch as i8) }) + } + + #[inline(always)] + fn le(&self, rhs: &Self) -> Self::Mask { + unsafe { + let max = _mm_max_epu8(self.0, rhs.0); + let eq = _mm_cmpeq_epi8(max, rhs.0); + Mask128(eq) + } + } +} diff --git a/src/simd/traits.rs b/src/simd/traits.rs new file mode 100644 index 0000000..985e262 --- /dev/null +++ b/src/simd/traits.rs @@ -0,0 +1,64 @@ +use std::ops::{BitAnd, BitOr, BitOrAssign}; + +/// Portable SIMD traits +pub trait Simd: Sized { + const LANES: usize; + + type Element; + type Mask: Mask; + + /// # Safety + unsafe fn from_slice_unaligned_unchecked(slice: &[u8]) -> Self { + debug_assert!(slice.len() >= Self::LANES); + unsafe { Self::loadu(slice.as_ptr()) } + } + + /// # Safety + unsafe fn write_to_slice_unaligned_unchecked(&self, slice: &mut [u8]) { + debug_assert!(slice.len() >= Self::LANES); + unsafe { self.storeu(slice.as_mut_ptr()) } + } + + /// # Safety + unsafe fn loadu(ptr: *const u8) -> Self; + + /// # Safety + unsafe fn storeu(&self, ptr: *mut u8); + + fn eq(&self, rhs: &Self) -> Self::Mask; + + fn splat(elem: Self::Element) -> Self; + + /// less or equal + fn le(&self, rhs: &Self) -> Self::Mask; +} + +/// Portable SIMD mask traits +pub trait Mask: Sized + BitOr + BitOrAssign + BitAnd { + type Element; + type BitMask: BitMask; + + fn bitmask(self) -> Self::BitMask; +} + +/// Trait for the bitmask of a vector Mask. +pub trait BitMask { + /// Total bits in the bitmask. + const LEN: usize; + + /// get the offset of the first `1` bit. + fn first_offset(&self) -> usize; + + #[allow(unused)] + /// check if this bitmask is before the other bitmask. + fn before(&self, rhs: &Self) -> bool; + + /// convert bitmask as little endian + fn as_little_endian(&self) -> Self; + + /// whether all bits are zero. + fn all_zero(&self) -> bool; + + /// clear high n bits. + fn clear_high_bits(&self, n: usize) -> Self; +} diff --git a/src/simd/v128.rs b/src/simd/v128.rs new file mode 100644 index 0000000..8e03fa1 --- /dev/null +++ b/src/simd/v128.rs @@ -0,0 +1,101 @@ +use std::ops::{BitAnd, BitOr, BitOrAssign}; + +use super::{Mask, Simd}; + +#[derive(Debug)] +pub struct Simd128u([u8; 16]); + +#[derive(Debug)] +pub struct Mask128(pub(crate) [u8; 16]); + +impl Simd for Simd128u { + type Element = u8; + const LANES: usize = 16; + type Mask = Mask128; + + unsafe fn loadu(ptr: *const u8) -> Self { + let v = unsafe { std::slice::from_raw_parts(ptr, Self::LANES) }; + let mut res = [0u8; 16]; + res.copy_from_slice(v); + Self(res) + } + + unsafe fn storeu(&self, ptr: *mut u8) { + let data = &self.0; + unsafe { std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, Self::LANES) }; + } + + fn eq(&self, rhs: &Self) -> Self::Mask { + let mut mask = [0u8; 16]; + for (i, item) in mask.iter_mut().enumerate().take(Self::LANES) { + *item = if self.0[i] == rhs.0[i] { 1 } else { 0 }; + } + Mask128(mask) + } + + fn splat(value: u8) -> Self { + Self([value; Self::LANES]) + } + + fn le(&self, rhs: &Self) -> Self::Mask { + let mut mask = [0u8; 16]; + for i in 0..Self::LANES { + mask[i] = if self.0[i] <= rhs.0[i] { 1 } else { 0 }; + } + Mask128(mask) + } +} + +impl Mask for Mask128 { + type BitMask = u16; + type Element = u8; + + fn bitmask(self) -> Self::BitMask { + #[cfg(target_endian = "little")] + { + self.0 + .iter() + .enumerate() + .fold(0, |acc, (i, &b)| acc | ((b as u16) << i)) + } + #[cfg(target_endian = "big")] + { + self.0 + .iter() + .enumerate() + .fold(0, |acc, (i, &b)| acc | ((b as u16) << (15 - i))) + } + } +} + +impl BitAnd for Mask128 { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[i] & rhs.0[i]; + } + Mask128(result) + } +} + +impl BitOr for Mask128 { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[i] | rhs.0[i]; + } + Mask128(result) + } +} + +impl BitOrAssign for Mask128 { + fn bitor_assign(&mut self, rhs: Self) { + for i in 0..16 { + self.0[i] |= rhs.0[i]; + } + } +}