Skip to content

Commit

Permalink
AMX IGEMM M=1 specialized loop to use input pointer directly.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624293152
  • Loading branch information
fbarchard authored and xnnpack-bot committed Apr 12, 2024
1 parent eebe54c commit 058ff10
Show file tree
Hide file tree
Showing 26 changed files with 2,117 additions and 1,436 deletions.
348 changes: 194 additions & 154 deletions src/amalgam/gen/avx512amx.c
Original file line number Diff line number Diff line change
Expand Up @@ -1282,67 +1282,87 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx(
a += 16;

size_t k = kc;
while (k >= 64 * sizeof(int8_t)) {
const __m512i vin0 = _mm512_loadu_epi32(a0);
a0 += 64;
_mm512_store_epi32(vintile + 0, vin0);
const __m512i vin1 = _mm512_loadu_epi32(a1);
a1 += 64;
_mm512_store_epi32(vintile + 16, vin1);
const __m512i vin2 = _mm512_loadu_epi32(a2);
a2 += 64;
_mm512_store_epi32(vintile + 32, vin2);
const __m512i vin3 = _mm512_loadu_epi32(a3);
a3 += 64;
_mm512_store_epi32(vintile + 48, vin3);
const __m512i vin4 = _mm512_loadu_epi32(a4);
a4 += 64;
_mm512_store_epi32(vintile + 64, vin4);
const __m512i vin5 = _mm512_loadu_epi32(a5);
a5 += 64;
_mm512_store_epi32(vintile + 80, vin5);
const __m512i vin6 = _mm512_loadu_epi32(a6);
a6 += 64;
_mm512_store_epi32(vintile + 96, vin6);
const __m512i vin7 = _mm512_loadu_epi32(a7);
a7 += 64;
_mm512_store_epi32(vintile + 112, vin7);
const __m512i vin8 = _mm512_loadu_epi32(a8);
a8 += 64;
_mm512_store_epi32(vintile + 128, vin8);
const __m512i vin9 = _mm512_loadu_epi32(a9);
a9 += 64;
_mm512_store_epi32(vintile + 144, vin9);
const __m512i vin10 = _mm512_loadu_epi32(a10);
a10 += 64;
_mm512_store_epi32(vintile + 160, vin10);
const __m512i vin11 = _mm512_loadu_epi32(a11);
a11 += 64;
_mm512_store_epi32(vintile + 176, vin11);
const __m512i vin12 = _mm512_loadu_epi32(a12);
a12 += 64;
_mm512_store_epi32(vintile + 192, vin12);
const __m512i vin13 = _mm512_loadu_epi32(a13);
a13 += 64;
_mm512_store_epi32(vintile + 208, vin13);
const __m512i vin14 = _mm512_loadu_epi32(a14);
a14 += 64;
_mm512_store_epi32(vintile + 224, vin14);
const __m512i vin15 = _mm512_loadu_epi32(a15);
a15 += 64;
_mm512_store_epi32(vintile + 240, vin15);
_tile_loadd(4, vintile, 64);
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
if (mr == 1)
{
while (k >= 64 * sizeof(int8_t)) {
_tile_loadd(4, a0, 64); // Directly load input for mr=1
a15 += 64;
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
}
}
else {
while (k >= 64 * sizeof(int8_t)) {
const __m512i vin0 = _mm512_loadu_epi32(a0);
a0 += 64;
_mm512_store_epi32(vintile + 0, vin0);
const __m512i vin1 = _mm512_loadu_epi32(a1);
a1 += 64;
_mm512_store_epi32(vintile + 16, vin1);
const __m512i vin2 = _mm512_loadu_epi32(a2);
a2 += 64;
_mm512_store_epi32(vintile + 32, vin2);
const __m512i vin3 = _mm512_loadu_epi32(a3);
a3 += 64;
_mm512_store_epi32(vintile + 48, vin3);
const __m512i vin4 = _mm512_loadu_epi32(a4);
a4 += 64;
_mm512_store_epi32(vintile + 64, vin4);
const __m512i vin5 = _mm512_loadu_epi32(a5);
a5 += 64;
_mm512_store_epi32(vintile + 80, vin5);
const __m512i vin6 = _mm512_loadu_epi32(a6);
a6 += 64;
_mm512_store_epi32(vintile + 96, vin6);
const __m512i vin7 = _mm512_loadu_epi32(a7);
a7 += 64;
_mm512_store_epi32(vintile + 112, vin7);
const __m512i vin8 = _mm512_loadu_epi32(a8);
a8 += 64;
_mm512_store_epi32(vintile + 128, vin8);
const __m512i vin9 = _mm512_loadu_epi32(a9);
a9 += 64;
_mm512_store_epi32(vintile + 144, vin9);
const __m512i vin10 = _mm512_loadu_epi32(a10);
a10 += 64;
_mm512_store_epi32(vintile + 160, vin10);
const __m512i vin11 = _mm512_loadu_epi32(a11);
a11 += 64;
_mm512_store_epi32(vintile + 176, vin11);
const __m512i vin12 = _mm512_loadu_epi32(a12);
a12 += 64;
_mm512_store_epi32(vintile + 192, vin12);
const __m512i vin13 = _mm512_loadu_epi32(a13);
a13 += 64;
_mm512_store_epi32(vintile + 208, vin13);
const __m512i vin14 = _mm512_loadu_epi32(a14);
a14 += 64;
_mm512_store_epi32(vintile + 224, vin14);
const __m512i vin15 = _mm512_loadu_epi32(a15);
a15 += 64;
_mm512_store_epi32(vintile + 240, vin15);
_tile_loadd(4, vintile, 64);
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
}
}

if XNN_UNLIKELY(k != 0) {
Expand Down Expand Up @@ -2150,22 +2170,22 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx(
a += 1;

size_t k = kc;
while (k >= 64 * sizeof(int8_t)) {
const __m512i vin0 = _mm512_loadu_epi32(a0);
a0 += 64;
_mm512_store_epi32(vintile + 0, vin0);
_tile_loadd(4, vintile, 64);
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
{
while (k >= 64 * sizeof(int8_t)) {
_tile_loadd(4, a0, 64); // Directly load input for mr=1
a0 += 64;
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
}
}

if XNN_UNLIKELY(k != 0) {
Expand Down Expand Up @@ -3567,67 +3587,87 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx(
a += 16;

size_t k = kc;
while (k >= 64 * sizeof(int8_t)) {
const __m512i vin0 = _mm512_loadu_epi32(a0);
a0 += 64;
_mm512_store_epi32(vintile + 0, vin0);
const __m512i vin1 = _mm512_loadu_epi32(a1);
a1 += 64;
_mm512_store_epi32(vintile + 16, vin1);
const __m512i vin2 = _mm512_loadu_epi32(a2);
a2 += 64;
_mm512_store_epi32(vintile + 32, vin2);
const __m512i vin3 = _mm512_loadu_epi32(a3);
a3 += 64;
_mm512_store_epi32(vintile + 48, vin3);
const __m512i vin4 = _mm512_loadu_epi32(a4);
a4 += 64;
_mm512_store_epi32(vintile + 64, vin4);
const __m512i vin5 = _mm512_loadu_epi32(a5);
a5 += 64;
_mm512_store_epi32(vintile + 80, vin5);
const __m512i vin6 = _mm512_loadu_epi32(a6);
a6 += 64;
_mm512_store_epi32(vintile + 96, vin6);
const __m512i vin7 = _mm512_loadu_epi32(a7);
a7 += 64;
_mm512_store_epi32(vintile + 112, vin7);
const __m512i vin8 = _mm512_loadu_epi32(a8);
a8 += 64;
_mm512_store_epi32(vintile + 128, vin8);
const __m512i vin9 = _mm512_loadu_epi32(a9);
a9 += 64;
_mm512_store_epi32(vintile + 144, vin9);
const __m512i vin10 = _mm512_loadu_epi32(a10);
a10 += 64;
_mm512_store_epi32(vintile + 160, vin10);
const __m512i vin11 = _mm512_loadu_epi32(a11);
a11 += 64;
_mm512_store_epi32(vintile + 176, vin11);
const __m512i vin12 = _mm512_loadu_epi32(a12);
a12 += 64;
_mm512_store_epi32(vintile + 192, vin12);
const __m512i vin13 = _mm512_loadu_epi32(a13);
a13 += 64;
_mm512_store_epi32(vintile + 208, vin13);
const __m512i vin14 = _mm512_loadu_epi32(a14);
a14 += 64;
_mm512_store_epi32(vintile + 224, vin14);
const __m512i vin15 = _mm512_loadu_epi32(a15);
a15 += 64;
_mm512_store_epi32(vintile + 240, vin15);
_tile_loadd(4, vintile, 64);
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
if (mr == 1)
{
while (k >= 64 * sizeof(int8_t)) {
_tile_loadd(4, a0, 64); // Directly load input for mr=1
a15 += 64;
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
}
}
else {
while (k >= 64 * sizeof(int8_t)) {
const __m512i vin0 = _mm512_loadu_epi32(a0);
a0 += 64;
_mm512_store_epi32(vintile + 0, vin0);
const __m512i vin1 = _mm512_loadu_epi32(a1);
a1 += 64;
_mm512_store_epi32(vintile + 16, vin1);
const __m512i vin2 = _mm512_loadu_epi32(a2);
a2 += 64;
_mm512_store_epi32(vintile + 32, vin2);
const __m512i vin3 = _mm512_loadu_epi32(a3);
a3 += 64;
_mm512_store_epi32(vintile + 48, vin3);
const __m512i vin4 = _mm512_loadu_epi32(a4);
a4 += 64;
_mm512_store_epi32(vintile + 64, vin4);
const __m512i vin5 = _mm512_loadu_epi32(a5);
a5 += 64;
_mm512_store_epi32(vintile + 80, vin5);
const __m512i vin6 = _mm512_loadu_epi32(a6);
a6 += 64;
_mm512_store_epi32(vintile + 96, vin6);
const __m512i vin7 = _mm512_loadu_epi32(a7);
a7 += 64;
_mm512_store_epi32(vintile + 112, vin7);
const __m512i vin8 = _mm512_loadu_epi32(a8);
a8 += 64;
_mm512_store_epi32(vintile + 128, vin8);
const __m512i vin9 = _mm512_loadu_epi32(a9);
a9 += 64;
_mm512_store_epi32(vintile + 144, vin9);
const __m512i vin10 = _mm512_loadu_epi32(a10);
a10 += 64;
_mm512_store_epi32(vintile + 160, vin10);
const __m512i vin11 = _mm512_loadu_epi32(a11);
a11 += 64;
_mm512_store_epi32(vintile + 176, vin11);
const __m512i vin12 = _mm512_loadu_epi32(a12);
a12 += 64;
_mm512_store_epi32(vintile + 192, vin12);
const __m512i vin13 = _mm512_loadu_epi32(a13);
a13 += 64;
_mm512_store_epi32(vintile + 208, vin13);
const __m512i vin14 = _mm512_loadu_epi32(a14);
a14 += 64;
_mm512_store_epi32(vintile + 224, vin14);
const __m512i vin15 = _mm512_loadu_epi32(a15);
a15 += 64;
_mm512_store_epi32(vintile + 240, vin15);
_tile_loadd(4, vintile, 64);
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
}
}

if XNN_UNLIKELY(k != 0) {
Expand Down Expand Up @@ -4493,22 +4533,22 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx(
a += 1;

size_t k = kc;
while (k >= 64 * sizeof(int8_t)) {
const __m512i vin0 = _mm512_loadu_epi32(a0);
a0 += 64;
_mm512_store_epi32(vintile + 0, vin0);
_tile_loadd(4, vintile, 64);
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
{
while (k >= 64 * sizeof(int8_t)) {
_tile_loadd(4, a0, 64); // Directly load input for mr=1
a0 += 64;
_tile_loadd(5, (const int8_t*) w + 0, 256);
_tile_dpbssd(0, 4, 5);
_tile_loadd(5, (const int8_t*) w + 64, 256);
_tile_dpbssd(1, 4, 5);
_tile_loadd(5, (const int8_t*) w + 128, 256);
_tile_dpbssd(2, 4, 5);
_tile_loadd(5, (const int8_t*) w + 192, 256);
_tile_dpbssd(3, 4, 5);

w = (const int8_t*) w + 4096;
k -= 64 * sizeof(int8_t);
}
}

if XNN_UNLIKELY(k != 0) {
Expand Down

0 comments on commit 058ff10

Please sign in to comment.