Skip to content

Commit

Permalink
Feature mxfp4 bf16 avx2 gemms (#881)
Browse files Browse the repository at this point in the history
* Added some working code

* Some more ammends

* More ammends

* Got it working

* Fully removed m mod 8 requirement
  • Loading branch information
egeor committed May 14, 2024
1 parent feacc5e commit 9534052
Show file tree
Hide file tree
Showing 18 changed files with 175 additions and 127 deletions.
2 changes: 0 additions & 2 deletions samples/xgemm/kernel_test/generate_gemm_test_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@ for BINARY_POSTOP in 0 1; do
OUTNAME="mxfp4f32gemm_"
KSTART=32
KSTEP=32
MSTART=8
MSTEP=8
elif [[ ("$PREC" == 'BF16_BF16_F32_BF16') && ("$AVNNI" == '0') ]] ; then
OUTNAME="bf16_flatgemm_"
KSTART=2
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_adl.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_avx512_vl256.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_avx512_vl256_clx.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_avx512_vl256_cpx.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_clx.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_cpx.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_hsw.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_knl.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_knm.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_skx.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
4 changes: 0 additions & 4 deletions samples/xgemm/kernel_test_srf.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
mxfp4bf16f32gemm_nn_eqld
mxfp4bf16f32gemm_nn_gtld
mxfp4bf16gemm_nn_eqld
mxfp4bf16gemm_nn_gtld
hf8bf16bf16_spmm_eqld
hf8bf16bf16_spmm_gtld
hf8bf16f32_spmm_eqld
Expand Down
17 changes: 8 additions & 9 deletions src/generator_gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ void libxsmm_generator_gemm_kernel( libxsmm_generated_code* io_generated_
}
}

if ( l_is_Amxfp4_Bbf16_gemm > 0 ) {
if (io_generated_code->arch >= LIBXSMM_X86_AVX2 && io_generated_code->arch < LIBXSMM_X86_AVX512_SPR) {
io_generated_code->arch = LIBXSMM_X86_AVX2;
}
}

