diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..d2f375aac --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +# C/C++ follows clang-format +[*.{c,cpp,h,hpp}] +indent_style = space +indent_size = 4 diff --git a/CMakeLists.txt b/CMakeLists.txt index ebfee8f09..d305e5a65 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,14 +9,14 @@ if (NOT CMAKE_BUILD_TYPE) endif() if(UNIX) else() # Windows - # Force CMake to use icx-cl rather than the default C++ compiler/linker + # Force CMake to use icx-cl rather than the default C++ compiler/linker # (needed on Windows only) # include (CMakeForceCompiler) # CMAKE_FORCE_CXX_COMPILER (icx-cl IntelDPCPP) set(CMAKE_CXX_COMPILER icx-cl) include (Platform/Windows-Clang) include(cmake/GTestExternal.cmake) -endif() +endif() project(XeTLA) diff --git a/examples/01_gemm_universal/gemm_universal.cpp b/examples/01_gemm_universal/gemm_universal.cpp index 55d94249d..a895c303a 100644 --- a/examples/01_gemm_universal/gemm_universal.cpp +++ b/examples/01_gemm_universal/gemm_universal.cpp @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include enum class kslicing_impl_t : uint8_t { none = 0, global = 1, local = 2 }; -template +template void gemm_universal_run(uint32_t iter) { // Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations. // Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors. @@ -82,7 +83,7 @@ void gemm_universal_run(uint32_t iter) { constexpr uint32_t num_local_splitk = (kslicing_type == kslicing_impl_t::local) ? 2 : 1; - // Mirco-kernel configuration + // Micro-kernel configuration using tune_option = dict_t< elem_v_t, @@ -102,8 +103,8 @@ void gemm_universal_run(uint32_t iter) { data_type_c, // output datatype for C mem_layout::row_major, // memory layout for C 8, // leading dimension alignment for C, in unit of element - data_type_acc, // accumulator data type for intermediate resutls - gpu_arch::Xe, // GPU arch + data_type_acc, // accumulator data type for intermediate results + arch_tag, // GPU arch tune_option>; // allocate temp buffers for global split @@ -184,36 +185,42 @@ void gemm_universal_run(uint32_t iter) { free(Cnt, context); } +template +struct main_wrapper { + static constexpr auto exec = []() { + // An example code for calculating matrix multiplication using + // GEMM_UNIVERSAL API: + // C = A x B + // The resulted matrix C is partitioned by the group range + // in to multiple blocks. The block matrix + // C + // is computed by the workgroup with id: (0, i_w, j_w). + // (i_w, j_w) is an element in range specified by group range. + // Each thread with index (0, i_s, j_s) inside the same workgroup + // is responsible for a sub block of matrix multiplication, which is + // C[i_s*sg_m:(i_s+1):sg_m,j_s*sg_n:(j_s+1)*sg_n] + + // Alternatively, some threads can cooperate on the same sub block + // matrix given the same (i_s, j_s), i.e. the index space is extended + // from (0, i_s, j_s) to (k_s, i_s, j_s). + + // Another method to achieve the same effect is to extend the index space + // in group range, i.e. from (0, i_w, j_w) to (k_w, i_w, j_w) + + // More detailed description referring to the cooperation (kslicing) could + // be found in the example 01_gemm_universal with custom implementation + + // basic gemm_universal + gemm_universal_run(10); + + // basic gemm_universal with workgroup cooperation + // gemm_universal_run(10); + + // basic gemm_universal with thread cooperation + // gemm_universal_run(10); + }; +}; int main() { - // An example code for calculating matrix multiplication using - // GEMM_UNIVERSAL API: - // C = A x B - // The resulted matrix C is partitioned by the group range - // in to multiple blocks. The block matrix - // C - // is computed by the workgroup with id: (0, i_w, j_w). - // (i_w, j_w) is an element in range specified by group range. - // Each thread with index (0, i_s, j_s) inside the same workgroup - // is responsible for a sub block of matrix multiplication, which is - // C[i_s*sg_m:(i_s+1):sg_m,j_s*sg_n:(j_s+1)*sg_n] - - // Alternatively, some threads can cooperate on the same sub block - // matrix given the same (i_s, j_s), i.e. the index space is extended - // from (0, i_s, j_s) to (k_s, i_s, j_s). - - // Another method to achieve the same effect is to extend the index space - // in group range, i.e. from (0, i_w, j_w) to (k_w, i_w, j_w) - - // More detailed description referring to the cooperation (kslicing) could - // be found in the example 01_gemm_universal with custom implementation - - // basic gemm_universal - gemm_universal_run(10); - - // basic gemm_universal with workgroup cooperation - // gemm_universal_run(10); - - // basic gemm_universal with thread cooperation - // gemm_universal_run(10); - return (0); + dispatch_arch::exec(); + return 0; } diff --git a/examples/02_basic_gemm/basic_gemm.cpp b/examples/02_basic_gemm/basic_gemm.cpp index 918a78e10..44866e82f 100644 --- a/examples/02_basic_gemm/basic_gemm.cpp +++ b/examples/02_basic_gemm/basic_gemm.cpp @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include -template +template void basic_gemm_run(sycl::queue queue, uint32_t iter) { // Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations. // Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors. @@ -110,11 +110,11 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { // should larger than 8 static constexpr uint32_t k_stride = 32; - // Step 1: define mirco-kernel's configuration + // Step 1: define Micro-kernel's configuration using wg_shape = shape; using sg_shape = shape; - // Mirco-kernel configuration + // Micro-kernel configuration using gemm_tune_option = dict_t, elem_v_t; gemm_t gemm; @@ -149,24 +149,26 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { mem_space::global, // memory writing to global mem for C wg_shape, // computation tile shape k_stride, // elements in each iteration - gpu_arch::Xe, // GPU arch + arch_tag, // GPU arch epilogue_tune_option>; // Step 3: define the shared local memory usages // developers have the responsibility to set - // shared loacal memory through XeTLA API + // shared local memory through XeTLA API static constexpr uint32_t barrier_count = gemm_t::barrier_count; static constexpr uint32_t slm_size = gemm_t::slm_size; + static_assert(slm_size <= arch_attr_t::local_mem_size, + "The local memory size excess!"); xetla_nbarrier_init(); xetla_local_init(); - // Step 4: ecah workgroup gets it individual index to start computation + // Step 4: each workgroup gets it individual index to start computation int start_n = item.get_group(2) * wg_tile_n; int start_m = item.get_group(1) * wg_tile_m; // no slicing in K direction so start from zero for all WG int start_k = 0; - // Each workgroup will compute all data in K based on no k_sliciing + // Each workgroup will compute all data in K based on no k_slicing // The developer can set how much data a subgroup compute by k_stride uint32_t wg_tile_k = matrix_k; uint32_t inner_loop_count @@ -183,7 +185,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { mem_desc_output_c md_c( {C}, {matrix_n, matrix_m, ldc}, {start_n, start_m}); - // Step 6: real calculation with accumulator varibales which suppose + // Step 6: real calculation with accumulator variables which suppose // will be in register. typename gemm_t::matAcc_t matAcc; matAcc.init(0); @@ -194,8 +196,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { // the results is in the matAcc rather than real output C typename gemm_t::work_group_t g(item.get_local_linear_id()); gemm(g, matAcc, gemm_args); - - // Step 7: write the results from matACC to real output C + // Step 7: write the results from matAcc to real output C epilogue_t epilogue; epilogue(g, matAcc, md_c); }); @@ -220,23 +221,21 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { free(C, context); } +template +struct main_wrapper { + static constexpr auto exec = []() { + // This case shows how to use batch-reduce (br) GEMM microkernel to + // solve a standard GEMM + // Turn on the profiling property to facilitate subsequent profiling + sycl::property_list properties { + sycl::property::queue::enable_profiling()}; + + // Define SYCL queue, context and device + auto queue = sycl::queue(properties); + basic_gemm_run(queue, 10); + }; +}; int main() { - // This case shows how to use batch-reduce (br) GEMM microkernel to - // solve a standard GEMM - // Turn on the profiling property to facilitate subsequent profiling - sycl::property_list properties {sycl::property::queue::enable_profiling()}; - - // Define SYCL queue, context and device - auto queue = sycl::queue(properties); - auto device = queue.get_device(); - - // Detect the execution size, 8 for Arc, 16 for PVC. - int ExecSize - = device.get_info(); - if (ExecSize == 8) { - basic_gemm_run(queue, 10); - } else { - basic_gemm_run(queue, 10); - } - return (0); + dispatch_arch::exec(); + return 0; } diff --git a/examples/03_gemm_relu_bias/gemm_relu_bias.cpp b/examples/03_gemm_relu_bias/gemm_relu_bias.cpp index f81e946b0..caa820b50 100644 --- a/examples/03_gemm_relu_bias/gemm_relu_bias.cpp +++ b/examples/03_gemm_relu_bias/gemm_relu_bias.cpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ #include -#include #include "xetla.hpp" +#include using namespace cl::sycl; using namespace gpu::xetla; @@ -140,7 +140,7 @@ void gemm_relu_bias_run(uint32_t iter) { using epilogue_policy = xetla::group::epilogue_policy_tile_op; - // Mirco-kernel configuration + // Micro-kernel configuration using tune_option = dict_t< elem_v_t, @@ -156,7 +156,7 @@ void gemm_relu_bias_run(uint32_t iter) { data_type_c, // output datatype for C mem_layout::row_major, // memory layout for C 8, // leading dimension alignment for C, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results gpu_arch::Xe, // GPU arch tune_option>; using gemm_op_t = typename default_config_t::type; @@ -223,7 +223,7 @@ int main() { // The purpose of this example is to illustrate the epilogue_t API in XeTLA. // It allows user to implement multiple Ops inside a kernel call to avoid - // overheads in invokation, memory transfer, etc. + // overheads in invocation, memory transfer, etc. // Take the following python code as an example: // Original: @@ -231,7 +231,7 @@ int main() { // > x = to.matmul(A, B) // > y = to.nn.functional.relu(x) - // It takes two kernel invokations and the ReLU Op is a elementwise operation + // It takes two kernel invocations and the ReLU Op is a elementwise operation // that could be fused into MatMul Op, which is basically calling GEMM kernel. // Fusion: diff --git a/examples/04_gemm_polynomial/gemm_polynomial.cpp b/examples/04_gemm_polynomial/gemm_polynomial.cpp index 2aa2b61a9..981561e66 100644 --- a/examples/04_gemm_polynomial/gemm_polynomial.cpp +++ b/examples/04_gemm_polynomial/gemm_polynomial.cpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ #include -#include #include "xetla.hpp" +#include #include "gemm_polynomial.hpp" @@ -137,7 +137,7 @@ void gemm_polynomial_run(int iter) { using epilogue_policy = xetla::group::epilogue_policy_tile_op; - // Mirco-kernel configuration + // Micro-kernel configuration using tune_option = dict_t< elem_v_t, @@ -154,7 +154,7 @@ void gemm_polynomial_run(int iter) { data_type_c, // output datatype for C mem_layout::row_major, // memory layout for C 8, // leading dimension alignment for C, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results gpu_arch::Xe, // GPU arch tune_option>; diff --git a/examples/05_batch_gemm/batch_gemm.cpp b/examples/05_batch_gemm/batch_gemm.cpp index 3011f26a1..ff66838c1 100644 --- a/examples/05_batch_gemm/batch_gemm.cpp +++ b/examples/05_batch_gemm/batch_gemm.cpp @@ -90,7 +90,7 @@ void batch_gemm_run(uint32_t iter) { using wg_shape = shape; using sg_shape = shape; - // Mirco-kernel configuration + // Micro-kernel configuration using tune_option = dict_t, @@ -106,7 +106,7 @@ void batch_gemm_run(uint32_t iter) { mem_layout::row_major, // memory layout for B 8, // leading dimension for B, in unit of element mem_space::global, // memory reading from global mem for B - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results wg_shape, // computation tile shape wg_tile_k, // elements in each iteration gpu_arch::Xe, // GPU arch diff --git a/examples/05_batch_gemm/batch_gemm.hpp b/examples/05_batch_gemm/batch_gemm.hpp index ce2a814d7..fe00dbaa5 100644 --- a/examples/05_batch_gemm/batch_gemm.hpp +++ b/examples/05_batch_gemm/batch_gemm.hpp @@ -173,8 +173,8 @@ class batch_gemm_t { /// @return The size of local memory required. __XETLA_API static constexpr uint32_t get_slm_size() { constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size; - static_assert(size <= (128 * 1024), - "The local memory size should be less than 128KB!"); + static_assert(size <= arch_attr_t::local_mem_size, + "The local memory size excess!"); return size; }; diff --git a/examples/06_gemm_softmax/gemm_softmax.cpp b/examples/06_gemm_softmax/gemm_softmax.cpp index 724f43a57..cb6562677 100644 --- a/examples/06_gemm_softmax/gemm_softmax.cpp +++ b/examples/06_gemm_softmax/gemm_softmax.cpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include using namespace gpu::xetla; using namespace cl::sycl; @@ -156,8 +156,8 @@ void gemm_softmax_run(uint32_t iter) { cl::sycl::nd_range<3> nd_range(group_range * local_range, local_range); uint32_t warmup = 10; - int64_t ops - = 2 * static_cast(matrix_m) * matrix_n * matrix_k * batch_num; + int64_t ops = 2 * static_cast(matrix_m) * matrix_n * matrix_k + * batch_num; profiling_helper prof("gemm_softmax", ops, "gflops"); try { for (uint32_t i = 0; i < iter + warmup; i++) { @@ -178,11 +178,11 @@ void gemm_softmax_run(uint32_t iter) { // should larger than 8 static constexpr uint32_t k_iter_num = 16; - // Step 1: define mirco-kernel's configuration + // Step 1: define Micro-kernel's configuration using wg_shape = shape; using sg_shape = shape; - // Mirco-kernel configuration + // Micro-kernel configuration using tune_option = dict_t< elem_v_t; using sg_shape_layer1 = shape; - // Mirco-kernel configuration + // Micro-kernel configuration using epilogue_policy_layer1 = xetla::group::epilogue_policy_tile_op< xetla::subgroup::chained_tile_op_t, gpu_arch::Xe>; @@ -184,7 +184,7 @@ void mlp_run(uint32_t iter) { mem_layout::row_major, // memory layout for W 8, // leading dimension for W, in unit of element mem_space::global, // memory reading from global mem for W - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results wg_shape_layer1, // computation tile shape wg_tile_k, // elements in each iteration gpu_arch::Xe, // GPU arch @@ -203,7 +203,7 @@ void mlp_run(uint32_t iter) { using wg_shape_layer2 = shape; using sg_shape_layer2 = shape; - // Mirco-kernel configuration + // Micro-kernel configuration using layer2_tune_option = dict_t, @@ -219,7 +219,7 @@ void mlp_run(uint32_t iter) { mem_layout::row_major, // memory layout for V 8, // leading dimension for V, in unit of element mem_space::global, // memory reading from global mem for V - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results wg_shape_layer2, // computation tile shape wg_tile_k, // elements in each iteration gpu_arch::Xe, // GPU arch diff --git a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp index cb61058f6..2a4ac852f 100644 --- a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp +++ b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp @@ -135,10 +135,12 @@ int sdp_fwd_result_validate(dtype_in *q_device, dtype_in *k_device, return result ? 0 : 1; } -void sdp_fwd_run(uint32_t iter) { - // Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations. - // Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors. - // Please contact us for support. +template +void sdp_fwd_run(uint32_t iter, uint32_t warmup = 10) { + // Tips, the example demonstrates programming kernel with XeTLA, it works as + // expected with current configurations. Please make sure you fully understand + // these configurations before you do any modifications, incomplete changes + // may lead to unexpected behaviors. Please contact us for support. using dtype_in = bf16; using dtype_out = bf16; @@ -150,9 +152,15 @@ void sdp_fwd_run(uint32_t iter) { constexpr uint32_t matrix_n_qk = sequence_len; constexpr uint32_t matrix_k_qk = head_size; - constexpr uint32_t wg_tile_m_qk = 64; - constexpr uint32_t wg_tile_n_qk = 512; - constexpr uint32_t sg_tile_m_qk = 32; + constexpr double slm_ratio_to_pvc + = static_cast(arch_attr_t::local_mem_size) + / arch_attr_t::local_mem_size; + + constexpr uint32_t wg_tile_m_qksv = 64 * slm_ratio_to_pvc; + + constexpr uint32_t wg_tile_m_qk = wg_tile_m_qksv; + constexpr uint32_t wg_tile_n_qk = 512; // must == sl_kv + constexpr uint32_t sg_tile_m_qk = 32 * slm_ratio_to_pvc; constexpr uint32_t sg_tile_n_qk = 32; constexpr uint32_t wg_tile_k_qk = 32; @@ -161,10 +169,11 @@ void sdp_fwd_run(uint32_t iter) { constexpr uint32_t matrix_n_sv = head_size; constexpr uint32_t matrix_k_sv = sequence_len; - constexpr uint32_t wg_tile_m_sv = 64; - constexpr uint32_t wg_tile_n_sv = 64; + // constexpr uint32_t wg_tile_m_sv = 64; + constexpr uint32_t wg_tile_m_sv = wg_tile_m_qksv; + constexpr uint32_t wg_tile_n_sv = 64; // must == head_dim constexpr uint32_t sg_tile_m_sv = 8; - constexpr uint32_t sg_tile_n_sv = 16; + constexpr uint32_t sg_tile_n_sv = 16 * slm_ratio_to_pvc; constexpr uint32_t wg_tile_k_sv = 32; // buffer size of softmax row data @@ -178,11 +187,12 @@ void sdp_fwd_run(uint32_t iter) { auto context = queue.get_info(); auto device = queue.get_info(); - std::cout << "Running on " << device.get_info() << "\n"; + print_device_details(device); constexpr uint32_t size_qkv = matrix_m_qk * matrix_k_qk; constexpr uint32_t size_mask = matrix_m_qk * matrix_n_qk; constexpr uint32_t size_out = matrix_m_sv * matrix_n_sv; + const float scale_qk = 1.f / std::sqrt(head_size); auto q = alloc_device_and_init( batch_cnt * size_qkv, @@ -220,6 +230,11 @@ void sdp_fwd_run(uint32_t iter) { constexpr uint32_t subgroup_range_m = wg_tile_m_qk / sg_tile_m_qk; constexpr uint32_t subgroup_range_n = wg_tile_n_qk / sg_tile_n_qk; + constexpr uint32_t slm_size + = wg_tile_m_qk * wg_tile_n_qk * sizeof(dtype_sfx); + static_assert(slm_size <= arch_attr_t::local_mem_size, + "The local memory size excess!"); + static_assert(subgroup_range_m * subgroup_range_n == thread_num, "Given thread number should equal to pre-set value 32!"); std::cout << "group_num_x: " << group_range_n @@ -231,22 +246,20 @@ void sdp_fwd_run(uint32_t iter) { cl::sycl::range<3> local_range {1, subgroup_range_m, subgroup_range_n}; cl::sycl::nd_range<3> nd_range(group_range * local_range, local_range); - constexpr uint32_t warmup = 10; - int64_t ops = int64_t(4 * batch_num * head_num * sequence_len) * sequence_len - * head_size; + int64_t ops = int64_t(4 * batch_num * head_num * sequence_len) + * sequence_len * head_size; profiling_helper prof("sdp", ops, "gflops"); try { for (uint32_t i = 0; i < iter + warmup; i++) { if (i >= warmup) { prof.cpu_start(); } auto gpu_event = queue.submit([&](handler &cgh) { - cgh.parallel_for< - class Test>(nd_range, [=](nd_item<3> item) KERNEL_MAIN { + cgh.parallel_for(nd_range, [=](nd_item<3> item) KERNEL_MAIN { using namespace gpu::xetla; using namespace gpu::xetla::group; using namespace gpu::xetla::kernel; using namespace gpu::xetla::subgroup; - uint32_t batch_id = item.get_group(0); + const uint32_t batch_id = item.get_group(0); // disable sync in gemm static constexpr uint32_t periodic_sync_interval = 0; static constexpr uint32_t prefetch_distance = 3; @@ -254,19 +267,23 @@ void sdp_fwd_run(uint32_t iter) { using wg_shape0 = shape; using sg_shape0 = shape; - using post_op0_t = scalar_mul_op_t; + using post_op0_t = scalar_mul_op_t; using post_op1_t = elemwise_reduce_op_t; + dtype_in, arch_tag>; using post_op_t = chained_tile_op_t; using epilogue_policy0 = xetla::group::epilogue_policy_tile_op; - using group_swizzle = group_swizzle_default; - - using tune_option0 = dict_t< - elem_v_t, + arch_tag>; + using group_swizzle = group_swizzle_default; + + using elem_opt_mode_t + = elem_v_t; + using elem_opt_type_t = elem_v_t< + tune_key::param_optimizer_type, + tune_key_value::param_optimizer_decision_tree>; + using tune_option0 = dict_t< // + elem_opt_type_t, elem_opt_mode_t, elem_t_t, elem_t_t, @@ -277,28 +294,30 @@ void sdp_fwd_run(uint32_t iter) { using gemm0_t = xetla::group::default_gemm_selector_t< dtype_in, // input datatype for A mem_layout::row_major, // memory layout for A - 8, // leading dimension for A, in unit of element + // alignment for A, in unit of element + DEVICE_MEM_ALIGNMENT / sizeof(dtype_in), mem_space:: global, // memory reading from global mem for A dtype_in, // input datatype for B mem_layout::row_major, // memory layout for B - 8, // leading dimension for B, in unit of element + // alignment for B, in unit of element + DEVICE_MEM_ALIGNMENT / sizeof(dtype_in), mem_space:: global, // memory reading from global mem for B - float, // accumulator data type for intermediate resutls + float, // accumulator data type for intermediate results wg_shape0, // computation tile shape wg_tile_k_qk, // elements in each iteration - gpu_arch::Xe, // GPU arch + arch_tag, // GPU arch tune_option0>; using epilogue0_t = xetla::group::default_epilogue_selector_t< dtype_sfx, // onput datatype for C mem_layout::row_major, // memory layout for C - 8, // leading dimension for C, in unit of element + 8, // alignment for C, in unit of element mem_space:: local, // memory writing to local mem for C wg_shape0, // computation tile shape wg_tile_k_qk, // elements in each iteration - gpu_arch::Xe, // GPU arch + arch_tag, // GPU arch tune_option0>; using gemm_op0_t = gemm_universal_t< dispatch_policy_default, gemm0_t, @@ -307,29 +326,27 @@ void sdp_fwd_run(uint32_t iter) { using tile_shape0 = typename gemm0_t::tile_shape; // initialize SLM size - constexpr uint32_t slm_size - = wg_tile_m_qk * wg_tile_n_qk * sizeof(dtype_sfx); xetla_local_init(); // initialize named barrier count // we only need to do thread sync while store gemm results to SLM // one barrier is enough for that xetla_nbarrier_init<1>(); - xetla_nbarrier_t - nbarrier; + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier(0, nbarrier_role::producer_consumer); // initialize gemm op: gemm result store to shared local memory - typename post_op0_t::arguments_t post_op0_arg(0.125); + typename post_op0_t::arguments_t post_op0_arg(scale_qk); typename post_op1_t::arguments_t post_op1_arg( + // attn_mask pre-load ptr batch offset attn_mask + batch_id / head_num * size_mask + wg_tile_m_qk * wg_tile_n_qk - * item.get_group( - 1), // attn_mask pre-load ptr batch offset - {matrix_n_qk, // attn_mask tdesc width + * item.get_group(1), + { + matrix_n_qk, // attn_mask tdesc width matrix_m_qk, // attn_mask tdesc height - matrix_n_qk} // attn_mask tdesc pitch - ); + matrix_n_qk, // attn_mask tdesc pitch + }); typename gemm_op0_t::arguments_t arg0(matrix_m_qk, matrix_k_qk, matrix_n_qk, q + batch_id * size_qkv, // matA_ptr + batch offset @@ -339,22 +356,20 @@ void sdp_fwd_run(uint32_t iter) { 0, // matC_base matrix_n_qk, // matC load width {{post_op0_arg, post_op1_arg}}); - gemm_op0_t gemm_op0; - gemm_op0(item, arg0); + gemm_op0_t {}(item, arg0); xetla_fence(); nbarrier.arrive_wait(); // softmax start: result store to SLM using softmax_op_t = xetla_softmax_fwd_t; + mem_space::local, SIMD, thread_num, softmax_sz, + arch_tag>; typename softmax_op_t::arguments_t arg1; - softmax_op_t softmax_op; - arg1.data_in_base = 0; arg1.data_out_base = 0; - softmax_op(item, &arg1); + softmax_op_t {}(item, &arg1); xetla_fence(); nbarrier.arrive_wait(); @@ -362,10 +377,8 @@ void sdp_fwd_run(uint32_t iter) { using wg_shape1 = shape; using sg_shape1 = shape; - using tune_option1 = dict_t< - elem_v_t, + using tune_option1 = dict_t< // + elem_opt_type_t, elem_opt_mode_t, elem_t_t, elem_v_t, @@ -375,18 +388,19 @@ void sdp_fwd_run(uint32_t iter) { using gemm1_t = xetla::group::default_gemm_selector_t< dtype_in, // input datatype for A mem_layout::row_major, // memory layout for A - 8, // leading dimension for A, in unit of element + 8, // alignment for A, in unit of element mem_space:: local, // memory reading from local mem for A dtype_in, // input datatype for B mem_layout::row_major, // memory layout for B - 8, // leading dimension for B, in unit of element + // alignment for B, in unit of element + DEVICE_MEM_ALIGNMENT / sizeof(dtype_in), mem_space:: global, // memory reading from global mem for B - float, // accumulator data type for intermediate resutls + float, // accumulator data type for intermediate results wg_shape1, // computation tile shape wg_tile_k_sv, // elements in each iteration - gpu_arch::Xe, // GPU arch + arch_tag, // GPU arch tune_option1>; // gemm arguments include matA & matB load information and @@ -395,6 +409,9 @@ void sdp_fwd_run(uint32_t iter) { using work_group_t = typename gemm1_t::work_group_t; using mem_desc_a_t = typename gemm1_t::mem_desc_a_t; using mem_desc_b_t = typename gemm1_t::mem_desc_b_t; + using mem_desc_c_t = mem_desc_t< // + dtype_out, mem_layout::row_major, mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(dtype_out)>; // Using gemm::matAcc init a matC class for future storage using matAcc_t = typename gemm1_t::matAcc_t; using matC_t = tile_t matrix_n - ? matrix_n - : (start_n + wg_tile_n_sv); + uint32_t boundary_n + = std::min(start_n + wg_tile_n_sv, matrix_n); uint32_t boundary_k = wg_tile_k; work_group_t g; @@ -431,42 +447,45 @@ void sdp_fwd_run(uint32_t iter) { mem_desc_b.init(matB_ptr, {boundary_n, boundary_k, matB_ld}, {start_n, start_k}); - uint32_t inner_loop_count + uint32_t sg_k_count = (wg_tile_k + wg_tile_k_sv - 1) / wg_tile_k_sv; - gemm_args_t gemm_args( - mem_desc_a, mem_desc_b, inner_loop_count); + gemm_args_t gemm_args(mem_desc_a, mem_desc_b, sg_k_count); matAcc_t matAcc; - matC_t matC; - gemm1_t gemm; matAcc.init(0); - gemm(g, matAcc, gemm_args); + gemm1_t {}(g, matAcc, gemm_args); + // permute store + matC_t matC; subgroup::elemwise_cvt(matC, matAcc); - xetla_tdescriptor transpose_tdecs; - // Define a temprary vector as output buffer - xetla_vector out_reg; // Calculate new coordination of each element - uint32_t b = item.get_group(0) / head_num; - uint32_t n = item.get_group(0) % head_num; - uint32_t f = start_m + gemm1_t::get_matC_offset_y(g); - uint32_t h = start_n + gemm1_t::get_matC_offset_x(g); - - // transpose 8 * 16 tile and store to global - for (uint32_t j = 0; j < sg_tile_m_sv; ++j, ++f) { - uint32_t dst_offset - = b * head_num * sequence_len * head_size - + f * head_num * head_size + n * head_size; - out_reg = matC.reg.xetla_select( - j * sg_tile_n_sv); - xetla_fill_tdesc( - transpose_tdecs.xetla_format(), - out + dst_offset, head_size, 1, head_size, h, - 0); - xetla_tstore_global(transpose_tdecs, out_reg); - } + const uint32_t b = batch_id / head_num; + const uint32_t n = batch_id % head_num; + const uint32_t batch_offset + = b * head_num * sequence_len * head_size + + start_m * head_num * head_size + n * head_size + + start_n; + const uint32_t f = gemm1_t::get_matC_offset_y(g); + const uint32_t h = gemm1_t::get_matC_offset_x(g); + + const auto ld_c = head_num * head_size; + mem_desc_c_t mem_desc_c; + mem_desc_c.init( + out + batch_offset, // dst_base = out_ptr + wg offset + { + std::min(h + sg_tile_n_sv, wg_tile_n_sv), + std::min(f + sg_tile_m_sv, wg_tile_m_sv), + ld_c, + }, + {int(h), int(f)}); + + constexpr auto msg_type_c = msg_type::block_2d; + using mat_tile_desc = typename matC_t::tile_desc; + using matC_payload_t = subgroup::mem_payload_t; + matC_payload_t matC_payload(mem_desc_c); + subgroup::tile_store(matC, matC_payload); }); }); gpu_event.wait(); @@ -488,7 +507,7 @@ void sdp_fwd_run(uint32_t iter) { mem_layout::col_major, mem_layout::row_major, mem_layout::row_major)); - //performance + // performance prof.print_profiling_result(profiling_selector::GPU); free(q, context); @@ -498,28 +517,36 @@ void sdp_fwd_run(uint32_t iter) { free(out, context); } +template +struct main_wrapper { + static constexpr auto exec = []() { + // This example implements scaled-dot-production with batch_size: 16, + // num_heads: 16, sequence_length: 512, head_size: 64. It will be shown how to + // remap the index space of each work-item used for gemm1, softmax and gemm2. + + // Description: + // Scaled-dot-production mechanism can be seen as two chained batch MatMul + // with a softmax in the middle layer. It can be described as following + // mathematical expression: + // softmax(Q · (K.transpose(-1, -2)) * (1 / sqr_root(num_heads)) + + // attn_mask) · V + // where: + // Q, K, V: input data + // shape(Q) = [16 x 16, 512, 64] + // shape(K) = [16 x 16, 512, 64] + // shape(V) = [16 x 16, 512, 64] + // shape(attn_mask) = [16, 512, 512] + // shape(DST) = [16, 512, 16, 64] + + // This kernel is designed to execute the following task: + // 1: S = (Q · (K.transpose(-1, -2))) * (1 / sqr_root(num_heads)) + attn_mask + // 2: S' = softmax(S) + // 3: O = S' · V + sdp_fwd_run(10); + }; +}; + int main() { - // This example implements scaled-dot-production with batch_size: 16, - // num_heads: 16, sequence_lenth: 512, head_size: 64. It will be shown how to - // remap the index space of each work-item used for gemm1, softmax and gemm2. - - // Description: - // Scaled-dot-production mechanism can be seen as two chained batch MatMul with - // a softmax in the middle layer. It can be descripted as following - // mathematical expression: - // softmax(Q · (K.transpose(-1, -2)) * (1 / sqr_root(num_heads)) + attn_mask) · V - // where: - // Q, K, V: input data - // shape(Q) = [16 x 16, 512, 64] - // shape(K) = [16 x 16, 512, 64] - // shape(V) = [16 x 16, 512, 64] - // shape(attn_mask) = [16, 512, 512] - - // This kernel is designed to execute the following task: - // 1: S = (Q · (K.transpose(-1, -2))) * (1 / sqr_root(num_heads)) + attn_mask - // 2: S' = softmax(S) - // 3: O = S' · V - - sdp_fwd_run(10); + dispatch_arch::exec(); return 0; } diff --git a/examples/08_scaled_dot_product_attention/softmax.hpp b/examples/08_scaled_dot_product_attention/softmax.hpp index 184bc311b..0fc04b8aa 100644 --- a/examples/08_scaled_dot_product_attention/softmax.hpp +++ b/examples/08_scaled_dot_product_attention/softmax.hpp @@ -24,7 +24,7 @@ using namespace gpu::xetla::subgroup; template + uint32_t thread_num_, uint32_t softmax_size_, gpu_arch arch_tag> struct xetla_softmax_fwd_t { using dtype_in = dtype_in_; using dtype_out = dtype_out_; @@ -56,16 +56,14 @@ struct xetla_softmax_fwd_t { using softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_t, softmax_tile_desc_t, - subgroup::msg_type_v, - gpu_arch::Xe>; + subgroup::msg_type_v, arch_tag>; // this tile will store the softmax result to global memory using softmax_store_t = subgroup::tile_t; using softmax_store_payload_t = subgroup::mem_payload_t< mem_desc_t, softmax_tile_desc_t, - subgroup::msg_type_v, - gpu_arch::Xe>; + subgroup::msg_type_v, arch_tag>; struct arguments_t { // available while original data is from SLM @@ -113,10 +111,10 @@ struct xetla_softmax_fwd_t { row_data_32 = softmax_load.reg.xetla_select(0); // get max - float xmax = hmax(row_data_32); + float x_max = hmax(row_data_32); // get exp_sum - row_data_32 -= xmax; + row_data_32 -= x_max; row_data_32 = exp(row_data_32); float exp_sum = sum(row_data_32); diff --git a/examples/10_gemm_large_n/gemm_large_n.cpp b/examples/10_gemm_large_n/gemm_large_n.cpp index a0e0b599c..8ebe93d31 100644 --- a/examples/10_gemm_large_n/gemm_large_n.cpp +++ b/examples/10_gemm_large_n/gemm_large_n.cpp @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include #include #include @@ -83,7 +83,7 @@ void gemm_large_n_run(uint32_t iter) { // default 8 static constexpr uint32_t wg_num_n = 8; - // Mirco-kernel configuration + // Micro-kernel configuration using group_swizzle = xetla::kernel::group_swizzle_snake; @@ -106,7 +106,7 @@ void gemm_large_n_run(uint32_t iter) { data_type_c, // output datatype for C mem_layout::row_major, // memory layout for C 8, // leading dimension alignment for C, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results gpu_arch::Xe, // GPU arch tune_option>; diff --git a/examples/11_stream_k_gemm/stream_k_gemm.cpp b/examples/11_stream_k_gemm/stream_k_gemm.cpp index 87bd3fe93..0844dbe3c 100644 --- a/examples/11_stream_k_gemm/stream_k_gemm.cpp +++ b/examples/11_stream_k_gemm/stream_k_gemm.cpp @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include void stream_k_gemm_run(uint32_t iter) { // Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations. @@ -85,7 +85,7 @@ void stream_k_gemm_run(uint32_t iter) { sg_tile_n, // subgroup size in dim0 sg_tile_m>; // subgroup size in dim1 - // Mirco-kernel configuration + // Micro-kernel configuration using gemm_config = xetla::group::gemm_selector_t< data_type_a, // input datatype for A data_type_b, // input datatype for B @@ -95,7 +95,7 @@ void stream_k_gemm_run(uint32_t iter) { mem_space::global, // memory reading from global mem for B 8, // leading dimension for A, in unit of element 8, // leading dimension for B, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results tile_shape, // computation tile shape sg_tile_k, // elements in each iteration mma_engine::xmx, // compute engine @@ -303,7 +303,7 @@ void stream_k_gemm_relu_biasadd_run(uint32_t iter) { sg_tile_n, // subgroup size in dim0 sg_tile_m>; // subgroup size in dim1 - // Mirco-kernel configuration + // Micro-kernel configuration using gemm_config = xetla::group::gemm_selector_t< data_type_a, // input datatype for A data_type_b, // input datatype for B @@ -313,7 +313,7 @@ void stream_k_gemm_relu_biasadd_run(uint32_t iter) { mem_space::global, // memory reading from global mem for B 8, // leading dimension for A, in unit of element 8, // leading dimension for B, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results tile_shape, // computation tile shape sg_tile_k, // elements in each iteration mma_engine::xmx, // compute engine diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 1f50f5d7d..193696628 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,6 +1,11 @@ include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}) +# Creates a separate device code module for each SYCL* kernel +# so that kernel for Dg2 and Xe will be JIT separately +add_compile_options(-fsycl-device-code-split=per_kernel) +add_link_options(-fsycl-device-code-split=per_kernel) + add_subdirectory(01_gemm_universal) add_subdirectory(02_basic_gemm) add_subdirectory(03_gemm_relu_bias) @@ -13,4 +18,4 @@ add_subdirectory(09_gate_recurrent_unit) add_subdirectory(10_gemm_large_n) if(UNIX) # pvc not available on win? add_subdirectory(11_stream_k_gemm) -endif() \ No newline at end of file +endif() diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 8c020641d..c7f6157d8 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -120,6 +120,7 @@ struct arch_attr_t { using mma_attr = mma_attr_t; static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t local_mem_size = 128 * 1024; }; template <> @@ -133,6 +134,7 @@ struct arch_attr_t { using mma_attr = mma_attr_t; static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t local_mem_size = 64 * 1024; }; /// @} xetla_core_arch_config diff --git a/include/common/utils/common.hpp b/include/common/utils/common.hpp index e67ccc829..52a4a61e3 100644 --- a/include/common/utils/common.hpp +++ b/include/common/utils/common.hpp @@ -46,7 +46,7 @@ constexpr uint32_t get_element_size_code() { enum class lsc_action : uint8_t { prefetch, load, store, atomic }; template -constexpr std::enable_if_t +constexpr std::enable_if_t check_lsc_cache_hint() { if constexpr (Action == lsc_action::prefetch) { // https://gfxspecs.intel.com/Predator/Home/Index/53560 @@ -126,7 +126,7 @@ get_prefetch_cache_hint_code() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t get_store_cache_hint_code() { check_lsc_cache_hint(); if (L1H == cache_hint::none && L2H == cache_hint::none) { diff --git a/include/common/utils/raw_send_nbarrier.hpp b/include/common/utils/raw_send_nbarrier.hpp index fb7b92ee1..13bc45426 100644 --- a/include/common/utils/raw_send_nbarrier.hpp +++ b/include/common/utils/raw_send_nbarrier.hpp @@ -41,8 +41,12 @@ enum class nbarrier_role : uint8_t { /// as consumer. /// template -struct xetla_nbarrier_t { + gpu_arch arch_tag = gpu_arch::Xe, typename enable = void> +struct xetla_nbarrier_t; + +template +struct xetla_nbarrier_t> { /// /// @brief Description of named barrier objection. /// Structure is defined in @@ -87,8 +91,9 @@ struct xetla_nbarrier_t { } }; -template -struct xetla_nbarrier_t { +template +struct xetla_nbarrier_t> { /// /// @brief Description of named barrier objection. /// Structure is defined in @@ -114,7 +119,6 @@ struct xetla_nbarrier_t { /// __XETLA_API void wait() { __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::wait>(); - // __ESIMD_NS::barrier(); } /// @brief named barrier signal from subgroup. diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 460ec110c..11d2c0e39 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -58,7 +58,7 @@ struct compute_policy_int4_dequantize_xmx::mma_attr::mma_n_in_elem; static constexpr uint32_t block_bytes_y_b = 32; static_assert(block_bytes_x_a == block_bytes_y_b, "mat_a x need to match with mat_b y"); diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 2df93ab4e..4fb330e5e 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -376,13 +376,8 @@ class gemm_universal_t::local_mem_size, + "The local memory size excess!"); return size; } diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 43d6b6bef..338fc46d5 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -31,19 +31,20 @@ namespace gpu::xetla::group { /// @tparam perf_tuning_knob_ Is performance-related knobs. /// @tparam arch_tag_ Is the HW architecture. template + gpu_arch arch_tag_ = gpu_arch::Xe, typename enable = void> struct compute_policy_default_xmx {}; /// @brief Specialized for Xe architecture. -template -struct compute_policy_default_xmx { +template +struct compute_policy_default_xmx> { using compute_attr = compute_attr_; using perf_tuning_knob = perf_tuning_knob_; static constexpr int k_stride = perf_tuning_knob::k_stride; static constexpr int stages = perf_tuning_knob::stages; static constexpr int sync_freq = perf_tuning_knob::sync_freq; - static constexpr gpu_arch arch_tag = gpu_arch::Xe; + static constexpr gpu_arch arch_tag = arch_tag_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; using dtype_mma_b = typename compute_attr::dtype_b; @@ -53,7 +54,8 @@ struct compute_policy_default_xmx::mma_attr::mma_n_in_elem; static constexpr uint32_t block_bytes_y_b = 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); @@ -90,7 +92,7 @@ struct compute_policy_unaligned_xmx::mma_attr::mma_n_in_elem; static constexpr uint32_t block_bytes_y_b = 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); @@ -103,19 +105,20 @@ struct compute_policy_unaligned_xmx + gpu_arch arch_tag_ = gpu_arch::Xe, typename enable = void> struct compute_policy_default_fpu {}; /// @brief Specialized for Xe architecture. -template -struct compute_policy_default_fpu { +template +struct compute_policy_default_fpu> { using compute_attr = compute_attr_; using perf_tuning_knob = perf_tuning_knob_; static constexpr int k_stride = perf_tuning_knob::k_stride; static constexpr int stages = perf_tuning_knob::stages; static constexpr int sync_freq = perf_tuning_knob::sync_freq; - static constexpr gpu_arch arch_tag = gpu_arch::Xe; + static constexpr gpu_arch arch_tag = arch_tag_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; using dtype_mma_b = typename compute_attr::dtype_b; diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index 73e57e860..33cfa0f43 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -310,7 +310,8 @@ class gemm_t< if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); } - if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } + if constexpr (arch_tag >= gpu_arch::Xe) + if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } } } SW_BARRIER(); @@ -343,7 +344,8 @@ class gemm_t< if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { if constexpr (wg_size_x > 1) { nbarrier_a.wait(); } - if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } + if constexpr (arch_tag >= gpu_arch::Xe) + if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } } } } diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index 08d2d216f..832b5d4a8 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -37,7 +37,7 @@ class gemm_t< mem_desc_a_t_, // memory attribute of matA mem_desc_b_t_, // memory attribute of matB pre_processing_t_, // pre_processing functor - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { public: using mem_desc_a_t = mem_desc_a_t_; using mem_desc_b_t = mem_desc_b_t_; @@ -310,7 +310,8 @@ class gemm_t< if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); } - if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } + if constexpr (arch_tag >= gpu_arch::Xe) + if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } } } subgroup::tile_load( @@ -346,7 +347,8 @@ class gemm_t< if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { if constexpr (wg_size_x > 1) { nbarrier_a.wait(); } - if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } + if constexpr (arch_tag >= gpu_arch::Xe) + if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } } } } diff --git a/include/group/gemm/impl/unaligned_xmx_xe.hpp b/include/group/gemm/impl/unaligned_xmx_xe.hpp index 04d7dcdaf..508f9d082 100755 --- a/include/group/gemm/impl/unaligned_xmx_xe.hpp +++ b/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -373,7 +373,7 @@ class gemm_t(matB_t::tile_size_y); xetla_fence(); nbarrier_a.arrive(); - nbarrier_b.arrive(); + if (arch_tag >= gpu_arch::Xe) nbarrier_b.arrive(); #pragma unroll for (uint32_t i = 1; i < num_cyclic - 1; i++) { tile_load(partial_matA, matA_payload); @@ -429,7 +429,7 @@ class gemm_t= gpu_arch::Xe) nbarrier_b.wait(); tile_load(matA, matA_local_ld_payload); tile_load(matB, matB_local_ld_payload); @@ -463,7 +463,7 @@ class gemm_t= gpu_arch::Xe) nbarrier_b.arrive(); SW_BARRIER(); matA_acc_t matA_acc; matB_acc_t matB_acc; @@ -498,7 +498,7 @@ class gemm_t= gpu_arch::Xe) nbarrier_b.wait(); } private: diff --git a/include/kernel/default_config/common.hpp b/include/kernel/default_config/common.hpp index c0910b3e9..fd8c621a2 100644 --- a/include/kernel/default_config/common.hpp +++ b/include/kernel/default_config/common.hpp @@ -52,8 +52,38 @@ enum class tune_key : uint8_t { dispatch_policy, group_swizzle_policy, param_optimizer_type, + param_optimizer_level, source_location }; +template +using data_type_a_t = + typename T::template find_elem_t::type; +template +using data_type_b_t = + typename T::template find_elem_t::type; +template +using data_type_c_t = + typename T::template find_elem_t::type; +template +constexpr auto memory_layout_a_v + = T::template find_elem_v; +template +constexpr auto memory_alignment_a_v + = T::template find_elem_v; +template +constexpr auto memory_layout_b_v + = T::template find_elem_v; +template +constexpr auto memory_alignment_b_v + = T::template find_elem_v; +template +constexpr auto memory_layout_c_v + = T::template find_elem_v; +template +constexpr auto memory_alignment_c_v + = T::template find_elem_v; +template +constexpr auto gpu_arch_v = T::template find_elem_v; enum class tune_key_value : uint8_t { pre_processing_default, @@ -68,45 +98,28 @@ enum class tune_key_value : uint8_t { // parameter optimizer enum class param_optimizer_tag : uint8_t { kernel, work_group }; +// optimizer_level (currently only useful with param_optimizer_decision_tree) +enum class param_optimizer_level : uint8_t { + full, // optimize all available options + keep_shape, // optimize all except keeping the original wg/sg tile shape +}; template struct param_optimizer; struct param_optimizer_base { template - struct validate_attribute { - static constexpr bool value = []() constexpr { - bool valid = true; - valid &= std::is_same::type, - typename U::template find_elem_t< - tune_key::data_type_a>::type>::value; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= std::is_same::type, - typename U::template find_elem_t< - tune_key::data_type_b>::type>::value; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= std::is_same::type, - typename U::template find_elem_t< - tune_key::data_type_c>::type>::value; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - return valid; - } - (); - }; + static constexpr bool valid_attribute_v + = std::is_same_v, data_type_a_t> // + && memory_layout_a_v == memory_layout_a_v // + && memory_alignment_a_v == memory_alignment_a_v // + && std::is_same_v, data_type_b_t> // + && memory_layout_b_v == memory_layout_b_v // + && memory_alignment_b_v == memory_alignment_b_v // + && std::is_same_v, data_type_c_t> // + && memory_layout_c_v == memory_layout_c_v // + && memory_alignment_c_v == memory_alignment_c_v // + && gpu_arch_v == gpu_arch_v; }; // parameter adaptor diff --git a/include/kernel/default_config/decision_tree_policy.hpp b/include/kernel/default_config/decision_tree_policy.hpp index c8b0b3c21..e84ca2bfe 100644 --- a/include/kernel/default_config/decision_tree_policy.hpp +++ b/include/kernel/default_config/decision_tree_policy.hpp @@ -264,53 +264,47 @@ struct kslicing_handler { }; } // namespace decision_tree_rule -template +template struct fallback_optimizer { - using type = typename opt_dict_t_::template update_t< - elem_t_t::type>, - elem_t_t::type>, - elem_t_t::type>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>>; + using type = typename opt_dict::template update_t< + elem_t_t>, + elem_t_t>, + elem_t_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>>; }; -template +template struct decision_tree_optimizer : param_optimizer_base { struct impl { - using type = typename dict_t_ ::template update_generator_t< - decision_tree_rule::data_type_handler>:: - template update_generator_t< - decision_tree_rule::tile_shape_handler>:: - template update_generator_t< - decision_tree_rule::kslicing_handler>; + template typename G> + using apply_handeler = T::template update_generator_t; + static constexpr bool keep_shape + = (mode_ == param_optimizer_level::keep_shape); + + using t0 = dict_t_; + using t1 = apply_handeler; + using t2_0 = apply_handeler; + using t2 = std::conditional_t; + using t3 = apply_handeler; + + using type = t3; + + // If any of data_type / mem_layout / mem_align is changed, + // then change it back via fallback_optimizer using fallback_type = fallback_optimizer; }; static constexpr bool use_fallback - = !(param_optimizer_base::template validate_attribute::value); - using type = typename std::conditional::type::type; + = !(param_optimizer_base::template valid_attribute_v); + using type = typename std::conditional_t::type; }; } // namespace gpu::xetla diff --git a/include/kernel/default_config/dummy_policy.hpp b/include/kernel/default_config/dummy_policy.hpp index 7bed9f2c9..b18c882e7 100644 --- a/include/kernel/default_config/dummy_policy.hpp +++ b/include/kernel/default_config/dummy_policy.hpp @@ -255,8 +255,8 @@ struct dummy_optimizer : param_optimizer_base { using fallback_type = fallback_optimizer; }; static constexpr bool use_fallback - = !(param_optimizer_base::template validate_attribute::value); + = !(param_optimizer_base::template valid_attribute_v); using type = typename std::conditional::type::type; }; diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 455eeeaf0..b56e1ab73 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -29,12 +29,12 @@ namespace kernel { template > struct default_gemm_config_t : param_adaptor::template update_dict_t< typename tune_option::template update_t< elem_t_t, elem_v_t, elem_v_t>>>::type> {}; + arch_tag>>>>::type> {}; template > -struct default_gemm_t - : default_gemm_config_t::type {}; +using default_gemm_t = typename default_gemm_config_t::type; } // namespace kernel template @@ -74,12 +73,20 @@ struct param_optimizer { param_optimizer_type> != dict_t_::impl::key_not_found) && (dict_t_::template find_elem_v == tune_key_value::param_optimizer_decision_tree); + static constexpr auto arch_tag + = (dict_t_::impl::template find_elem_index< + tune_key::gpu_arch> != dict_t_::impl::key_not_found) + ? dict_t_::template find_elem_v + : gpu_arch::Xe; + static constexpr auto optimizer_level + = dict_t_::template find_elem_v; using type = typename std::conditional, + decision_tree_optimizer, dummy_optimizer>::type::type; + kernel::param_kslicing_g1l1_t, + kernel::param_kslicing_g2l1_t, + kernel::param_kslicing_g1l2_t>>::type::type; }; template @@ -124,12 +131,12 @@ namespace group { template > + typename wg_shape, uint32_t wg_tile_k, gpu_arch arch_tag = gpu_arch::Xe, + typename tune_option = dict_t<>> struct default_gemm_selector_config_t : param_adaptor::template update_dict_t< typename tune_option::template update_t< elem_t_t, elem_v_t, elem_v_t, elem_v_t>>>::type> {}; + arch_tag>>>>::type> {}; template > -struct default_gemm_selector_t - : default_gemm_selector_config_t::type { -}; + typename wg_shape, uint32_t wg_tile_k, gpu_arch arch_tag = gpu_arch::Xe, + typename tune_option = dict_t<>> +using default_gemm_selector_t = typename default_gemm_selector_config_t::type; template > + gpu_arch arch_tag = gpu_arch::Xe, typename tune_option = dict_t<>> struct default_epilogue_selector_config_t : param_adaptor::template update_dict_t< typename tune_option::template update_t< elem_t_t, elem_v_t, elem_v_t, elem_v_t>>>::type> {}; + arch_tag>>>>::type> {}; template > -struct default_epilogue_selector_t - : default_epilogue_selector_config_t::type {}; + gpu_arch arch_tag = gpu_arch::Xe, typename tune_option = dict_t<>> +using default_epilogue_selector_t = + typename default_epilogue_selector_config_t::type; } // namespace group template @@ -201,10 +207,18 @@ struct param_optimizer { param_optimizer_type> != dict_t_::impl::key_not_found) && (dict_t_::template find_elem_v == tune_key_value::param_optimizer_decision_tree); + static constexpr auto optimizer_level + = dict_t_::template find_elem_v; + static constexpr auto arch_tag + = (dict_t_::impl::template find_elem_index< + tune_key::gpu_arch> != dict_t_::impl::key_not_found) + ? dict_t_::template find_elem_v + : gpu_arch::Xe; using type = typename std::conditional, + decision_tree_optimizer, dummy_optimizer>::type::type; + group::param_dict1_wg_t>>::type::type; }; template @@ -304,4 +318,4 @@ struct param_adaptor { using type = epilogue_t; }; -} // namespace gpu::xetla \ No newline at end of file +} // namespace gpu::xetla diff --git a/include/kernel/gemm/gemm_preset.hpp b/include/kernel/gemm/gemm_preset.hpp index fde7f21ab..5afeef300 100644 --- a/include/kernel/gemm/gemm_preset.hpp +++ b/include/kernel/gemm/gemm_preset.hpp @@ -45,36 +45,40 @@ using param_performance_default elem_v_t, elem_v_t>; +template using param_runtime_default = dict_t, elem_v_t, - elem_v_t, + elem_v_t, elem_t_t>, + group::epilogue_policy_default>, elem_v_t, elem_t_t>>; + kernel::group_swizzle_default>>; } // namespace detail - +template using default_param_t = dict_t<>::template update_dict_t< detail::param_dtype_bf16_bf16_bf16>::template update_dict_t::template update_dict_t::template update_dict_t::template update_dict_t::template update_dict_t:: + param_runtime_default>:: template update_t, elem_v_t, elem_v_t, elem_t_t>, elem_t_t>, elem_v_t>; + tune_key_value::param_optimizer_dummy>, + elem_v_t>; namespace kernel { -using param_kslicing_g1l1_t = default_param_t::template update_t< +template +using param_kslicing_g1l1_t = default_param_t::template update_t< elem_v_t, elem_v_t, elem_t_t>, @@ -83,7 +87,8 @@ using param_kslicing_g1l1_t = default_param_t::template update_t< elem_v_t>; -using param_kslicing_g2l1_t = default_param_t::template update_t< +template +using param_kslicing_g2l1_t = default_param_t::template update_t< elem_v_t, elem_v_t, elem_t_t>, @@ -92,7 +97,8 @@ using param_kslicing_g2l1_t = default_param_t::template update_t< elem_v_t>; -using param_kslicing_g1l2_t = default_param_t::template update_t< +template +using param_kslicing_g1l2_t = default_param_t::template update_t< elem_v_t, elem_v_t, elem_t_t>, @@ -104,7 +110,8 @@ using param_kslicing_g1l2_t = default_param_t::template update_t< } // namespace kernel namespace group { -using param_dict1_wg_t = default_param_t::template update_t< +template +using param_dict1_wg_t = default_param_t::template update_t< elem_t_t, elem_t_t>, elem_v_t, @@ -112,6 +119,6 @@ using param_dict1_wg_t = default_param_t::template update_t< elem_v_t, elem_v_t, elem_t_t>>; + group::epilogue_policy_default>>; } } // namespace gpu::xetla diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index 8f2da1cc3..78c63ffbd 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -35,7 +35,7 @@ namespace gpu::xetla::kernel { template class gemm_universal_t, gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag == gpu_arch::Xe)>> { + std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::Xe)>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; @@ -176,8 +176,8 @@ class gemm_universal_t, gemm_t_, /// @return The size of local memory required. __XETLA_API static constexpr uint32_t get_slm_size() { constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size; - static_assert(size <= (128 * 1024), - "The local memory size should be less than 128KB!"); + static_assert(size <= arch_attr_t::local_mem_size, + "The local memory size excess!"); return size; }; diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index 595b072ea..6415c39c1 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -39,7 +39,7 @@ template , gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag == gpu_arch::Xe)>> { + std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::Xe)>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; @@ -239,8 +239,8 @@ class gemm_universal_t::local_mem_size, + "The local memory size excess!"); return size; } diff --git a/include/kernel/gemm/impl/stream_k_xe.hpp b/include/kernel/gemm/impl/stream_k_xe.hpp index 620242062..eec3948d9 100644 --- a/include/kernel/gemm/impl/stream_k_xe.hpp +++ b/include/kernel/gemm/impl/stream_k_xe.hpp @@ -217,8 +217,8 @@ class gemm_universal_t, gemm_t_, /// @return The size of local memory required. __XETLA_API static constexpr uint32_t get_slm_size() { constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size; - static_assert(size <= (128 * 1024), - "The local memory size should be less than 128KB!"); + static_assert(size <= arch_attr_t::local_mem_size, + "The local memory size excess!"); return size; }; diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 7a518899b..bd4934a1f 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -590,7 +590,7 @@ template , tile_desc_, msg_type::unaligned_2d, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1308,7 +1308,8 @@ struct prefetch_payload_t< tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Dg2)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Dg2 + && (tile_size_y_ != 1 || block_size_y_ != 1))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1499,7 +1500,8 @@ struct prefetch_payload_t< tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ == gpu_arch::Xe) + && (tile_size_y_ != 1 || block_size_y_ != 1)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1665,7 +1667,7 @@ struct prefetch_payload_t< mem_desc_t, tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1763,7 +1765,7 @@ template , tile_desc_, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 9205c68a3..a67400987 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -45,7 +45,7 @@ struct check_prefetch_type { static constexpr bool is_local_xe = ((payload_t::memory_space == mem_space::local) - && (payload_t::arch_tag == gpu_arch::Xe)); + && (payload_t::arch_tag <= gpu_arch::Xe)); }; } // namespace detail diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 379bd062d..866aabc8d 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -177,7 +177,7 @@ struct gelu_fwd_w_op_t {}; /// @brief Is the element-wise gelu training forward op functor, specialized for Xe architecture. template struct gelu_fwd_w_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_out = dtype_out_; using mem_desc_w_t = mem_desc_t; @@ -295,7 +295,7 @@ struct gelu_bwd_op_t {}; /// @brief Is the element-wise gelu backward op functor, specialized for Xe architecture. template struct gelu_bwd_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_in = dtype_in_; using mem_desc_x_t = mem_desc_t; @@ -490,7 +490,7 @@ struct scale_v_offset_v_op_t {}; /// @brief Is the scale_v_offset_v op functor, specialized for Xe architecture. template struct scale_v_offset_v_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using scale_dtype = scale_dtype_; using offset_dtype = offset_dtype_; @@ -619,7 +619,7 @@ struct scale_v_op_t {}; /// @brief Is the scale_v op functor, specialized for Xe architecture. template struct scale_v_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using scale_dtype = scale_dtype_; using scale_mem_desc_t @@ -933,7 +933,7 @@ struct dropout_op_t {}; /// @brief Is the dropout op functor, specialized for Xe architecture. template struct dropout_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1010,7 +1010,7 @@ struct rng_dropout_op_t {}; /// @brief Is the random number generator and dropout op functor, specialized for Xe architecture. template struct rng_dropout_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1114,7 +1114,7 @@ struct scalar_mul_op_t {}; /// @brief Is the scalar_multiply op functor, specialized for Xe architecture. template struct scalar_mul_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; diff --git a/media/docs/construct_a_gemm.md b/media/docs/construct_a_gemm.md index a5a3cbf17..ef54b287a 100644 --- a/media/docs/construct_a_gemm.md +++ b/media/docs/construct_a_gemm.md @@ -6,16 +6,16 @@ As shown in the diagram below, each workgroup will calculate a sub-matrix, repre ![ALT](/media/docs/dom.jpg "GEMM decomposition by workgroup and subgroup") -## Basic Components +## Basic Components 1. Select a `GEMM building block`, considering the division of work-group and sub-group -2. Decide if `splitK` or `steamK` is needed in specific shape +2. Decide if `splitK` or `steamK` is needed in specific shape 3. Define `epilogue` that specifies what you want to fuse after the GEMM computation based on accumulator 4. Instantiate a `gemm` implementation by the selections from 1)-3). For a runnable code example, you can refer to the code in the [02_basic_gemm](/examples/02_basic_gemm). -### Task Mapping +### Task Mapping Before launching the GPU kernel, it is crucial to determine how to map the entire GEMM computation onto the GPU, considering work-group and sub-group configurations. Efficiently utilizing GPU resources requires careful consideration of factors such as the operation's shape, data type, and the hardware specifications of the GPU. A typical configuration for workgroups and subgroups may resemble the example below, especially when the input shape is sufficient to fully utilize the GPU. ```c++ @@ -64,7 +64,7 @@ Alternatively, the subgroup-level splitK is also available i which can accumulat ![ALT](/media/docs/subgroup_splitK.jpg "split K in subgroup level") -For kernel level API, we can set two parameters in dispatch policy of `gemm_universal` API. Definitely, you can set both value to large than 1 for mixing workgroup and subgroup level split K together. +For kernel level API, we can set two parameters in dispatch policy of `gemm_universal` API. Definitely, you can set both value to large than 1 for mixing workgroup and subgroup level split K together. ```c++ using dispatch_policy @@ -87,7 +87,7 @@ decide the location of input and output matrix which is either from global or sh mem_space::global, // memory reading from global mem for B 8, // buffer alignment for A, in unit of element 8, // buffer alignment for B, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results tile_shape, // computation tile shape sg_tile_k, // elements in each iteration mma_engine::xmx, // compute engine @@ -122,7 +122,7 @@ class epilogue_t {}; - `tile_shape` is the problem size of each group and subgroup. - `mem_desc_c` is the description of buffer `c`, which includes `memory data type`, `memory space` and `memory layout`... -In example [03_gemm_relu_bias](/examples/03_gemm_relu_bias), a chain of operations is effectively fused into the GEMM computation. +In example [03_gemm_relu_bias](/examples/03_gemm_relu_bias), a chain of operations is effectively fused into the GEMM computation. First, using pre-defined post-operations `relu` and `bias_add`, and then pass it to `epilogue_policy::tile_op_t`. ```c++ @@ -132,7 +132,7 @@ using tile_op_t = chained_tile_op_t< >; ``` -### GEMM Instantiate +### GEMM Instantiate After configuration of BRGEMM and epilogue, it's simple to build entire GEMM with: - assigning tasks to each group, setting working boundaries and starting position accordingly. @@ -153,7 +153,7 @@ Finally, the actual data will be passed using gemm_op_t::arguments_t, and all of typename gemm_op_t::arguments_t arg(matrix_n, matrix_k, matrix_m, A, matrix_k, B, matrix_n, C, matrix_n); ``` -```c++ +```c++ gemm_op_t gemm_op; gemm_op(item, arg); diff --git a/tests/integration/default_config/group_gemm/kernel_func.hpp b/tests/integration/default_config/group_gemm/kernel_func.hpp index 13c91ba1f..aecaf6fd9 100644 --- a/tests/integration/default_config/group_gemm/kernel_func.hpp +++ b/tests/integration/default_config/group_gemm/kernel_func.hpp @@ -34,11 +34,11 @@ struct default_config_group_gemm_test_func { // should larger than 8 static constexpr uint32_t k_stride = sg_k; - // Step 1: define mirco-kernel's configuration + // Step 1: define Micro-kernel's configuration using wg_shape = shape; using sg_shape = shape; - // Mirco-kernel configuration + // Micro-kernel configuration using gemm_tune_option = dict_t< elem_v_t, @@ -59,7 +59,7 @@ struct default_config_group_gemm_test_func { layout_b, // memory layout for B 8, // leading dimension alignment for B, in unit of element mem_space::global, // memory reading from global mem for B - dtype_acc, // accumulator data type for intermediate resutls + dtype_acc, // accumulator data type for intermediate results wg_shape, // computation tile shape k_stride, // elements in each iteration gpu_arch::Xe, // GPU arch diff --git a/tests/integration/default_config/kernel_gemm/kernel_func.hpp b/tests/integration/default_config/kernel_gemm/kernel_func.hpp index 84e9ab1b3..f16a50c97 100644 --- a/tests/integration/default_config/kernel_gemm/kernel_func.hpp +++ b/tests/integration/default_config/kernel_gemm/kernel_func.hpp @@ -48,7 +48,7 @@ struct default_config_kernel_gemm_test_func { dtype_c, // output datatype for C mem_layout::row_major, // memory layout for C 8, // leading dimension alignment for C, in unit of element - dtype_acc, // accumulator data type for intermediate resutls + dtype_acc, // accumulator data type for intermediate results gpu_arch::Xe, // GPU arch tune_option>; diff --git a/tests/integration/gemm/bf16_stream_k/main.cpp b/tests/integration/gemm/bf16_stream_k/main.cpp index df3dfc54b..55b3570b9 100644 --- a/tests/integration/gemm/bf16_stream_k/main.cpp +++ b/tests/integration/gemm/bf16_stream_k/main.cpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include using namespace gpu::xetla; //The number of times the kernel is executed @@ -245,7 +245,7 @@ void stream_k_gemm_run(uint32_t iter) { static constexpr uint32_t periodic_sync_interval = 4; static constexpr uint32_t prefetch_distance = 4; - // Mirco-kernel configuration + // Micro-kernel configuration using gemm_config = typename xetla::group::gemm_selector_t< data_type_a, // input datatype for A data_type_b, // input datatype for B @@ -255,7 +255,7 @@ void stream_k_gemm_run(uint32_t iter) { mem_space::global, // memory reading from global mem for B 8, // leading dimension for A, in unit of element 8, // leading dimension for B, in unit of element - data_type_acc, // accumulator data type for intermediate resutls + data_type_acc, // accumulator data type for intermediate results tile_shape, // computation tile shape sg_tile_k, // elements in each iteration mma_engine::xmx, // compute engine @@ -299,9 +299,11 @@ void stream_k_gemm_run(uint32_t iter) { gemm_config::k_stride, wg_tile_n, sg_tile_m, sg_tile_n, avail_xecores); - - static const std::string env_set_str = "SYCL_PROGRAM_COMPILE_OPTIONS= -vc-codegen -doubleGRF -vc-disable-indvars-opt -Xfinalizer ' -printregusage -enableBCR -DPASTokenReduction '"; - putenv(const_cast(env_set_str.c_str())); + static const std::string env_set_str + = "SYCL_PROGRAM_COMPILE_OPTIONS= -vc-codegen -doubleGRF " + "-vc-disable-indvars-opt -Xfinalizer ' -printregusage -enableBCR " + "-DPASTokenReduction '"; + putenv(const_cast(env_set_str.c_str())); //Define and initialize the data required for the calculation auto A = alloc_device_and_init( size_a, @@ -434,7 +436,7 @@ void stream_k_gemm_run(uint32_t iter) { } static const std::string env_unset_str = "SYCL_PROGRAM_COMPILE_OPTIONS="; - putenv(const_cast(env_unset_str.c_str())); + putenv(const_cast(env_unset_str.c_str())); ASSERT_EQ(0, gemm_result_validate(A, B, C, Bias, matrix_m, matrix_k, matrix_n, diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 01da3e8a8..10c859631 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -17,9 +17,9 @@ #pragma once #include "kernel_func.hpp" +#include #include #include -#include class TestBase { public: diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index fdc579917..6e52131dd 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -16,8 +16,8 @@ #include "common.hpp" #include "kernel_func.hpp" -#include #include +#include std::string esimd_compile_string = " -vc-codegen -doubleGRF " diff --git a/tests/unit/tile_load_store/main.cpp b/tests/unit/tile_load_store/main.cpp index fcfc6a3c3..82f3eebba 100644 --- a/tests/unit/tile_load_store/main.cpp +++ b/tests/unit/tile_load_store/main.cpp @@ -161,8 +161,9 @@ TEST(tile_load_store, esimd) { cl::sycl::nd_range<1> nd_range({1}, {1}); auto result_validate = std::bind(tile_load_store_result_validate, _1, _2, _3, 128, 64, 32, 32, 0); - kernel_run>( - nd_range, result_validate); + kernel_run>(nd_range, result_validate); } TEST(tile_load_transpose_store_1, esimd) { @@ -266,8 +267,8 @@ TEST(tile_load_store_unaligned_2d, esimd) { auto result_validate = std::bind(tile_load_store_result_validate, _1, _2, _3, 127, 63, 32, 32, 0); kernel_run>(nd_range, result_validate); + tile_load_store_unaligned_2d_func>(nd_range, result_validate); } TEST(tile_load_store_oob_1, esimd) { diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp index 3d85114da..dd30d2756 100644 --- a/tests/utils/execution.hpp +++ b/tests/utils/execution.hpp @@ -16,6 +16,8 @@ #pragma once +#include +#include #include "common.hpp" #include "profiling.hpp" #include "xetla.hpp" @@ -89,11 +91,13 @@ void gemm_exec(const std::string &compile_str, size_t batch = 1) { std::vector kernelId = {get_kernel_id()}; auto inputBundle = get_kernel_bundle(context, kernelId); - static const std::string env_set_str = "SYCL_PROGRAM_COMPILE_OPTIONS="+compile_str; - putenv(const_cast(env_set_str.c_str())); + static const std::string env_set_str + = "SYCL_PROGRAM_COMPILE_OPTIONS=" + compile_str; + putenv(const_cast(env_set_str.c_str())); kernel_bundle exeBundle = build(inputBundle); - static const std::string env_unset_str = "SYCL_PROGRAM_COMPILE_OPTIONS="; - putenv(const_cast(env_unset_str.c_str())); + static const std::string env_unset_str + = "SYCL_PROGRAM_COMPILE_OPTIONS="; + putenv(const_cast(env_unset_str.c_str())); using namespace gpu::xetla::group; using namespace gpu::xetla::kernel; @@ -164,12 +168,15 @@ void gemm_exec(const std::string &compile_str, size_t batch = 1) { } } -/// @brief The template function to execute kernel in esimd way for unit test framework +/// @brief The template function to execute kernel in esimd way for unit test +/// framework /// -/// @tparam data_type data_type The data type of buffer used in kernel and buffer allocation +/// @tparam data_type data_type The data type of buffer used in kernel and +/// buffer allocation /// @tparam KERNEL the kernel function struct /// @param nd_range the range of workitems -/// @param validate_result validation function, taking 3 parameters buffer A, B as input C as output +/// @param validate_result validation function, taking 3 parameters buffer A, B +/// as input C as output /// template @@ -227,3 +234,109 @@ void kernel_run(auto nd_range, auto validate_result) { free(B_host); free(C_host); } + +/// @brief Using gpu_arch of current machine to run F::exec +/// +/// @tparam F The gpu_arch-templated function wrapper +/// +/// @example example usage in /examples/01 or /examples/02 +template