Skip to content

Commit

Permalink
Add moved files back
Browse files Browse the repository at this point in the history
  • Loading branch information
xander-zitara committed May 2, 2023
1 parent d39d571 commit b986816
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
150 changes: 150 additions & 0 deletions spare_kernels/aarch64_neon_4x4.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#[cfg(target_arch="aarch64")]
struct KernelArmNeon;

#[cfg(target_arch="aarch64")]
impl GemmKernel for KernelArmNeon {
type Elem = T;

type MRTy = U4;
type NRTy = U4;

#[inline(always)]
fn align_to() -> usize { 16 }

#[inline(always)]
fn always_masked() -> bool { false }

#[inline(always)]
fn nc() -> usize { archparam::S_NC }
#[inline(always)]
fn kc() -> usize { archparam::S_KC }
#[inline(always)]
fn mc() -> usize { archparam::S_MC }

#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel_target_arm_neon(k, alpha, a, b, beta, c, rsc, csc)
}
}

// 4x4 neon kernel unrolled developed for apple silicon M1
#[cfg(target_arch="aarch64")]
#[target_feature(enable="neon")]
unsafe fn kernel_target_arm_neon(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
use core::arch::aarch64::*;
const MR: usize = KernelArmNeon::MR;
const NR: usize = KernelArmNeon::NR;

let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) };

let mut ab = [vmovq_n_f32(0.); MR];
let mut ab2 = [vmovq_n_f32(0.); MR];
let mut ab3 = [vmovq_n_f32(0.); MR];
let mut ab4 = [vmovq_n_f32(0.); MR];
let use_fma = true;

// Compute
// ab_ij = a_i * b_j for all i, j
macro_rules! ab_ij_equals_ai_bj {
($dest:ident, $av:expr, $bv:expr) => {
if use_fma {
$dest[0] = vfmaq_laneq_f32($dest[0], $bv, $av, 0);
$dest[1] = vfmaq_laneq_f32($dest[1], $bv, $av, 1);
$dest[2] = vfmaq_laneq_f32($dest[2], $bv, $av, 2);
$dest[3] = vfmaq_laneq_f32($dest[3], $bv, $av, 3);
} else {
$dest[0] = vaddq_f32($dest[0], vmulq_laneq_f32($bv, $av, 0));
$dest[1] = vaddq_f32($dest[1], vmulq_laneq_f32($bv, $av, 1));
$dest[2] = vaddq_f32($dest[2], vmulq_laneq_f32($bv, $av, 2));
$dest[3] = vaddq_f32($dest[3], vmulq_laneq_f32($bv, $av, 3));
}
}
}

const UNROLL_BY: usize = 4;

for _ in 0..k / UNROLL_BY {
let av = vld1q_f32(a);
let bv = vld1q_f32(b);
// eprintln!("a: {av:?}");
// eprintln!("b: {bv:?}");

// FMLA instruction
// Cortex 7A: FMA has 7 cycles latency or 3 cycles when the dependency is on the accumulator
// M1: Latency 3, throughput 0.25
ab_ij_equals_ai_bj!(ab, av, bv);

let av = vld1q_f32(a.add(4));
let bv = vld1q_f32(b.add(4));

ab_ij_equals_ai_bj!(ab2, av, bv);

if UNROLL_BY > 2 {

let av = vld1q_f32(a.add(8));
let bv = vld1q_f32(b.add(8));

ab_ij_equals_ai_bj!(ab3, av, bv);

let av = vld1q_f32(a.add(12));
let bv = vld1q_f32(b.add(12));

ab_ij_equals_ai_bj!(ab4, av, bv);

}

a = a.offset(UNROLL_BY as isize * MR as isize);
b = b.offset(UNROLL_BY as isize * NR as isize);
}

for _ in 0..k % UNROLL_BY {
let av = vld1q_f32(a);
let bv = vld1q_f32(b);

ab_ij_equals_ai_bj!(ab, av, bv);

a = a.offset(MR as isize);
b = b.offset(NR as isize);
}

macro_rules! c {
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
}

macro_rules! extract {
($v:expr, $imm:expr) => (
f32::from_bits(vgetq_lane_u32(core::mem::transmute::<_, uint32x4_t>($v), $imm))
)
}

// Combine accumulators and multiply by alpha
loop4!(i, ab[i] = vaddq_f32(vaddq_f32(ab[i], ab2[i]), vaddq_f32(ab3[i], ab4[i])));
loop4!(i, ab[i] = vmulq_n_f32(ab[i], alpha));