/* Check if it is a supported spmm with bitmap */
if ((l_xgemm_desc_mod.flags & LIBXSMM_GEMM_FLAG_DECOMPRESS_A_VIA_BITMASK) > 0) {
if ((io_generated_code->arch >= LIBXSMM_X86_GENERIC) && (io_generated_code->arch <= LIBXSMM_X86_ALLFEAT )) {
Expand Down Expand Up @@ -224,7 +230,7 @@ void libxsmm_generator_gemm_kernel( libxsmm_generated_code* io_generated_
} else {
/* We are good... */
}
} else if ((l_is_Amxfp4_Bfp32_gemm > 0) && (io_generated_code->arch >= LIBXSMM_X86_AVX2)) {
} else if ((l_is_Amxfp4_Bfp32_gemm > 0 || l_is_Amxfp4_Bbf16_gemm > 0) && (io_generated_code->arch >= LIBXSMM_X86_AVX2)) {
if ( (l_xgemm_desc_mod.flags & LIBXSMM_GEMM_FLAG_TRANS_A) > 0 ) {
LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_TRANS_A );
return;
Expand All @@ -240,9 +246,6 @@ void libxsmm_generator_gemm_kernel( libxsmm_generated_code* io_generated_
} else if (l_xgemm_desc_mod.k % 32 != 0) {
LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
return;
} else if (l_xgemm_desc_mod.m % 8 != 0) {
LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
return;
} else {
/* We are good... */
}
Expand Down Expand Up @@ -340,16 +343,12 @@ void libxsmm_generator_gemm_kernel( libxsmm_generated_code* io_generated_
} else {
l_vector_length = 8;
}
} else if ( ( io_generated_code->arch >= LIBXSMM_X86_AVX2 ) && ( l_is_Amxfp4_Bfp32_gemm > 0 ) ) {
} else if ( ( io_generated_code->arch >= LIBXSMM_X86_AVX2 ) && ( l_is_Amxfp4_Bfp32_gemm > 0 || l_is_Amxfp4_Bbf16_gemm > 0) ) {
l_vector_length = 8;
if (l_xgemm_desc_mod.k % 32 != 0) {
LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
return;
}
if (l_xgemm_desc_mod.m % 8 != 0) {
LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
return;
}
} else if ( ( io_generated_code->arch <= LIBXSMM_X86_AVX512_VL256_SKX ) && LIBXSMM_DATATYPE_F64 == LIBXSMM_GEMM_GETENUM_AB_COMMON_PREC( l_xgemm_desc_mod.datatype ) ) {
l_vector_length = 4;
} else if ( ( io_generated_code->arch <= LIBXSMM_X86_AVX512_VL256_SKX ) && LIBXSMM_DATATYPE_F32 == LIBXSMM_GEMM_GETENUM_AB_COMMON_PREC( l_xgemm_desc_mod.datatype ) ) {
Expand Down
106 changes: 76 additions & 30 deletions src/generator_gemm_avx2_microkernel.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ void libxsmm_generator_gemm_avx2_kloop_kernel( libxsmm_generated_code*
const libxsmm_gemm_descriptor*, const unsigned int, const unsigned int, const int);

unsigned int l_is_Amxfp4_Bfp32_gemm = libxsmm_x86_is_Amxfp4_Bfp32_gemm(i_xgemm_desc);
if ( l_is_Amxfp4_Bfp32_gemm > 0) {
l_generator_microkernel = libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32;
unsigned int l_is_Amxfp4_Bbf16_gemm = libxsmm_x86_is_Amxfp4_Bbf16_gemm(i_xgemm_desc);
if ( l_is_Amxfp4_Bfp32_gemm > 0 || l_is_Amxfp4_Bbf16_gemm > 0 ) {
l_generator_microkernel = libxsmm_generator_gemm_avx2_microkernel_Amxfp4;
l_generator_microkernel(io_generated_code, i_gp_reg_mapping, i_micro_kernel_config,
i_xgemm_desc, i_m_blocking, i_n_blocking, i_k_blocking);
return;
Expand Down Expand Up @@ -65,7 +66,7 @@ void libxsmm_generator_gemm_avx2_kloop_kernel( libxsmm_generated_code*
}

LIBXSMM_API_INTERN
void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code* io_generated_code,
void libxsmm_generator_gemm_avx2_microkernel_Amxfp4( libxsmm_generated_code* io_generated_code,
const libxsmm_gp_reg_mapping* i_gp_reg_mapping,
const libxsmm_micro_kernel_config* i_micro_kernel_config,
const libxsmm_gemm_descriptor* i_xgemm_desc,
Expand All @@ -82,6 +83,7 @@ void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code
/* start register of accumulator */
unsigned int l_vec_reg_acc_start = 16 - (i_n_blocking * l_m_blocking);
unsigned int l_tmp_reg_acc_start = l_vec_reg_acc_start - (i_n_blocking * l_m_blocking);
unsigned int l_andmask_vreg = l_tmp_reg_acc_start - 1;
unsigned int l_lut_mant_vreg = 0;
unsigned int l_lut_sign_vreg = 1;
unsigned int l_vreg_m_start = 2;
Expand All @@ -91,12 +93,19 @@ void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code
unsigned int l_k = 0;
unsigned int l_k_unroll_factor = (i_n_blocking == 1) ? 16 : 4;
unsigned int l_pf_dist = 16;
unsigned int l_is_Amxfp4_Bbf16_gemm = libxsmm_x86_is_Amxfp4_Bbf16_gemm(i_xgemm_desc);

/* Load scale value GPR */
if (((i_xgemm_desc->flags & LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS) == 0) && ((i_xgemm_desc->flags & LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET) == 0) && ((i_xgemm_desc->flags & LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE) == 0)) {
libxsmm_generator_gemm_getval_stack_var( io_generated_code, i_micro_kernel_config, LIBXSMM_GEMM_STACK_VAR_MXSCALE_PTR, i_gp_reg_mapping->gp_reg_scf );
}

if (l_is_Amxfp4_Bbf16_gemm > 0) {
libxsmm_generator_gemm_getval_stack_var( io_generated_code, i_micro_kernel_config, LIBXSMM_GEMM_STACK_VAR_SSE_AVX2_LP_HELPER_PTR, i_gp_reg_mapping->gp_reg_help_0 );
libxsmm_x86_instruction_vec_move( io_generated_code, i_micro_kernel_config->instruction_set, LIBXSMM_X86_INSTR_VMOVUPS,
i_gp_reg_mapping->gp_reg_help_0, LIBXSMM_X86_GP_REG_UNDEF, 0, 0, 'y', l_andmask_vreg, 0, 0, 0 );
}

/* Here we handle all K loop */
/* Start of K loop */
if (i_k > 32) {
Expand Down Expand Up @@ -127,16 +136,24 @@ void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code
l_tmp_vreg = l_vreg_m_start + 2;
}

libxsmm_x86_instruction_vec_move( io_generated_code, io_generated_code->arch, LIBXSMM_X86_INSTR_VPMOVZXBD,
i_gp_reg_mapping->gp_reg_a, LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m + l_k * i_xgemm_desc->lda * i_micro_kernel_config->datatype_size_in,
i_micro_kernel_config->vector_name, l_m_vreg_k0, 0, 0, 0 );
if ( ( l_m == (l_m_blocking - 1) ) && (i_micro_kernel_config->use_masking_a_c != 0) ) {
libxsmm_generator_maskedload_8bit_avx2( io_generated_code, i_gp_reg_mapping->gp_reg_help_1,
i_gp_reg_mapping->gp_reg_a, LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m + l_k * i_xgemm_desc->lda * i_micro_kernel_config->datatype_size_in,
l_m_vreg_k0, i_m_blocking % i_micro_kernel_config->vector_length );
libxsmm_x86_instruction_vec_compute_2reg( io_generated_code, LIBXSMM_X86_INSTR_VPMOVZXBD, i_micro_kernel_config->vector_name, l_m_vreg_k0, l_m_vreg_k0);
} else {
libxsmm_x86_instruction_vec_move( io_generated_code, io_generated_code->arch, LIBXSMM_X86_INSTR_VPMOVZXBD,
i_gp_reg_mapping->gp_reg_a, LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m + l_k * i_xgemm_desc->lda * i_micro_kernel_config->datatype_size_in,
i_micro_kernel_config->vector_name, l_m_vreg_k0, 0, 0, 0 );
}
if (l_k % 8 == 0) {
libxsmm_x86_instruction_prefetch(io_generated_code,
LIBXSMM_X86_INSTR_PREFETCHT0,
i_gp_reg_mapping->gp_reg_a,
LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m + l_k * i_xgemm_desc->lda * i_micro_kernel_config->datatype_size_in + l_pf_dist * 64);
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m + l_k * i_xgemm_desc->lda * i_micro_kernel_config->datatype_size_in + l_pf_dist * 64);
}

libxsmm_x86_instruction_vec_compute_2reg_imm8( io_generated_code, LIBXSMM_X86_INSTR_VPSRLD_I, i_micro_kernel_config->vector_name, l_m_vreg_k0, l_m_vreg_k1, 4);
Expand Down Expand Up @@ -170,15 +187,26 @@ void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code
i_xgemm_desc->ldb * 0 * i_micro_kernel_config->datatype_size_in2 + 0 + l_k*2*i_micro_kernel_config->datatype_size_in2,
i_micro_kernel_config->vector_name,
l_n_vreg_k0, 0, 1, 0 );
/* Load odd k */
libxsmm_x86_instruction_vec_move( io_generated_code,
i_micro_kernel_config->instruction_set,
i_micro_kernel_config->b_vmove_instruction,
i_gp_reg_mapping->gp_reg_b,
LIBXSMM_X86_GP_REG_UNDEF, 0,
i_xgemm_desc->ldb * 0 * i_micro_kernel_config->datatype_size_in2 + 4 + l_k*2*i_micro_kernel_config->datatype_size_in2 ,
i_micro_kernel_config->vector_name,
l_n_vreg_k1, 0, 1, 0 );

if (l_is_Amxfp4_Bbf16_gemm == 0) {
/* Load odd k */
libxsmm_x86_instruction_vec_move( io_generated_code,
i_micro_kernel_config->instruction_set,
i_micro_kernel_config->b_vmove_instruction,
i_gp_reg_mapping->gp_reg_b,
LIBXSMM_X86_GP_REG_UNDEF, 0,
i_xgemm_desc->ldb * 0 * i_micro_kernel_config->datatype_size_in2 + i_micro_kernel_config->datatype_size_in2 + l_k*2*i_micro_kernel_config->datatype_size_in2 ,
i_micro_kernel_config->vector_name,
l_n_vreg_k1, 0, 1, 0 );
} else {
libxsmm_x86_instruction_vec_compute_3reg( io_generated_code, LIBXSMM_X86_INSTR_VPANDD,
i_micro_kernel_config->vector_name,
l_n_vreg_k0, l_andmask_vreg, l_n_vreg_k1 );

libxsmm_x86_instruction_vec_compute_2reg_imm8( io_generated_code, LIBXSMM_X86_INSTR_VPSLLD_I,
i_micro_kernel_config->vector_name,
l_n_vreg_k0, l_n_vreg_k0, 16 );
}
}
libxsmm_x86_instruction_vec_compute_3reg( io_generated_code,
i_micro_kernel_config->vmul_instruction,
Expand Down Expand Up @@ -208,15 +236,24 @@ void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code
i_xgemm_desc->ldb * l_n * i_micro_kernel_config->datatype_size_in2 + 0 + l_k*2*i_micro_kernel_config->datatype_size_in2,
i_micro_kernel_config->vector_name,
l_n_vreg_k0, 0, 1, 0 );
/* Load odd k */
libxsmm_x86_instruction_vec_move( io_generated_code,
i_micro_kernel_config->instruction_set,
i_micro_kernel_config->b_vmove_instruction,
i_gp_reg_mapping->gp_reg_b,
LIBXSMM_X86_GP_REG_UNDEF, 0,
i_xgemm_desc->ldb * l_n * i_micro_kernel_config->datatype_size_in2 + 4 + l_k*2*i_micro_kernel_config->datatype_size_in2,
i_micro_kernel_config->vector_name,
l_n_vreg_k1, 0, 1, 0 );
if (l_is_Amxfp4_Bbf16_gemm == 0) {
/* Load odd k */
libxsmm_x86_instruction_vec_move( io_generated_code,
i_micro_kernel_config->instruction_set,
i_micro_kernel_config->b_vmove_instruction,
i_gp_reg_mapping->gp_reg_b,
LIBXSMM_X86_GP_REG_UNDEF, 0,
i_xgemm_desc->ldb * l_n * i_micro_kernel_config->datatype_size_in2 + i_micro_kernel_config->datatype_size_in2 + l_k*2*i_micro_kernel_config->datatype_size_in2,
i_micro_kernel_config->vector_name,
l_n_vreg_k1, 0, 1, 0 );
} else {
libxsmm_x86_instruction_vec_compute_3reg( io_generated_code, LIBXSMM_X86_INSTR_VPANDD,
i_micro_kernel_config->vector_name,
l_n_vreg_k0, l_andmask_vreg, l_n_vreg_k1 );
libxsmm_x86_instruction_vec_compute_2reg_imm8( io_generated_code, LIBXSMM_X86_INSTR_VPSLLD_I,
i_micro_kernel_config->vector_name,
l_n_vreg_k0, l_n_vreg_k0, 16 );
}

for ( l_m = 0; l_m < l_m_blocking; l_m++ ) {
unsigned int l_m_vreg_k0 = l_vreg_m_start + 2 * l_m + 0;
Expand Down Expand Up @@ -254,10 +291,19 @@ void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code
for ( l_m = 0; l_m < l_m_blocking; l_m++ ) {
unsigned int l_scf_vreg = l_vreg_scf_start + l_m;

libxsmm_x86_instruction_vec_move( io_generated_code, io_generated_code->arch, LIBXSMM_X86_INSTR_VPMOVZXBD,
i_gp_reg_mapping->gp_reg_scf, LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m,
i_micro_kernel_config->vector_name, l_scf_vreg, 0, 0, 0 );
if ( ( l_m == (l_m_blocking - 1) ) && (i_micro_kernel_config->use_masking_a_c != 0) ) {
libxsmm_generator_maskedload_8bit_avx2( io_generated_code, i_gp_reg_mapping->gp_reg_help_1,
i_gp_reg_mapping->gp_reg_scf, LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m,
l_scf_vreg, i_m_blocking % i_micro_kernel_config->vector_length );
libxsmm_x86_instruction_vec_compute_2reg( io_generated_code, LIBXSMM_X86_INSTR_VPMOVZXBD, i_micro_kernel_config->vector_name, l_scf_vreg, l_scf_vreg);
} else {
libxsmm_x86_instruction_vec_move( io_generated_code, io_generated_code->arch, LIBXSMM_X86_INSTR_VPMOVZXBD,
i_gp_reg_mapping->gp_reg_scf, LIBXSMM_X86_GP_REG_UNDEF, 0,
i_micro_kernel_config->datatype_size_in * i_micro_kernel_config->vector_length * l_m,
i_micro_kernel_config->vector_name, l_scf_vreg, 0, 0, 0 );

}
libxsmm_x86_instruction_vec_compute_2reg_imm8( io_generated_code, LIBXSMM_X86_INSTR_VPSLLD_I, i_micro_kernel_config->vector_name, l_scf_vreg, l_scf_vreg, 23);

for ( l_n = 0; l_n < i_n_blocking; l_n++ ) {
Expand Down
2 changes: 1 addition & 1 deletion src/generator_gemm_avx2_microkernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "generator_gemm_common.h"

LIBXSMM_API_INTERN
void libxsmm_generator_gemm_avx2_microkernel_Amxfp4Bfp32( libxsmm_generated_code* io_generated_code,
void libxsmm_generator_gemm_avx2_microkernel_Amxfp4( libxsmm_generated_code* io_generated_code,
const libxsmm_gp_reg_mapping* i_gp_reg_mapping,
const libxsmm_micro_kernel_config* i_micro_kernel_config,
const libxsmm_gemm_descriptor* i_xgemm_desc,
Expand Down

0 comments on commit 9534052

Please sign in to comment.