diff --git a/rust/.cargo/config.toml b/rust/.cargo/config.toml index b327ff53f8..6102a36c00 100644 --- a/rust/.cargo/config.toml +++ b/rust/.cargo/config.toml @@ -30,11 +30,5 @@ rustflags = [ [target.x86_64-unknown-linux-gnu] rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"] -[target.x86_64-unknown-linux-gnu-avx512bf16] -rustflags = ["-C", "target-feature=+avx2,+fma,+f16c,+avx512bf16"] - [target.aarch64-apple-darwin] rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"] - -[target.aarch64-apple-darwin-m2] -rustflags = ["-C", "target-cpu=apple-m2", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod,+bf16"] diff --git a/rust/lance-linalg/build.rs b/rust/lance-linalg/build.rs index 3ba4ddeb68..29192df8ec 100644 --- a/rust/lance-linalg/build.rs +++ b/rust/lance-linalg/build.rs @@ -50,7 +50,8 @@ fn main() { .compile("f16"); } - if cfg!(all(feature = "avx512bf16")) { + if cfg!(all(target_os = "linux", feature = "avx512bf16")) { + // No enable bf16 kernels on sappphire rapids cc::Build::new() .compiler("clang") .std("c17") diff --git a/rust/lance-linalg/src/distance/l2.rs b/rust/lance-linalg/src/distance/l2.rs index 5e1e03f413..a34a937364 100644 --- a/rust/lance-linalg/src/distance/l2.rs +++ b/rust/lance-linalg/src/distance/l2.rs @@ -118,7 +118,8 @@ impl L2 for BFloat16Type { #[cfg(any( all(target_os = "macos", target_feature = "neon"), - all(target_os = "linux", feature = "avx512fp16") + all(target_os = "linux", feature = "avx512fp16"), + all(target_os = "linux", feature = "avx512bf16") ))] mod kernel { use super::*; diff --git a/rust/lance-linalg/src/distance/norm_l2.rs b/rust/lance-linalg/src/distance/norm_l2.rs index 1483ab68b0..05f63a05ad 100644 --- a/rust/lance-linalg/src/distance/norm_l2.rs +++ b/rust/lance-linalg/src/distance/norm_l2.rs @@ -28,17 +28,19 @@ pub trait Normalize { fn norm_l2(&self) -> f32; } -// `avx512fp16` is not supported in rustc yet. Once it is supported, we can -// move it to target_feture. +// `avx512fp16` and `avx512bf16` are not supported in rustc yet. Once they are +// supported, we can move this to target_feture. #[cfg(any( all(target_os = "macos", target_feature = "neon"), - feature = "avx512fp16" + all(target_os = "linux", feature = "avx512fp16"), + all(target_os = "linux", feature = "avx512bf16"), ))] mod kernel { use super::*; extern "C" { pub fn norm_l2_f16(ptr: *const f16, len: u32) -> f32; + pub fn norm_l2_bf16(ptr: *const bf16, len: u32) -> f32; } } @@ -88,32 +90,18 @@ impl Normalize for &[f16] { } } -// `avx512bf16` is not supported in rustc yet. Once it is supported, we can -// move it to target_feture. -#[cfg(any( -all(target_os = "macos", target_feature = "neon"), -feature = "avx512bf16", -))] -mod kernel { - use super::*; - - extern "C" { - pub fn norm_l2_bf16(ptr: *const bf16, len: u32) -> f32; - } -} - impl Normalize for &[bf16] { fn norm_l2(&self) -> f32 { #[cfg(any( all(target_os = "macos", target_feature = "neon"), - feature = "avx512bf16", + all(target_os = "linux", target_feature = "avx512bf16"), ))] unsafe { kernel::norm_l2_bf16(self.as_ptr(), self.len() as u32) } #[cfg(not(any( all(target_os = "macos", target_feature = "neon"), - feature = "avx512bf16", + all(target_os = "linux", target_feature = "avx512bf16"), )))] norm_l2_impl::(self) } diff --git a/rust/lance-linalg/src/simd/bf16.c b/rust/lance-linalg/src/simd/bf16.c index 402500af01..4cfc926a94 100644 --- a/rust/lance-linalg/src/simd/bf16.c +++ b/rust/lance-linalg/src/simd/bf16.c @@ -16,10 +16,6 @@ #include #include -// See https://github.com/ashvardanian/SimSIMD/blob/main/include/simsimd/spatial.h -// https://www.intel.com/content/www/us/en/developer/articles/technical/intel-deep-learning-boost-new-instruction-bfloat16.html -/// Needs avx512f, avx512bf16 and avx512vl - /// @brief Computes the L2 norm of a bf16 vector. /// @param x A bf16 vector /// @param dimension The dimension of the vectors