Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Dec 8, 2023
1 parent ce5dcb7 commit 1578a24
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 31 deletions.
6 changes: 0 additions & 6 deletions rust/.cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,5 @@ rustflags = [
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]

[target.x86_64-unknown-linux-gnu-avx512bf16]
rustflags = ["-C", "target-feature=+avx2,+fma,+f16c,+avx512bf16"]

[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]

[target.aarch64-apple-darwin-m2]
rustflags = ["-C", "target-cpu=apple-m2", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod,+bf16"]
3 changes: 2 additions & 1 deletion rust/lance-linalg/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ fn main() {
.compile("f16");
}

if cfg!(all(feature = "avx512bf16")) {
if cfg!(all(target_os = "linux", feature = "avx512bf16")) {
// No enable bf16 kernels on sappphire rapids
cc::Build::new()
.compiler("clang")
.std("c17")
Expand Down
3 changes: 2 additions & 1 deletion rust/lance-linalg/src/distance/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ impl L2 for BFloat16Type {

#[cfg(any(
all(target_os = "macos", target_feature = "neon"),
all(target_os = "linux", feature = "avx512fp16")
all(target_os = "linux", feature = "avx512fp16"),
all(target_os = "linux", feature = "avx512bf16")
))]
mod kernel {
use super::*;
Expand Down
26 changes: 7 additions & 19 deletions rust/lance-linalg/src/distance/norm_l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,19 @@ pub trait Normalize<T: Float> {
fn norm_l2(&self) -> f32;
}

// `avx512fp16` is not supported in rustc yet. Once it is supported, we can
// move it to target_feture.
// `avx512fp16` and `avx512bf16` are not supported in rustc yet. Once they are
// supported, we can move this to target_feture.
#[cfg(any(
all(target_os = "macos", target_feature = "neon"),
feature = "avx512fp16"
all(target_os = "linux", feature = "avx512fp16"),
all(target_os = "linux", feature = "avx512bf16"),
))]
mod kernel {
use super::*;

extern "C" {
pub fn norm_l2_f16(ptr: *const f16, len: u32) -> f32;
pub fn norm_l2_bf16(ptr: *const bf16, len: u32) -> f32;
}
}

Expand Down Expand Up @@ -88,32 +90,18 @@ impl Normalize<f16> for &[f16] {
}
}

// `avx512bf16` is not supported in rustc yet. Once it is supported, we can
// move it to target_feture.
#[cfg(any(
all(target_os = "macos", target_feature = "neon"),
feature = "avx512bf16",
))]
mod kernel {
use super::*;

extern "C" {
pub fn norm_l2_bf16(ptr: *const bf16, len: u32) -> f32;
}
}

impl Normalize<bf16> for &[bf16] {
fn norm_l2(&self) -> f32 {
#[cfg(any(
all(target_os = "macos", target_feature = "neon"),
feature = "avx512bf16",
all(target_os = "linux", target_feature = "avx512bf16"),
))]
unsafe {
kernel::norm_l2_bf16(self.as_ptr(), self.len() as u32)
}
#[cfg(not(any(
all(target_os = "macos", target_feature = "neon"),
feature = "avx512bf16",
all(target_os = "linux", target_feature = "avx512bf16"),
)))]
norm_l2_impl::<bf16, 32>(self)
}
Expand Down
4 changes: 0 additions & 4 deletions rust/lance-linalg/src/simd/bf16.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
#include <stdint.h>
#include <immintrin.h>

// See https://github.com/ashvardanian/SimSIMD/blob/main/include/simsimd/spatial.h
// https://www.intel.com/content/www/us/en/developer/articles/technical/intel-deep-learning-boost-new-instruction-bfloat16.html
/// Needs avx512f, avx512bf16 and avx512vl

/// @brief Computes the L2 norm of a bf16 vector.
/// @param x A bf16 vector
/// @param dimension The dimension of the vectors
Expand Down

0 comments on commit 1578a24

Please sign in to comment.