From 8f75a3c36cd6dc30c256a14c87187cd5244483b1 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 17 Jan 2024 02:37:56 +0000 Subject: [PATCH 01/11] backup backup bug fix for Arc_XMX bugfix for dg2 modify UT of Dg2 fp16 revert debugging changes --- CMakeLists.txt | 1 - _clang-format | 122 ++++++++++++++++++ examples/01_gemm_universal/gemm_universal.cpp | 14 +- examples/02_basic_gemm/basic_gemm.cpp | 5 +- .../scaled_dot_product_attention.cpp | 18 +-- .../softmax.hpp | 4 +- include/common/common.hpp | 6 + include/common/utils/common.hpp | 4 +- include/common/utils/raw_send_nbarrier.hpp | 7 +- include/group/gemm/compute_policy.hpp | 14 +- include/group/gemm/impl/default_xmx_xe.hpp | 2 +- include/kernel/gemm/default_gemm.hpp | 30 ++--- include/kernel/gemm/impl/default_xe.hpp | 2 +- include/kernel/gemm/impl/kslicing_xe.hpp | 2 +- include/subgroup/tile/impl/payload_xe.hpp | 9 +- include/subgroup/tile/impl/prefetch_xe.hpp | 2 +- .../subgroup/tile/impl/tile_op_functor.hpp | 14 +- tests/integration/gemm/fp16/common.hpp | 20 +++ tests/integration/gemm/fp16/kernel_func.hpp | 12 +- tests/integration/gemm/fp16/main.cpp | 4 +- tests/unit/tile_load_store/main.cpp | 9 +- 21 files changed, 228 insertions(+), 73 deletions(-) create mode 100644 _clang-format diff --git a/CMakeLists.txt b/CMakeLists.txt index ebfee8f09..2b46837d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,6 @@ else() # Windows endif() project(XeTLA) - include(CTest) enable_testing() diff --git a/_clang-format b/_clang-format new file mode 100644 index 000000000..eee0a4ee7 --- /dev/null +++ b/_clang-format @@ -0,0 +1,122 @@ +#=============================================================================== +# Copyright 2016-2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +--- +Language: Cpp +AccessModifierOffset: -4 +AlignAfterOpenBracket: DontAlign +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: DontAlign +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: All +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeComma +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 8 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +FixNamespaceComments: true +ForEachMacros: +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '<[[:alnum:].]+>' + Priority: 0 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true +# IndentPPDirectives: AfterHash +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Right +ReflowComments: false +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - for_ + - PRAGMA_OMP + - PRAGMA_OMP_SIMD +TabWidth: 4 +UseTab: Never +... +# vim:ft=conf et ts=2 sw=2 diff --git a/examples/01_gemm_universal/gemm_universal.cpp b/examples/01_gemm_universal/gemm_universal.cpp index 55d94249d..ed566cea3 100644 --- a/examples/01_gemm_universal/gemm_universal.cpp +++ b/examples/01_gemm_universal/gemm_universal.cpp @@ -33,9 +33,9 @@ void gemm_universal_run(uint32_t iter) { size_t size_b = matrix_k * matrix_n; size_t size_c = matrix_m * matrix_n; - using data_type_a = bf16; - using data_type_b = bf16; - using data_type_c = bf16; + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; using data_type_acc = float; //Turn on the profiling property to facilitate subsequent profiling @@ -91,7 +91,11 @@ void gemm_universal_run(uint32_t iter) { tune_key_value::dispatch_policy_kslicing>, elem_v_t, elem_v_t, - elem_t_t>>; + elem_t_t>, + elem_t_t>, + elem_t_t>>; using gemm_op_t = gpu::xetla::kernel::default_gemm_t< data_type_a, // input datatype for A mem_layout::row_major, // memory layout for A @@ -103,7 +107,7 @@ void gemm_universal_run(uint32_t iter) { mem_layout::row_major, // memory layout for C 8, // leading dimension alignment for C, in unit of element data_type_acc, // accumulator data type for intermediate resutls - gpu_arch::Xe, // GPU arch + gpu_arch::Dg2, // GPU arch tune_option>; // allocate temp buffers for global split diff --git a/examples/02_basic_gemm/basic_gemm.cpp b/examples/02_basic_gemm/basic_gemm.cpp index 918a78e10..330e85e09 100644 --- a/examples/02_basic_gemm/basic_gemm.cpp +++ b/examples/02_basic_gemm/basic_gemm.cpp @@ -135,7 +135,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { data_type_acc, // accumulator data type for intermediate resutls wg_shape, // computation tile shape k_stride, // elements in each iteration - gpu_arch::Xe, // GPU arch + arch_tag_, // GPU arch gemm_tune_option>; gemm_t gemm; @@ -149,7 +149,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { mem_space::global, // memory writing to global mem for C wg_shape, // computation tile shape k_stride, // elements in each iteration - gpu_arch::Xe, // GPU arch + arch_tag_, // GPU arch epilogue_tune_option>; // Step 3: define the shared local memory usages @@ -194,7 +194,6 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) { // the results is in the matAcc rather than real output C typename gemm_t::work_group_t g(item.get_local_linear_id()); gemm(g, matAcc, gemm_args); - // Step 7: write the results from matACC to real output C epilogue_t epilogue; epilogue(g, matAcc, md_c); diff --git a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp index cb61058f6..0060c299e 100644 --- a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp +++ b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp @@ -254,14 +254,14 @@ void sdp_fwd_run(uint32_t iter) { using wg_shape0 = shape; using sg_shape0 = shape; - using post_op0_t = scalar_mul_op_t; + using post_op0_t = scalar_mul_op_t; using post_op1_t = elemwise_reduce_op_t; + dtype_in, gpu_arch::Dg2>; using post_op_t = chained_tile_op_t; using epilogue_policy0 = xetla::group::epilogue_policy_tile_op; - using group_swizzle = group_swizzle_default; + gpu_arch::Dg2>; + using group_swizzle = group_swizzle_default; using tune_option0 = dict_t< elem_v_t; using epilogue0_t = xetla::group::default_epilogue_selector_t< dtype_sfx, // onput datatype for C @@ -298,7 +298,7 @@ void sdp_fwd_run(uint32_t iter) { local, // memory writing to local mem for C wg_shape0, // computation tile shape wg_tile_k_qk, // elements in each iteration - gpu_arch::Xe, // GPU arch + gpu_arch::Dg2, // GPU arch tune_option0>; using gemm_op0_t = gemm_universal_t< dispatch_policy_default, gemm0_t, @@ -315,7 +315,7 @@ void sdp_fwd_run(uint32_t iter) { // we only need to do thread sync while store gemm results to SLM // one barrier is enough for that xetla_nbarrier_init<1>(); - xetla_nbarrier_t + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier(0, nbarrier_role::producer_consumer); @@ -386,7 +386,7 @@ void sdp_fwd_run(uint32_t iter) { float, // accumulator data type for intermediate resutls wg_shape1, // computation tile shape wg_tile_k_sv, // elements in each iteration - gpu_arch::Xe, // GPU arch + gpu_arch::Dg2, // GPU arch tune_option1>; // gemm arguments include matA & matB load information and @@ -465,7 +465,7 @@ void sdp_fwd_run(uint32_t iter) { 0); xetla_tstore_global(transpose_tdecs, out_reg); + gpu_arch::Dg2>(transpose_tdecs, out_reg); } }); }); diff --git a/examples/08_scaled_dot_product_attention/softmax.hpp b/examples/08_scaled_dot_product_attention/softmax.hpp index 184bc311b..58fb1c688 100644 --- a/examples/08_scaled_dot_product_attention/softmax.hpp +++ b/examples/08_scaled_dot_product_attention/softmax.hpp @@ -57,7 +57,7 @@ struct xetla_softmax_fwd_t { mem_desc_t, softmax_tile_desc_t, subgroup::msg_type_v, - gpu_arch::Xe>; + gpu_arch::Dg2>; // this tile will store the softmax result to global memory using softmax_store_t = subgroup::tile_t; @@ -65,7 +65,7 @@ struct xetla_softmax_fwd_t { mem_desc_t, softmax_tile_desc_t, subgroup::msg_type_v, - gpu_arch::Xe>; + gpu_arch::Dg2>; struct arguments_t { // available while original data is from SLM diff --git a/include/common/common.hpp b/include/common/common.hpp index 97c831c4b..cccc09bbb 100644 --- a/include/common/common.hpp +++ b/include/common/common.hpp @@ -21,3 +21,9 @@ #include #include + +#ifdef __SYCL_DEVICE_ONLY__ +#define CONSTANT __attribute__((opencl_constant)) +#else +#define CONSTANT +#endif diff --git a/include/common/utils/common.hpp b/include/common/utils/common.hpp index e67ccc829..52a4a61e3 100644 --- a/include/common/utils/common.hpp +++ b/include/common/utils/common.hpp @@ -46,7 +46,7 @@ constexpr uint32_t get_element_size_code() { enum class lsc_action : uint8_t { prefetch, load, store, atomic }; template -constexpr std::enable_if_t +constexpr std::enable_if_t check_lsc_cache_hint() { if constexpr (Action == lsc_action::prefetch) { // https://gfxspecs.intel.com/Predator/Home/Index/53560 @@ -126,7 +126,7 @@ get_prefetch_cache_hint_code() { } template -constexpr std::enable_if_t +constexpr std::enable_if_t get_store_cache_hint_code() { check_lsc_cache_hint(); if (L1H == cache_hint::none && L2H == cache_hint::none) { diff --git a/include/common/utils/raw_send_nbarrier.hpp b/include/common/utils/raw_send_nbarrier.hpp index fb7b92ee1..5050b9c51 100644 --- a/include/common/utils/raw_send_nbarrier.hpp +++ b/include/common/utils/raw_send_nbarrier.hpp @@ -107,14 +107,15 @@ struct xetla_nbarrier_t { /// @brief Generic work-group split barrier. /// __XETLA_API void arrive() { - __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::signal>(); + // __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::signal>(); + __ESIMD_NS::barrier(); } /// @brief named barrier wait within subgroup. /// __XETLA_API void wait() { - __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::wait>(); - // __ESIMD_NS::barrier(); + // __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::wait>(); + __ESIMD_NS::barrier(); } /// @brief named barrier signal from subgroup. diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 43d6b6bef..bceb3f339 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -31,19 +31,20 @@ namespace gpu::xetla::group { /// @tparam perf_tuning_knob_ Is performance-related knobs. /// @tparam arch_tag_ Is the HW architecture. template + gpu_arch arch_tag_ = gpu_arch::Xe, typename enable = void> struct compute_policy_default_xmx {}; /// @brief Specialized for Xe architecture. -template -struct compute_policy_default_xmx { +template +struct compute_policy_default_xmx> { using compute_attr = compute_attr_; using perf_tuning_knob = perf_tuning_knob_; static constexpr int k_stride = perf_tuning_knob::k_stride; static constexpr int stages = perf_tuning_knob::stages; static constexpr int sync_freq = perf_tuning_knob::sync_freq; - static constexpr gpu_arch arch_tag = gpu_arch::Xe; + static constexpr gpu_arch arch_tag = arch_tag_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; using dtype_mma_b = typename compute_attr::dtype_b; @@ -53,7 +54,8 @@ struct compute_policy_default_xmx> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { public: using mem_desc_a_t = mem_desc_a_t_; using mem_desc_b_t = mem_desc_b_t_; diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 455eeeaf0..0b0584062 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -1,18 +1,18 @@ /******************************************************************************* -* Copyright (c) 2022-2023 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ /// @file /// C++ API @@ -304,4 +304,4 @@ struct param_adaptor { using type = epilogue_t; }; -} // namespace gpu::xetla \ No newline at end of file +} // namespace gpu::xetla diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index 8f2da1cc3..93c949e09 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -35,7 +35,7 @@ namespace gpu::xetla::kernel { template class gemm_universal_t, gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag == gpu_arch::Xe)>> { + std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::Xe)>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index 595b072ea..87dd795f8 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -39,7 +39,7 @@ template , gemm_t_, epilogue_t_, - std::enable_if_t<(group_swizzle_::arch_tag == gpu_arch::Xe)>> { + std::enable_if_t<(group_swizzle_::arch_tag <= gpu_arch::Xe)>> { using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 7a518899b..60fc4ee85 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -590,7 +590,7 @@ template , tile_desc_, msg_type::unaligned_2d, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1499,7 +1499,8 @@ struct prefetch_payload_t< tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ == gpu_arch::Xe) + && (tile_size_y_ != 1 || block_size_y_ != 1)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1665,7 +1666,7 @@ struct prefetch_payload_t< mem_desc_t, tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -1763,7 +1764,7 @@ template , tile_desc_, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + std::enable_if_t<(arch_tag_ <= gpu_arch::Xe)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 9205c68a3..a67400987 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -45,7 +45,7 @@ struct check_prefetch_type { static constexpr bool is_local_xe = ((payload_t::memory_space == mem_space::local) - && (payload_t::arch_tag == gpu_arch::Xe)); + && (payload_t::arch_tag <= gpu_arch::Xe)); }; } // namespace detail diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 379bd062d..866aabc8d 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -177,7 +177,7 @@ struct gelu_fwd_w_op_t {}; /// @brief Is the element-wise gelu training forward op functor, specialized for Xe architecture. template struct gelu_fwd_w_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_out = dtype_out_; using mem_desc_w_t = mem_desc_t; @@ -295,7 +295,7 @@ struct gelu_bwd_op_t {}; /// @brief Is the element-wise gelu backward op functor, specialized for Xe architecture. template struct gelu_bwd_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_in = dtype_in_; using mem_desc_x_t = mem_desc_t; @@ -490,7 +490,7 @@ struct scale_v_offset_v_op_t {}; /// @brief Is the scale_v_offset_v op functor, specialized for Xe architecture. template struct scale_v_offset_v_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using scale_dtype = scale_dtype_; using offset_dtype = offset_dtype_; @@ -619,7 +619,7 @@ struct scale_v_op_t {}; /// @brief Is the scale_v op functor, specialized for Xe architecture. template struct scale_v_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using scale_dtype = scale_dtype_; using scale_mem_desc_t @@ -933,7 +933,7 @@ struct dropout_op_t {}; /// @brief Is the dropout op functor, specialized for Xe architecture. template struct dropout_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1010,7 +1010,7 @@ struct rng_dropout_op_t {}; /// @brief Is the random number generator and dropout op functor, specialized for Xe architecture. template struct rng_dropout_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_mask = dtype_mask_; using mem_desc_mask_t = mem_desc_t; @@ -1114,7 +1114,7 @@ struct scalar_mul_op_t {}; /// @brief Is the scalar_multiply op functor, specialized for Xe architecture. template struct scalar_mul_op_t> { + std::enable_if_t<(arch_tag <= gpu_arch::Xe)>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 01da3e8a8..8a6ceaf16 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -41,6 +41,26 @@ class TestBase { static constexpr mma_engine engine = mma_engine::xmx; }; +class Test : public TestBase { +public: + static constexpr size_t mat_m = 256; + static constexpr size_t mat_n = 256; + static constexpr size_t mat_k = 256; + static constexpr size_t wg_m = 8; + static constexpr size_t wg_n = 32; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 32; + static constexpr uint32_t global_kslicing = 1; + static constexpr uint32_t local_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; + using data_type_a = fp16; + using data_type_b = fp16; + using data_type_c = fp16; + using data_type_acc = float; +}; + class Test0 : public TestBase { public: static constexpr size_t mat_m = 256; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index 98fdb9572..adaef295d 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -29,8 +29,8 @@ template struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 8; - static constexpr uint32_t prefetch_distance = 3; + static constexpr uint32_t periodic_sync_interval = 0; + static constexpr uint32_t prefetch_distance = 0; using compute_attr = typename std::conditional<(engine == mma_engine::fpu), compute_attr_t, @@ -40,9 +40,9 @@ struct fp16_gemm_test_func { using compute_policy = typename std::conditional<(engine == mma_engine::fpu), compute_policy_default_fpu, + gpu_arch::Dg2>, compute_policy_default_xmx>::type; + gpu_arch::Dg2>>::type; using mem_desc_input_a = mem_desc_t; using mem_desc_input_b = mem_desc_t; @@ -52,11 +52,11 @@ struct fp16_gemm_test_func { using gemm_t = gemm_t; - using epilogue_t = epilogue_t, + using epilogue_t = epilogue_t, tile_shape, mem_desc_output_c>; using group_swizzle - = gpu::xetla::kernel::group_swizzle_default; + = gpu::xetla::kernel::group_swizzle_default; using dispatch_policy = dispatch_policy_kslicing; diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index fdc579917..cc1c07b3a 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -28,8 +28,8 @@ template class fp16_gemm_test : public ::testing::Test {}; TYPED_TEST_SUITE_P(fp16_gemm_test); TYPED_TEST_P(fp16_gemm_test, esimd) { - gemm_exec, fp16_gemm_func>( - esimd_compile_string); + gemm_exec, fp16_gemm_func( + esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); using tests = ::testing::Types nd_range({1}, {1}); auto result_validate = std::bind(tile_load_store_result_validate, _1, _2, _3, 128, 64, 32, 32, 0); - kernel_run>( - nd_range, result_validate); + kernel_run>(nd_range, result_validate); } TEST(tile_load_transpose_store_1, esimd) { @@ -266,8 +267,8 @@ TEST(tile_load_store_unaligned_2d, esimd) { auto result_validate = std::bind(tile_load_store_result_validate, _1, _2, _3, 127, 63, 32, 32, 0); kernel_run>(nd_range, result_validate); + tile_load_store_unaligned_2d_func>(nd_range, result_validate); } TEST(tile_load_store_oob_1, esimd) { From eafab9e3b4f4187b4cea131366aa19397de5a30a Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Tue, 27 Feb 2024 08:39:58 +0000 Subject: [PATCH 02/11] update 2024.1 --- CMakeLists.txt | 1 + _clang-format | 122 ------------------ examples/01_gemm_universal/gemm_universal.cpp | 8 +- include/common/common.hpp | 6 - include/group/gemm/compute_policy.hpp | 12 +- include/kernel/gemm/default_gemm.hpp | 28 ++-- include/subgroup/tile/impl/payload_xe.hpp | 3 +- tests/integration/gemm/fp16/main.cpp | 6 +- 8 files changed, 31 insertions(+), 155 deletions(-) delete mode 100644 _clang-format diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b46837d6..ebfee8f09 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ else() # Windows endif() project(XeTLA) + include(CTest) enable_testing() diff --git a/_clang-format b/_clang-format deleted file mode 100644 index eee0a4ee7..000000000 --- a/_clang-format +++ /dev/null @@ -1,122 +0,0 @@ -#=============================================================================== -# Copyright 2016-2019 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - ---- -Language: Cpp -AccessModifierOffset: -4 -AlignAfterOpenBracket: DontAlign -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: DontAlign -AlignOperands: false -AlignTrailingComments: false -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: true -AllowShortCaseLabelsOnASingleLine: true -AllowShortFunctionsOnASingleLine: Inline -AllowShortIfStatementsOnASingleLine: true -AllowShortLoopsOnASingleLine: false -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: Yes -BinPackArguments: true -BinPackParameters: true -BraceWrapping: - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true -BreakBeforeBinaryOperators: All -BreakBeforeBraces: Custom -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializers: BeforeComma -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 80 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: true -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 8 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -FixNamespaceComments: true -ForEachMacros: -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '<[[:alnum:].]+>' - Priority: 0 -IncludeIsMainRegex: '(Test)?$' -IndentCaseLabels: true -# IndentPPDirectives: AfterHash -IndentPPDirectives: None -IndentWidth: 4 -IndentWrappedFunctionNames: false -KeepEmptyLinesAtTheStartOfBlocks: true -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 19 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 60 -PointerAlignment: Right -ReflowComments: false -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: true -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 1 -SpacesInAngles: false -SpacesInContainerLiterals: false -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -Standard: Cpp11 -StatementMacros: - - for_ - - PRAGMA_OMP - - PRAGMA_OMP_SIMD -TabWidth: 4 -UseTab: Never -... -# vim:ft=conf et ts=2 sw=2 diff --git a/examples/01_gemm_universal/gemm_universal.cpp b/examples/01_gemm_universal/gemm_universal.cpp index ed566cea3..5144c2e16 100644 --- a/examples/01_gemm_universal/gemm_universal.cpp +++ b/examples/01_gemm_universal/gemm_universal.cpp @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include #include "xetla.hpp" +#include enum class kslicing_impl_t : uint8_t { none = 0, global = 1, local = 2 }; @@ -33,9 +33,9 @@ void gemm_universal_run(uint32_t iter) { size_t size_b = matrix_k * matrix_n; size_t size_c = matrix_m * matrix_n; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; + using data_type_a = bf16; + using data_type_b = bf16; + using data_type_c = bf16; using data_type_acc = float; //Turn on the profiling property to facilitate subsequent profiling diff --git a/include/common/common.hpp b/include/common/common.hpp index cccc09bbb..97c831c4b 100644 --- a/include/common/common.hpp +++ b/include/common/common.hpp @@ -21,9 +21,3 @@ #include #include - -#ifdef __SYCL_DEVICE_ONLY__ -#define CONSTANT __attribute__((opencl_constant)) -#else -#define CONSTANT -#endif diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index bceb3f339..d22c45829 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -105,19 +105,21 @@ struct compute_policy_unaligned_xmx + gpu_arch arch_tag_ = gpu_arch::Xe, typename enable = void> struct compute_policy_default_fpu {}; /// @brief Specialized for Xe architecture. -template -struct compute_policy_default_fpu { +template +struct compute_policy_default_fpu> { using compute_attr = compute_attr_; using perf_tuning_knob = perf_tuning_knob_; static constexpr int k_stride = perf_tuning_knob::k_stride; static constexpr int stages = perf_tuning_knob::stages; static constexpr int sync_freq = perf_tuning_knob::sync_freq; - static constexpr gpu_arch arch_tag = gpu_arch::Xe; + static constexpr gpu_arch arch_tag = arch_tag_; + using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; using dtype_mma_b = typename compute_attr::dtype_b; diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 0b0584062..63625f4c8 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -1,18 +1,18 @@ /******************************************************************************* - * Copyright (c) 2022-2023 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ +* Copyright (c) 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ /// @file /// C++ API diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 60fc4ee85..302167448 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1308,7 +1308,8 @@ struct prefetch_payload_t< tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Dg2)>> { + std::enable_if_t<(arch_tag_ == gpu_arch::Dg2 + && (tile_size_y_ != 1 || block_size_y_ != 1))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index cc1c07b3a..6e52131dd 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -16,8 +16,8 @@ #include "common.hpp" #include "kernel_func.hpp" -#include #include +#include std::string esimd_compile_string = " -vc-codegen -doubleGRF " @@ -28,8 +28,8 @@ template class fp16_gemm_test : public ::testing::Test {}; TYPED_TEST_SUITE_P(fp16_gemm_test); TYPED_TEST_P(fp16_gemm_test, esimd) { - gemm_exec, fp16_gemm_func( - esimd_compile_string); + gemm_exec, fp16_gemm_func>( + esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); using tests = ::testing::Types Date: Tue, 27 Feb 2024 08:44:35 +0000 Subject: [PATCH 03/11] using instead of deriving --- include/kernel/gemm/default_gemm.hpp | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 63625f4c8..502b28076 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -61,10 +61,9 @@ template > -struct default_gemm_t - : default_gemm_config_t::type {}; +using default_gemm_t = typename default_gemm_config_t::type; } // namespace kernel template @@ -158,11 +157,10 @@ template > -struct default_gemm_selector_t - : default_gemm_selector_config_t::type { -}; +using default_gemm_selector_t = typename default_gemm_selector_config_t::type; template > -struct default_epilogue_selector_t - : default_epilogue_selector_config_t::type {}; +using default_epilogue_selector_t = + typename default_epilogue_selector_config_t::type; } // namespace group template From f53a39ba48c3fab3a2c23cd983d2cf0081afe86b Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Tue, 27 Feb 2024 08:55:06 +0000 Subject: [PATCH 04/11] reformat sdp --- .../scaled_dot_product_attention.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp index 0060c299e..9c25fc9cc 100644 --- a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp +++ b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp @@ -232,8 +232,8 @@ void sdp_fwd_run(uint32_t iter) { cl::sycl::nd_range<3> nd_range(group_range * local_range, local_range); constexpr uint32_t warmup = 10; - int64_t ops = int64_t(4 * batch_num * head_num * sequence_len) * sequence_len - * head_size; + int64_t ops = int64_t(4 * batch_num * head_num * sequence_len) + * sequence_len * head_size; profiling_helper prof("sdp", ops, "gflops"); try { for (uint32_t i = 0; i < iter + warmup; i++) { From c96aef2989e8f80f99a429610e523859464367be Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Fri, 1 Mar 2024 16:50:17 +0000 Subject: [PATCH 05/11] Enable dg2 sdp exmaple --- .editorconfig | 15 ++ CMakeLists.txt | 4 +- .../scaled_dot_product_attention.cpp | 198 ++++++++++-------- .../softmax.hpp | 12 +- examples/CMakeLists.txt | 7 +- include/common/utils/raw_send_nbarrier.hpp | 28 ++- include/group/gemm/compute_policy.hpp | 2 - include/kernel/default_config/common.hpp | 75 ++++--- .../default_config/decision_tree_policy.hpp | 72 +++---- .../kernel/default_config/dummy_policy.hpp | 4 +- include/kernel/gemm/default_gemm.hpp | 10 +- include/kernel/gemm/gemm_preset.hpp | 4 +- include/subgroup/tile/impl/payload_xe.hpp | 2 +- tests/integration/gemm/fp16/common.hpp | 22 +- tests/integration/gemm/fp16/kernel_func.hpp | 12 +- tests/utils/execution.hpp | 103 ++++++++- 16 files changed, 346 insertions(+), 224 deletions(-) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..d2f375aac --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +# C/C++ follows clang-format +[*.{c,cpp,h,hpp}] +indent_style = space +indent_size = 4 diff --git a/CMakeLists.txt b/CMakeLists.txt index ebfee8f09..d305e5a65 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,14 +9,14 @@ if (NOT CMAKE_BUILD_TYPE) endif() if(UNIX) else() # Windows - # Force CMake to use icx-cl rather than the default C++ compiler/linker + # Force CMake to use icx-cl rather than the default C++ compiler/linker # (needed on Windows only) # include (CMakeForceCompiler) # CMAKE_FORCE_CXX_COMPILER (icx-cl IntelDPCPP) set(CMAKE_CXX_COMPILER icx-cl) include (Platform/Windows-Clang) include(cmake/GTestExternal.cmake) -endif() +endif() project(XeTLA) diff --git a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp index 9c25fc9cc..420a854c4 100644 --- a/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp +++ b/examples/08_scaled_dot_product_attention/scaled_dot_product_attention.cpp @@ -135,10 +135,12 @@ int sdp_fwd_result_validate(dtype_in *q_device, dtype_in *k_device, return result ? 0 : 1; } -void sdp_fwd_run(uint32_t iter) { - // Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations. - // Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors. - // Please contact us for support. +template +void sdp_fwd_run(uint32_t iter, uint32_t warmup = 10) { + // Tips, the example demonstrates programming kernel with XeTLA, it works as + // expected with current configurations. Please make sure you fully understand + // these configurations before you do any modifications, incomplete changes + // may lead to unexpected behaviors. Please contact us for support. using dtype_in = bf16; using dtype_out = bf16; @@ -150,9 +152,11 @@ void sdp_fwd_run(uint32_t iter) { constexpr uint32_t matrix_n_qk = sequence_len; constexpr uint32_t matrix_k_qk = head_size; - constexpr uint32_t wg_tile_m_qk = 64; - constexpr uint32_t wg_tile_n_qk = 512; - constexpr uint32_t sg_tile_m_qk = 32; + constexpr uint32_t wg_tile_m_qksv = arch_tag == gpu_arch::Xe ? 64 : 32; + + constexpr uint32_t wg_tile_m_qk = wg_tile_m_qksv; + constexpr uint32_t wg_tile_n_qk = 512; // must == sl_kv + constexpr uint32_t sg_tile_m_qk = arch_tag == gpu_arch::Xe ? 32 : 16; constexpr uint32_t sg_tile_n_qk = 32; constexpr uint32_t wg_tile_k_qk = 32; @@ -161,10 +165,11 @@ void sdp_fwd_run(uint32_t iter) { constexpr uint32_t matrix_n_sv = head_size; constexpr uint32_t matrix_k_sv = sequence_len; - constexpr uint32_t wg_tile_m_sv = 64; - constexpr uint32_t wg_tile_n_sv = 64; + // constexpr uint32_t wg_tile_m_sv = 64; + constexpr uint32_t wg_tile_m_sv = wg_tile_m_qksv; + constexpr uint32_t wg_tile_n_sv = 64; // must == head_dim constexpr uint32_t sg_tile_m_sv = 8; - constexpr uint32_t sg_tile_n_sv = 16; + constexpr uint32_t sg_tile_n_sv = arch_tag == gpu_arch::Xe ? 16 : 8; constexpr uint32_t wg_tile_k_sv = 32; // buffer size of softmax row data @@ -178,11 +183,12 @@ void sdp_fwd_run(uint32_t iter) { auto context = queue.get_info(); auto device = queue.get_info(); - std::cout << "Running on " << device.get_info() << "\n"; + print_device_details(device); constexpr uint32_t size_qkv = matrix_m_qk * matrix_k_qk; constexpr uint32_t size_mask = matrix_m_qk * matrix_n_qk; constexpr uint32_t size_out = matrix_m_sv * matrix_n_sv; + const float scale_qk = 1.f / std::sqrt(head_size); auto q = alloc_device_and_init( batch_cnt * size_qkv, @@ -220,6 +226,11 @@ void sdp_fwd_run(uint32_t iter) { constexpr uint32_t subgroup_range_m = wg_tile_m_qk / sg_tile_m_qk; constexpr uint32_t subgroup_range_n = wg_tile_n_qk / sg_tile_n_qk; + constexpr uint32_t slm_size + = wg_tile_m_qk * wg_tile_n_qk * sizeof(dtype_sfx); + XETLA_ASSERT(slm_size <= device.get_info(), + "SLM size too large!"); + static_assert(subgroup_range_m * subgroup_range_n == thread_num, "Given thread number should equal to pre-set value 32!"); std::cout << "group_num_x: " << group_range_n @@ -231,7 +242,6 @@ void sdp_fwd_run(uint32_t iter) { cl::sycl::range<3> local_range {1, subgroup_range_m, subgroup_range_n}; cl::sycl::nd_range<3> nd_range(group_range * local_range, local_range); - constexpr uint32_t warmup = 10; int64_t ops = int64_t(4 * batch_num * head_num * sequence_len) * sequence_len * head_size; profiling_helper prof("sdp", ops, "gflops"); @@ -239,14 +249,13 @@ void sdp_fwd_run(uint32_t iter) { for (uint32_t i = 0; i < iter + warmup; i++) { if (i >= warmup) { prof.cpu_start(); } auto gpu_event = queue.submit([&](handler &cgh) { - cgh.parallel_for< - class Test>(nd_range, [=](nd_item<3> item) KERNEL_MAIN { + cgh.parallel_for(nd_range, [=](nd_item<3> item) KERNEL_MAIN { using namespace gpu::xetla; using namespace gpu::xetla::group; using namespace gpu::xetla::kernel; using namespace gpu::xetla::subgroup; - uint32_t batch_id = item.get_group(0); + const uint32_t batch_id = item.get_group(0); // disable sync in gemm static constexpr uint32_t periodic_sync_interval = 0; static constexpr uint32_t prefetch_distance = 3; @@ -254,19 +263,23 @@ void sdp_fwd_run(uint32_t iter) { using wg_shape0 = shape; using sg_shape0 = shape; - using post_op0_t = scalar_mul_op_t; + using post_op0_t = scalar_mul_op_t; using post_op1_t = elemwise_reduce_op_t; + dtype_in, arch_tag>; using post_op_t = chained_tile_op_t; using epilogue_policy0 = xetla::group::epilogue_policy_tile_op; - using group_swizzle = group_swizzle_default; - - using tune_option0 = dict_t< - elem_v_t, + arch_tag>; + using group_swizzle = group_swizzle_default; + + using elem_opt_mode_t + = elem_v_t; + using elem_opt_type_t = elem_v_t< + tune_key::param_optimizer_type, + tune_key_value::param_optimizer_decision_tree>; + using tune_option0 = dict_t< // + elem_opt_type_t, elem_opt_mode_t, elem_t_t, elem_t_t, @@ -285,10 +298,10 @@ void sdp_fwd_run(uint32_t iter) { 8, // leading dimension for B, in unit of element mem_space:: global, // memory reading from global mem for B - float, // accumulator data type for intermediate resutls + float, // accumulator data type for intermediate results wg_shape0, // computation tile shape wg_tile_k_qk, // elements in each iteration - gpu_arch::Dg2, // GPU arch + arch_tag, // GPU arch tune_option0>; using epilogue0_t = xetla::group::default_epilogue_selector_t< dtype_sfx, // onput datatype for C @@ -298,7 +311,7 @@ void sdp_fwd_run(uint32_t iter) { local, // memory writing to local mem for C wg_shape0, // computation tile shape wg_tile_k_qk, // elements in each iteration - gpu_arch::Dg2, // GPU arch + arch_tag, // GPU arch tune_option0>; using gemm_op0_t = gemm_universal_t< dispatch_policy_default, gemm0_t, @@ -307,29 +320,27 @@ void sdp_fwd_run(uint32_t iter) { using tile_shape0 = typename gemm0_t::tile_shape; // initialize SLM size - constexpr uint32_t slm_size - = wg_tile_m_qk * wg_tile_n_qk * sizeof(dtype_sfx); xetla_local_init(); // initialize named barrier count // we only need to do thread sync while store gemm results to SLM // one barrier is enough for that xetla_nbarrier_init<1>(); - xetla_nbarrier_t - nbarrier; + xetla_nbarrier_t nbarrier; nbarrier.init_nbarrier(0, nbarrier_role::producer_consumer); // initialize gemm op: gemm result store to shared local memory - typename post_op0_t::arguments_t post_op0_arg(0.125); + typename post_op0_t::arguments_t post_op0_arg(scale_qk); typename post_op1_t::arguments_t post_op1_arg( + // attn_mask pre-load ptr batch offset attn_mask + batch_id / head_num * size_mask + wg_tile_m_qk * wg_tile_n_qk - * item.get_group( - 1), // attn_mask pre-load ptr batch offset - {matrix_n_qk, // attn_mask tdesc width + * item.get_group(1), + { + matrix_n_qk, // attn_mask tdesc width matrix_m_qk, // attn_mask tdesc height - matrix_n_qk} // attn_mask tdesc pitch - ); + matrix_n_qk, // attn_mask tdesc pitch + }); typename gemm_op0_t::arguments_t arg0(matrix_m_qk, matrix_k_qk, matrix_n_qk, q + batch_id * size_qkv, // matA_ptr + batch offset @@ -339,22 +350,20 @@ void sdp_fwd_run(uint32_t iter) { 0, // matC_base matrix_n_qk, // matC load width {{post_op0_arg, post_op1_arg}}); - gemm_op0_t gemm_op0; - gemm_op0(item, arg0); + gemm_op0_t {}(item, arg0); xetla_fence(); nbarrier.arrive_wait(); // softmax start: result store to SLM using softmax_op_t = xetla_softmax_fwd_t; + mem_space::local, SIMD, thread_num, softmax_sz, + arch_tag>; typename softmax_op_t::arguments_t arg1; - softmax_op_t softmax_op; - arg1.data_in_base = 0; arg1.data_out_base = 0; - softmax_op(item, &arg1); + softmax_op_t {}(item, &arg1); xetla_fence(); nbarrier.arrive_wait(); @@ -362,10 +371,8 @@ void sdp_fwd_run(uint32_t iter) { using wg_shape1 = shape; using sg_shape1 = shape; - using tune_option1 = dict_t< - elem_v_t, + using tune_option1 = dict_t< // + elem_opt_type_t, elem_opt_mode_t, elem_t_t, elem_v_t, @@ -383,10 +390,10 @@ void sdp_fwd_run(uint32_t iter) { 8, // leading dimension for B, in unit of element mem_space:: global, // memory reading from global mem for B - float, // accumulator data type for intermediate resutls + float, // accumulator data type for intermediate results wg_shape1, // computation tile shape wg_tile_k_sv, // elements in each iteration - gpu_arch::Dg2, // GPU arch + arch_tag, // GPU arch tune_option1>; // gemm arguments include matA & matB load information and @@ -395,6 +402,8 @@ void sdp_fwd_run(uint32_t iter) { using work_group_t = typename gemm1_t::work_group_t; using mem_desc_a_t = typename gemm1_t::mem_desc_a_t; using mem_desc_b_t = typename gemm1_t::mem_desc_b_t; + using mem_desc_c_t = mem_desc_t; // Using gemm::matAcc init a matC class for future storage using matAcc_t = typename gemm1_t::matAcc_t; using matC_t = tile_t matrix_n - ? matrix_n - : (start_n + wg_tile_n_sv); + uint32_t boundary_n + = std::min(start_n + wg_tile_n_sv, matrix_n); uint32_t boundary_k = wg_tile_k; work_group_t g; @@ -431,42 +439,45 @@ void sdp_fwd_run(uint32_t iter) { mem_desc_b.init(matB_ptr, {boundary_n, boundary_k, matB_ld}, {start_n, start_k}); - uint32_t inner_loop_count + uint32_t sg_k_count = (wg_tile_k + wg_tile_k_sv - 1) / wg_tile_k_sv; - gemm_args_t gemm_args( - mem_desc_a, mem_desc_b, inner_loop_count); + gemm_args_t gemm_args(mem_desc_a, mem_desc_b, sg_k_count); matAcc_t matAcc; - matC_t matC; - gemm1_t gemm; matAcc.init(0); - gemm(g, matAcc, gemm_args); + gemm1_t {}(g, matAcc, gemm_args); + // permute store + matC_t matC; subgroup::elemwise_cvt(matC, matAcc); - xetla_tdescriptor transpose_tdecs; - // Define a temprary vector as output buffer - xetla_vector out_reg; // Calculate new coordination of each element - uint32_t b = item.get_group(0) / head_num; - uint32_t n = item.get_group(0) % head_num; - uint32_t f = start_m + gemm1_t::get_matC_offset_y(g); - uint32_t h = start_n + gemm1_t::get_matC_offset_x(g); - - // transpose 8 * 16 tile and store to global - for (uint32_t j = 0; j < sg_tile_m_sv; ++j, ++f) { - uint32_t dst_offset - = b * head_num * sequence_len * head_size - + f * head_num * head_size + n * head_size; - out_reg = matC.reg.xetla_select( - j * sg_tile_n_sv); - xetla_fill_tdesc( - transpose_tdecs.xetla_format(), - out + dst_offset, head_size, 1, head_size, h, - 0); - xetla_tstore_global(transpose_tdecs, out_reg); - } + const uint32_t b = batch_id / head_num; + const uint32_t n = batch_id % head_num; + const uint32_t batch_offset + = b * head_num * sequence_len * head_size + + start_m * head_num * head_size + n * head_size + + start_n; + const uint32_t f = gemm1_t::get_matC_offset_y(g); + const uint32_t h = gemm1_t::get_matC_offset_x(g); + + const auto ld_c = head_num * head_size; + mem_desc_c_t mem_desc_c; + mem_desc_c.init( + out + batch_offset, // dst_base = out_ptr + wg offset + { + std::min(h + sg_tile_n_sv, wg_tile_n_sv), + std::min(f + sg_tile_m_sv, wg_tile_m_sv), + ld_c, + }, + {int(h), int(f)}); + + constexpr auto msg_type_c = msg_type::block_2d; + using mat_tile_desc = typename matC_t::tile_desc; + using matC_payload_t = subgroup::mem_payload_t; + matC_payload_t matC_payload(mem_desc_c); + subgroup::tile_store(matC, matC_payload); }); }); gpu_event.wait(); @@ -488,7 +499,7 @@ void sdp_fwd_run(uint32_t iter) { mem_layout::col_major, mem_layout::row_major, mem_layout::row_major)); - //performance + // performance prof.print_profiling_result(profiling_selector::GPU); free(q, context); @@ -498,28 +509,41 @@ void sdp_fwd_run(uint32_t iter) { free(out, context); } +template +struct main_wrapper { + static constexpr auto exec = []() { + if constexpr (arch_tag == gpu_arch::Dg2) { + sdp_fwd_run(10); + } else { + sdp_fwd_run(10); + } + }; +}; + int main() { // This example implements scaled-dot-production with batch_size: 16, - // num_heads: 16, sequence_lenth: 512, head_size: 64. It will be shown how to + // num_heads: 16, sequence_length: 512, head_size: 64. It will be shown how to // remap the index space of each work-item used for gemm1, softmax and gemm2. // Description: - // Scaled-dot-production mechanism can be seen as two chained batch MatMul with - // a softmax in the middle layer. It can be descripted as following + // Scaled-dot-production mechanism can be seen as two chained batch MatMul + // with a softmax in the middle layer. It can be described as following // mathematical expression: - // softmax(Q · (K.transpose(-1, -2)) * (1 / sqr_root(num_heads)) + attn_mask) · V + // softmax(Q · (K.transpose(-1, -2)) * (1 / sqr_root(num_heads)) + + // attn_mask) · V // where: // Q, K, V: input data // shape(Q) = [16 x 16, 512, 64] // shape(K) = [16 x 16, 512, 64] // shape(V) = [16 x 16, 512, 64] // shape(attn_mask) = [16, 512, 512] + // shape(DST) = [16, 512, 16, 64] // This kernel is designed to execute the following task: // 1: S = (Q · (K.transpose(-1, -2))) * (1 / sqr_root(num_heads)) + attn_mask // 2: S' = softmax(S) // 3: O = S' · V - sdp_fwd_run(10); + dispatch_arch::exec(); return 0; } diff --git a/examples/08_scaled_dot_product_attention/softmax.hpp b/examples/08_scaled_dot_product_attention/softmax.hpp index 58fb1c688..0fc04b8aa 100644 --- a/examples/08_scaled_dot_product_attention/softmax.hpp +++ b/examples/08_scaled_dot_product_attention/softmax.hpp @@ -24,7 +24,7 @@ using namespace gpu::xetla::subgroup; template + uint32_t thread_num_, uint32_t softmax_size_, gpu_arch arch_tag> struct xetla_softmax_fwd_t { using dtype_in = dtype_in_; using dtype_out = dtype_out_; @@ -56,16 +56,14 @@ struct xetla_softmax_fwd_t { using softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_t, softmax_tile_desc_t, - subgroup::msg_type_v, - gpu_arch::Dg2>; + subgroup::msg_type_v, arch_tag>; // this tile will store the softmax result to global memory using softmax_store_t = subgroup::tile_t; using softmax_store_payload_t = subgroup::mem_payload_t< mem_desc_t, softmax_tile_desc_t, - subgroup::msg_type_v, - gpu_arch::Dg2>; + subgroup::msg_type_v, arch_tag>; struct arguments_t { // available while original data is from SLM @@ -113,10 +111,10 @@ struct xetla_softmax_fwd_t { row_data_32 = softmax_load.reg.xetla_select(0); // get max - float xmax = hmax(row_data_32); + float x_max = hmax(row_data_32); // get exp_sum - row_data_32 -= xmax; + row_data_32 -= x_max; row_data_32 = exp(row_data_32); float exp_sum = sum(row_data_32); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 1f50f5d7d..193696628 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,6 +1,11 @@ include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}) +# Creates a separate device code module for each SYCL* kernel +# so that kernel for Dg2 and Xe will be JIT separately +add_compile_options(-fsycl-device-code-split=per_kernel) +add_link_options(-fsycl-device-code-split=per_kernel) + add_subdirectory(01_gemm_universal) add_subdirectory(02_basic_gemm) add_subdirectory(03_gemm_relu_bias) @@ -13,4 +18,4 @@ add_subdirectory(09_gate_recurrent_unit) add_subdirectory(10_gemm_large_n) if(UNIX) # pvc not available on win? add_subdirectory(11_stream_k_gemm) -endif() \ No newline at end of file +endif() diff --git a/include/common/utils/raw_send_nbarrier.hpp b/include/common/utils/raw_send_nbarrier.hpp index 5050b9c51..7bde822b0 100644 --- a/include/common/utils/raw_send_nbarrier.hpp +++ b/include/common/utils/raw_send_nbarrier.hpp @@ -41,8 +41,12 @@ enum class nbarrier_role : uint8_t { /// as consumer. /// template -struct xetla_nbarrier_t { + gpu_arch arch_tag = gpu_arch::Xe, typename enable = void> +struct xetla_nbarrier_t; + +template +struct xetla_nbarrier_t> { /// /// @brief Description of named barrier objection. /// Structure is defined in @@ -87,8 +91,9 @@ struct xetla_nbarrier_t { } }; -template -struct xetla_nbarrier_t { +template +struct xetla_nbarrier_t> { /// /// @brief Description of named barrier objection. /// Structure is defined in @@ -106,24 +111,15 @@ struct xetla_nbarrier_t { /// @brief Generic work-group split barrier. /// - __XETLA_API void arrive() { - // __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::signal>(); - __ESIMD_NS::barrier(); - } + __XETLA_API void arrive() { __ESIMD_NS::barrier(); } /// @brief named barrier wait within subgroup. /// - __XETLA_API void wait() { - // __ESIMD_ENS::split_barrier<__ESIMD_ENS::split_barrier_action::wait>(); - __ESIMD_NS::barrier(); - } + __XETLA_API void wait() { __ESIMD_NS::barrier(); } /// @brief named barrier signal from subgroup. /// - __XETLA_API void arrive_wait() { - arrive(); - wait(); - } + __XETLA_API void arrive_wait() { __ESIMD_NS::barrier(); } }; /// @} xetla_util_named_barrier diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index d22c45829..aa4343a28 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -55,7 +55,6 @@ struct compute_policy_default_xmx +using data_type_a_t = + typename T::template find_elem_t::type; +template +using data_type_b_t = + typename T::template find_elem_t::type; +template +using data_type_c_t = + typename T::template find_elem_t::type; +template +constexpr auto memory_layout_a_v + = T::template find_elem_v; +template +constexpr auto memory_alignment_a_v + = T::template find_elem_v; +template +constexpr auto memory_layout_b_v + = T::template find_elem_v; +template +constexpr auto memory_alignment_b_v + = T::template find_elem_v; +template +constexpr auto memory_layout_c_v + = T::template find_elem_v; +template +constexpr auto memory_alignment_c_v + = T::template find_elem_v; +template +constexpr auto gpu_arch_v = T::template find_elem_v; enum class tune_key_value : uint8_t { pre_processing_default, @@ -68,45 +98,24 @@ enum class tune_key_value : uint8_t { // parameter optimizer enum class param_optimizer_tag : uint8_t { kernel, work_group }; +enum class param_optimizer_mode : uint8_t { full, keep_shape }; template struct param_optimizer; struct param_optimizer_base { template - struct validate_attribute { - static constexpr bool value = []() constexpr { - bool valid = true; - valid &= std::is_same::type, - typename U::template find_elem_t< - tune_key::data_type_a>::type>::value; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= std::is_same::type, - typename U::template find_elem_t< - tune_key::data_type_b>::type>::value; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= std::is_same::type, - typename U::template find_elem_t< - tune_key::data_type_c>::type>::value; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - valid &= T::template find_elem_v == U::template find_elem_v; - return valid; - } - (); - }; + static constexpr bool valid_attribute_v + = std::is_same_v, data_type_a_t> // + && memory_layout_a_v == memory_layout_a_v // + && memory_alignment_a_v == memory_alignment_a_v // + && std::is_same_v, data_type_b_t> // + && memory_layout_b_v == memory_layout_b_v // + && memory_alignment_b_v == memory_alignment_b_v // + && std::is_same_v, data_type_c_t> // + && memory_layout_c_v == memory_layout_c_v // + && memory_alignment_c_v == memory_alignment_c_v // + && gpu_arch_v == gpu_arch_v; }; // parameter adaptor diff --git a/include/kernel/default_config/decision_tree_policy.hpp b/include/kernel/default_config/decision_tree_policy.hpp index c8b0b3c21..f9d89fbd5 100644 --- a/include/kernel/default_config/decision_tree_policy.hpp +++ b/include/kernel/default_config/decision_tree_policy.hpp @@ -264,53 +264,47 @@ struct kslicing_handler { }; } // namespace decision_tree_rule -template +template struct fallback_optimizer { - using type = typename opt_dict_t_::template update_t< - elem_t_t::type>, - elem_t_t::type>, - elem_t_t::type>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>, - elem_v_t>>; + using type = typename opt_dict::template update_t< + elem_t_t>, + elem_t_t>, + elem_t_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>, + elem_v_t>>; }; -template +template struct decision_tree_optimizer : param_optimizer_base { struct impl { - using type = typename dict_t_ ::template update_generator_t< - decision_tree_rule::data_type_handler>:: - template update_generator_t< - decision_tree_rule::tile_shape_handler>:: - template update_generator_t< - decision_tree_rule::kslicing_handler>; + template typename G> + using apply_handeler = T::template update_generator_t; + static constexpr bool keep_shape + = (mode_ == param_optimizer_mode::keep_shape); + + using t0 = dict_t_; + using t1 = apply_handeler; + using t2_0 = apply_handeler; + using t2 = std::conditional_t; + using t3 = apply_handeler; + + using type = t3; + + // If any of data_type / mem_layout / mem_align is changed, + // then change it back via fallback_optimizer using fallback_type = fallback_optimizer; }; static constexpr bool use_fallback - = !(param_optimizer_base::template validate_attribute::value); - using type = typename std::conditional::type::type; + = !(param_optimizer_base::template valid_attribute_v); + using type = typename std::conditional_t::type; }; } // namespace gpu::xetla diff --git a/include/kernel/default_config/dummy_policy.hpp b/include/kernel/default_config/dummy_policy.hpp index 7bed9f2c9..b18c882e7 100644 --- a/include/kernel/default_config/dummy_policy.hpp +++ b/include/kernel/default_config/dummy_policy.hpp @@ -255,8 +255,8 @@ struct dummy_optimizer : param_optimizer_base { using fallback_type = fallback_optimizer; }; static constexpr bool use_fallback - = !(param_optimizer_base::template validate_attribute::value); + = !(param_optimizer_base::template valid_attribute_v); using type = typename std::conditional::type::type; }; diff --git a/include/kernel/gemm/default_gemm.hpp b/include/kernel/gemm/default_gemm.hpp index 502b28076..ffea4a0b4 100644 --- a/include/kernel/gemm/default_gemm.hpp +++ b/include/kernel/gemm/default_gemm.hpp @@ -73,8 +73,11 @@ struct param_optimizer { param_optimizer_type> != dict_t_::impl::key_not_found) && (dict_t_::template find_elem_v == tune_key_value::param_optimizer_decision_tree); + static constexpr auto optimizer_mode + = dict_t_::template find_elem_v; using type = typename std::conditional, + decision_tree_optimizer, dummy_optimizer { param_optimizer_type> != dict_t_::impl::key_not_found) && (dict_t_::template find_elem_v == tune_key_value::param_optimizer_decision_tree); + static constexpr auto optimizer_mode + = dict_t_::template find_elem_v; using type = typename std::conditional, + decision_tree_optimizer, dummy_optimizer>::type::type; }; diff --git a/include/kernel/gemm/gemm_preset.hpp b/include/kernel/gemm/gemm_preset.hpp index fde7f21ab..d62d8c2b0 100644 --- a/include/kernel/gemm/gemm_preset.hpp +++ b/include/kernel/gemm/gemm_preset.hpp @@ -71,7 +71,9 @@ using default_param_t = dict_t<>::template update_dict_t< elem_t_t>, elem_t_t>, elem_v_t>; + tune_key_value::param_optimizer_dummy>, + elem_v_t>; namespace kernel { using param_kslicing_g1l1_t = default_param_t::template update_t< diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 302167448..bd4934a1f 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1308,7 +1308,7 @@ struct prefetch_payload_t< tile_desc_t, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ == gpu_arch::Dg2 + std::enable_if_t<(arch_tag_ <= gpu_arch::Dg2 && (tile_size_y_ != 1 || block_size_y_ != 1))>> { using dtype = dtype_; using mem_desc_t diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 8a6ceaf16..10c859631 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -17,9 +17,9 @@ #pragma once #include "kernel_func.hpp" +#include #include #include -#include class TestBase { public: @@ -41,26 +41,6 @@ class TestBase { static constexpr mma_engine engine = mma_engine::xmx; }; -class Test : public TestBase { -public: - static constexpr size_t mat_m = 256; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 32; - static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = fp16; - using data_type_c = fp16; - using data_type_acc = float; -}; - class Test0 : public TestBase { public: static constexpr size_t mat_m = 256; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index adaef295d..98fdb9572 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -29,8 +29,8 @@ template struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 0; + static constexpr uint32_t periodic_sync_interval = 8; + static constexpr uint32_t prefetch_distance = 3; using compute_attr = typename std::conditional<(engine == mma_engine::fpu), compute_attr_t, @@ -40,9 +40,9 @@ struct fp16_gemm_test_func { using compute_policy = typename std::conditional<(engine == mma_engine::fpu), compute_policy_default_fpu, + gpu_arch::Xe>, compute_policy_default_xmx>::type; + gpu_arch::Xe>>::type; using mem_desc_input_a = mem_desc_t; using mem_desc_input_b = mem_desc_t; @@ -52,11 +52,11 @@ struct fp16_gemm_test_func { using gemm_t = gemm_t; - using epilogue_t = epilogue_t, + using epilogue_t = epilogue_t, tile_shape, mem_desc_output_c>; using group_swizzle - = gpu::xetla::kernel::group_swizzle_default; + = gpu::xetla::kernel::group_swizzle_default; using dispatch_policy = dispatch_policy_kslicing; diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp index 3d85114da..66519472e 100644 --- a/tests/utils/execution.hpp +++ b/tests/utils/execution.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include "common.hpp" #include "profiling.hpp" #include "xetla.hpp" @@ -89,11 +90,13 @@ void gemm_exec(const std::string &compile_str, size_t batch = 1) { std::vector kernelId = {get_kernel_id()}; auto inputBundle = get_kernel_bundle(context, kernelId); - static const std::string env_set_str = "SYCL_PROGRAM_COMPILE_OPTIONS="+compile_str; - putenv(const_cast(env_set_str.c_str())); + static const std::string env_set_str + = "SYCL_PROGRAM_COMPILE_OPTIONS=" + compile_str; + putenv(const_cast(env_set_str.c_str())); kernel_bundle exeBundle = build(inputBundle); - static const std::string env_unset_str = "SYCL_PROGRAM_COMPILE_OPTIONS="; - putenv(const_cast(env_unset_str.c_str())); + static const std::string env_unset_str + = "SYCL_PROGRAM_COMPILE_OPTIONS="; + putenv(const_cast(env_unset_str.c_str())); using namespace gpu::xetla::group; using namespace gpu::xetla::kernel; @@ -227,3 +230,95 @@ void kernel_run(auto nd_range, auto validate_result) { free(B_host); free(C_host); } + +template