if beta == 0. {
// set C = α A B
if csc == 1 {
loop4!(i, vst1q_f32(c![i, 0], ab[i]));
} else {
loop4!(i, vst1q_lane_f32(c![i, 0], ab[i], 0));
loop4!(i, vst1q_lane_f32(c![i, 1], ab[i], 1));
loop4!(i, vst1q_lane_f32(c![i, 2], ab[i], 2));
loop4!(i, vst1q_lane_f32(c![i, 3], ab[i], 3));
}
} else {
// set C = α A B + beta C
loop4!(i, *c![i, 0] = *c![i, 0] * beta + extract!(ab[i], 0));
loop4!(i, *c![i, 1] = *c![i, 1] * beta + extract!(ab[i], 1));
loop4!(i, *c![i, 2] = *c![i, 2] * beta + extract!(ab[i], 2));
loop4!(i, *c![i, 3] = *c![i, 3] * beta + extract!(ab[i], 3));
}
}

84 changes: 84 additions & 0 deletions spare_kernels/x86_sse_sgemm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

// 4x4 sse sgemm
macro_rules! mm_transpose4 {
($c0:expr, $c1:expr, $c2:expr, $c3:expr) => {{
// This is _MM_TRANSPOSE4_PS except we take variables, not references
let tmp0 = _mm_unpacklo_ps($c0, $c1);
let tmp2 = _mm_unpacklo_ps($c2, $c3);
let tmp1 = _mm_unpackhi_ps($c0, $c1);
let tmp3 = _mm_unpackhi_ps($c2, $c3);

$c0 = _mm_movelh_ps(tmp0, tmp2);
$c1 = _mm_movehl_ps(tmp2, tmp0);
$c2 = _mm_movelh_ps(tmp1, tmp3);
$c3 = _mm_movehl_ps(tmp3, tmp1);
}}
}

#[inline(always)]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_x86_sse(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
let mut ab = [_mm_setzero_ps(); MR];

let mut bv;
let (mut a, mut b) = (a, b);

// Compute A B
for _ in 0..k {
bv = _mm_load_ps(b as _); // aligned due to GemmKernel::align_to

loop_m!(i, {
// Compute ab_i += [ai b_j+0, ai b_j+1, ai b_j+2, ai b_j+3]
let aiv = _mm_set1_ps(at(a, i));
ab[i] = _mm_add_ps(ab[i], _mm_mul_ps(aiv, bv));
});

a = a.add(MR);
b = b.add(NR);
}

// Compute α (A B)
let alphav = _mm_set1_ps(alpha);
loop_m!(i, ab[i] = _mm_mul_ps(alphav, ab[i]));

macro_rules! c {
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
}

// C ← α A B + β C
let mut c = [_mm_setzero_ps(); MR];
let betav = _mm_set1_ps(beta);
if beta != 0. {
// Read C
if csc == 1 {
loop_m!(i, c[i] = _mm_loadu_ps(c![i, 0]));
} else if rsc == 1 {
loop_m!(i, c[i] = _mm_loadu_ps(c![0, i]));
mm_transpose4!(c[0], c[1], c[2], c[3]);
} else {
loop_m!(i, c[i] = _mm_set_ps(*c![i, 3], *c![i, 2], *c![i, 1], *c![i, 0]));
}
// Compute β C
loop_m!(i, c[i] = _mm_mul_ps(c[i], betav));
}

// Compute (α A B) + (β C)
loop_m!(i, c[i] = _mm_add_ps(c[i], ab[i]));

// Store C back to memory
if csc == 1 {
loop_m!(i, _mm_storeu_ps(c![i, 0], c[i]));
} else if rsc == 1 {
mm_transpose4!(c[0], c[1], c[2], c[3]);
loop_m!(i, _mm_storeu_ps(c![0, i], c[i]));
} else {
// extract the nth value of a vector using _mm_cvtss_f32 (extract lowest)
// in combination with shuffle (move nth value to first position)
loop_m!(i, *c![i, 0] = _mm_cvtss_f32(c[i]));
loop_m!(i, *c![i, 1] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 1)));
loop_m!(i, *c![i, 2] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 2)));
loop_m!(i, *c![i, 3] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 3)));
}
}

0 comments on commit b986816

Please sign in to comment.