diff --git a/spare_kernels/aarch64_neon_4x4.rs b/spare_kernels/aarch64_neon_4x4.rs new file mode 100644 index 0000000..319b134 --- /dev/null +++ b/spare_kernels/aarch64_neon_4x4.rs @@ -0,0 +1,150 @@ +#[cfg(target_arch="aarch64")] +struct KernelArmNeon; + +#[cfg(target_arch="aarch64")] +impl GemmKernel for KernelArmNeon { + type Elem = T; + + type MRTy = U4; + type NRTy = U4; + + #[inline(always)] + fn align_to() -> usize { 16 } + + #[inline(always)] + fn always_masked() -> bool { false } + + #[inline(always)] + fn nc() -> usize { archparam::S_NC } + #[inline(always)] + fn kc() -> usize { archparam::S_KC } + #[inline(always)] + fn mc() -> usize { archparam::S_MC } + + #[inline(always)] + unsafe fn kernel( + k: usize, + alpha: T, + a: *const T, + b: *const T, + beta: T, + c: *mut T, rsc: isize, csc: isize) { + kernel_target_arm_neon(k, alpha, a, b, beta, c, rsc, csc) + } +} + +// 4x4 neon kernel unrolled developed for apple silicon M1 +#[cfg(target_arch="aarch64")] +#[target_feature(enable="neon")] +unsafe fn kernel_target_arm_neon(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + use core::arch::aarch64::*; + const MR: usize = KernelArmNeon::MR; + const NR: usize = KernelArmNeon::NR; + + let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) }; + + let mut ab = [vmovq_n_f32(0.); MR]; + let mut ab2 = [vmovq_n_f32(0.); MR]; + let mut ab3 = [vmovq_n_f32(0.); MR]; + let mut ab4 = [vmovq_n_f32(0.); MR]; + let use_fma = true; + + // Compute + // ab_ij = a_i * b_j for all i, j + macro_rules! ab_ij_equals_ai_bj { + ($dest:ident, $av:expr, $bv:expr) => { + if use_fma { + $dest[0] = vfmaq_laneq_f32($dest[0], $bv, $av, 0); + $dest[1] = vfmaq_laneq_f32($dest[1], $bv, $av, 1); + $dest[2] = vfmaq_laneq_f32($dest[2], $bv, $av, 2); + $dest[3] = vfmaq_laneq_f32($dest[3], $bv, $av, 3); + } else { + $dest[0] = vaddq_f32($dest[0], vmulq_laneq_f32($bv, $av, 0)); + $dest[1] = vaddq_f32($dest[1], vmulq_laneq_f32($bv, $av, 1)); + $dest[2] = vaddq_f32($dest[2], vmulq_laneq_f32($bv, $av, 2)); + $dest[3] = vaddq_f32($dest[3], vmulq_laneq_f32($bv, $av, 3)); + } + } + } + + const UNROLL_BY: usize = 4; + + for _ in 0..k / UNROLL_BY { + let av = vld1q_f32(a); + let bv = vld1q_f32(b); + // eprintln!("a: {av:?}"); + // eprintln!("b: {bv:?}"); + + // FMLA instruction + // Cortex 7A: FMA has 7 cycles latency or 3 cycles when the dependency is on the accumulator + // M1: Latency 3, throughput 0.25 + ab_ij_equals_ai_bj!(ab, av, bv); + + let av = vld1q_f32(a.add(4)); + let bv = vld1q_f32(b.add(4)); + + ab_ij_equals_ai_bj!(ab2, av, bv); + + if UNROLL_BY > 2 { + + let av = vld1q_f32(a.add(8)); + let bv = vld1q_f32(b.add(8)); + + ab_ij_equals_ai_bj!(ab3, av, bv); + + let av = vld1q_f32(a.add(12)); + let bv = vld1q_f32(b.add(12)); + + ab_ij_equals_ai_bj!(ab4, av, bv); + + } + + a = a.offset(UNROLL_BY as isize * MR as isize); + b = b.offset(UNROLL_BY as isize * NR as isize); + } + + for _ in 0..k % UNROLL_BY { + let av = vld1q_f32(a); + let bv = vld1q_f32(b); + + ab_ij_equals_ai_bj!(ab, av, bv); + + a = a.offset(MR as isize); + b = b.offset(NR as isize); + } + + macro_rules! c { + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + macro_rules! extract { + ($v:expr, $imm:expr) => ( + f32::from_bits(vgetq_lane_u32(core::mem::transmute::<_, uint32x4_t>($v), $imm)) + ) + } + + // Combine accumulators and multiply by alpha + loop4!(i, ab[i] = vaddq_f32(vaddq_f32(ab[i], ab2[i]), vaddq_f32(ab3[i], ab4[i]))); + loop4!(i, ab[i] = vmulq_n_f32(ab[i], alpha)); + + if beta == 0. { + // set C = α A B + if csc == 1 { + loop4!(i, vst1q_f32(c![i, 0], ab[i])); + } else { + loop4!(i, vst1q_lane_f32(c![i, 0], ab[i], 0)); + loop4!(i, vst1q_lane_f32(c![i, 1], ab[i], 1)); + loop4!(i, vst1q_lane_f32(c![i, 2], ab[i], 2)); + loop4!(i, vst1q_lane_f32(c![i, 3], ab[i], 3)); + } + } else { + // set C = α A B + beta C + loop4!(i, *c![i, 0] = *c![i, 0] * beta + extract!(ab[i], 0)); + loop4!(i, *c![i, 1] = *c![i, 1] * beta + extract!(ab[i], 1)); + loop4!(i, *c![i, 2] = *c![i, 2] * beta + extract!(ab[i], 2)); + loop4!(i, *c![i, 3] = *c![i, 3] * beta + extract!(ab[i], 3)); + } +} + diff --git a/spare_kernels/x86_sse_sgemm.rs b/spare_kernels/x86_sse_sgemm.rs new file mode 100644 index 0000000..720c93c --- /dev/null +++ b/spare_kernels/x86_sse_sgemm.rs @@ -0,0 +1,84 @@ + +// 4x4 sse sgemm +macro_rules! mm_transpose4 { + ($c0:expr, $c1:expr, $c2:expr, $c3:expr) => {{ + // This is _MM_TRANSPOSE4_PS except we take variables, not references + let tmp0 = _mm_unpacklo_ps($c0, $c1); + let tmp2 = _mm_unpacklo_ps($c2, $c3); + let tmp1 = _mm_unpackhi_ps($c0, $c1); + let tmp3 = _mm_unpackhi_ps($c2, $c3); + + $c0 = _mm_movelh_ps(tmp0, tmp2); + $c1 = _mm_movehl_ps(tmp2, tmp0); + $c2 = _mm_movelh_ps(tmp1, tmp3); + $c3 = _mm_movehl_ps(tmp3, tmp1); + }} +} + +#[inline(always)] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_x86_sse(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + let mut ab = [_mm_setzero_ps(); MR]; + + let mut bv; + let (mut a, mut b) = (a, b); + + // Compute A B + for _ in 0..k { + bv = _mm_load_ps(b as _); // aligned due to GemmKernel::align_to + + loop_m!(i, { + // Compute ab_i += [ai b_j+0, ai b_j+1, ai b_j+2, ai b_j+3] + let aiv = _mm_set1_ps(at(a, i)); + ab[i] = _mm_add_ps(ab[i], _mm_mul_ps(aiv, bv)); + }); + + a = a.add(MR); + b = b.add(NR); + } + + // Compute α (A B) + let alphav = _mm_set1_ps(alpha); + loop_m!(i, ab[i] = _mm_mul_ps(alphav, ab[i])); + + macro_rules! c { + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // C ← α A B + β C + let mut c = [_mm_setzero_ps(); MR]; + let betav = _mm_set1_ps(beta); + if beta != 0. { + // Read C + if csc == 1 { + loop_m!(i, c[i] = _mm_loadu_ps(c![i, 0])); + } else if rsc == 1 { + loop_m!(i, c[i] = _mm_loadu_ps(c![0, i])); + mm_transpose4!(c[0], c[1], c[2], c[3]); + } else { + loop_m!(i, c[i] = _mm_set_ps(*c![i, 3], *c![i, 2], *c![i, 1], *c![i, 0])); + } + // Compute β C + loop_m!(i, c[i] = _mm_mul_ps(c[i], betav)); + } + + // Compute (α A B) + (β C) + loop_m!(i, c[i] = _mm_add_ps(c[i], ab[i])); + + // Store C back to memory + if csc == 1 { + loop_m!(i, _mm_storeu_ps(c![i, 0], c[i])); + } else if rsc == 1 { + mm_transpose4!(c[0], c[1], c[2], c[3]); + loop_m!(i, _mm_storeu_ps(c![0, i], c[i])); + } else { + // extract the nth value of a vector using _mm_cvtss_f32 (extract lowest) + // in combination with shuffle (move nth value to first position) + loop_m!(i, *c![i, 0] = _mm_cvtss_f32(c[i])); + loop_m!(i, *c![i, 1] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 1))); + loop_m!(i, *c![i, 2] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 2))); + loop_m!(i, *c![i, 3] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 3))); + } +}