diff --git a/BUILD.bazel b/BUILD.bazel index e8a1bbfe..7edb6fa8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -27,13 +27,6 @@ exports_files([ ".github/workflows/build.yml", ]) -# To enable OneDNN BRGeMM support, build with: -# bazel build --define gemma_onednn_brgemm=1 ... -config_setting( - name = "gemma_onednn_brgemm", - define_values = {"gemma_onednn_brgemm": "1"}, -) - cc_library( name = "basics", srcs = ["util/basics.cc"], @@ -321,17 +314,7 @@ test_suite( cc_library( name = "matmul_env", srcs = ["ops/matmul.cc"], - hdrs = [ - "ops/brgemm.h", - "ops/matmul.h", - ], - defines = select({ - ":gemma_onednn_brgemm": [ - "GEMMA_ONEDNN_BRGEMM=1", - "DNNL_EXPERIMENTAL_UKERNEL", - ], - "//conditions:default": [], - }), + hdrs = ["ops/matmul.h"], deps = [ ":allocator", ":basics", @@ -342,20 +325,14 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", - ] + select({ - ":gemma_onednn_brgemm": ["@onednn"], - "//conditions:default": [], - }), + ], ) cc_library( name = "matmul", # allow depending only on this target, without also matmul_env. hdrs = ["ops/matmul.h"], - textual_hdrs = [ - "ops/brgemm-inl.h", - "ops/matmul-inl.h", - ], + textual_hdrs = ["ops/matmul-inl.h"], deps = [ ":allocator", ":basics", @@ -369,10 +346,7 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", - ] + select({ - ":gemma_onednn_brgemm": ["@onednn"], - "//conditions:default": [], - }), + ], ) cc_library( @@ -389,7 +363,6 @@ cc_library( "ops/matmul_static.h", ], textual_hdrs = [ - "ops/brgemm-inl.h", "ops/matmul_static-inl.h", "ops/matmul-inl.h", ], @@ -406,10 +379,7 @@ cc_library( "@highway//:hwy", "@highway//:profiler", "@highway//:timer", - ] + select({ - ":gemma_onednn_brgemm": ["@onednn"], - "//conditions:default": [], - }), + ], ) cc_library( diff --git a/CMakeLists.txt b/CMakeLists.txt index 6791fbcc..a68a3e13 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,10 +22,6 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -# Optional: OneDNN BRGeMM micro-kernel support (x86-64 only). -# Enable with: cmake -DGEMMA_ONEDNN_BRGEMM=ON ... -option(GEMMA_ONEDNN_BRGEMM "Enable OneDNN BRGeMM micro-kernel for MatMul (x86-64)" OFF) - if(EMSCRIPTEN) add_compile_options("-sMEMORY64") add_compile_options("-msimd128") @@ -89,23 +85,6 @@ if(EMSCRIPTEN) target_compile_options(benchmark PRIVATE -Wno-c2y-extensions) endif() -# OneDNN BRGeMM micro-kernel support (optional, x86-64 only). -if(GEMMA_ONEDNN_BRGEMM) - set(DNNL_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(DNNL_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) - set(DNNL_CPU_RUNTIME "SEQ" CACHE STRING "" FORCE) - set(DNNL_GPU_RUNTIME "NONE" CACHE STRING "" FORCE) - set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "" FORCE) - set(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) - FetchContent_Declare(onednn - GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git - GIT_TAG v3.11 - EXCLUDE_FROM_ALL - ) - FetchContent_MakeAvailable(onednn) - message(STATUS "OneDNN BRGeMM micro-kernel support enabled") -endif() - # Base source files set(SOURCES compression/compress-inl.h @@ -164,8 +143,6 @@ set(SOURCES ops/matmul-inl.h ops/matmul.cc ops/matmul.h - ops/brgemm.h - ops/brgemm-inl.h ops/ops-inl.h ops/ops.h ops/sum-inl.h @@ -218,10 +195,6 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static) target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) -if(GEMMA_ONEDNN_BRGEMM) - target_compile_definitions(libgemma PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL) - target_link_libraries(libgemma dnnl) -endif() install(TARGETS libgemma DESTINATION lib) # Shared library target for C# interop @@ -246,10 +219,6 @@ target_compile_definitions(gemma_shared $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX> ) target_compile_options(gemma_shared PRIVATE $<$:-Wno-deprecated-declarations>) -if(GEMMA_ONEDNN_BRGEMM) - target_compile_definitions(gemma_shared PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL) - target_link_libraries(gemma_shared PRIVATE dnnl) -endif() install(TARGETS gemma_shared DESTINATION lib) install(FILES gemma/c_api.h DESTINATION include/gemma) install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma) diff --git a/MODULE.bazel b/MODULE.bazel index e44c370b..0dea7752 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -25,17 +25,6 @@ git_override( http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -# OneDNN v3.11 for BRGeMM micro-kernel support (optional, x86-64 only). -http_archive( - name = "onednn", - build_file = "@//bazel:onednn.BUILD", - sha256 = "04df98b18300daf6c3aa7cc2d5e7ce8a8f430fed1787151daed0254d8dd4e64e", - strip_prefix = "oneDNN-3.11", - urls = [ - "https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v3.11.tar.gz", - ], -) - http_archive( name = "com_google_absl_py", sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758", diff --git a/bazel/onednn.BUILD b/bazel/onednn.BUILD deleted file mode 100644 index 0cbd436d..00000000 --- a/bazel/onednn.BUILD +++ /dev/null @@ -1,227 +0,0 @@ -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") - -exports_files(["LICENSE"]) - -expand_template( - name = "dnnl_config_h", - out = "include/oneapi/dnnl/dnnl_config.h", - substitutions = { - "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1", - "#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP", - "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ", - "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_SEQ", - "#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS", - "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", - "#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE", - "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE", - "#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL", - "#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO", - "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", - "#cmakedefine DNNL_SYCL_GENERIC": "#undef DNNL_SYCL_GENERIC", - "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", - "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", - "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH", - "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE", - "#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING", - "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", - "#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER", - "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", - "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", - "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", - "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", - "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0", - "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0", - "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0", - "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0", - "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0", - "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0", - "#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1", - "#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0", - "#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1", - "#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1", - "#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1", - "#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1", - "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0", - "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0", - "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0", - "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0", - "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0", - "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0", - "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0", - "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0", - "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0", - "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0", - "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0", - "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0", - "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0", - "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1", - "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0", - "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0", - "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0", - "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0", - "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0", - "#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0", - "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0", - "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0", - "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0", - "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0", - "#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1", - "#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0", - }, - template = "include/oneapi/dnnl/dnnl_config.h.in", -) - -expand_template( - name = "dnnl_version_h", - out = "include/oneapi/dnnl/dnnl_version.h", - substitutions = { - "@DNNL_VERSION_MAJOR@": "3", - "@DNNL_VERSION_MINOR@": "11", - "@DNNL_VERSION_PATCH@": "0", - }, - template = "include/oneapi/dnnl/dnnl_version.h.in", -) - -expand_template( - name = "dnnl_version_hash_h", - out = "include/oneapi/dnnl/dnnl_version_hash.h", - substitutions = { - "@DNNL_VERSION_HASH@": "fc6151651a4577beae5ffac5a4132e75d39e1409", - }, - template = "include/oneapi/dnnl/dnnl_version_hash.h.in", -) - -cc_library( - name = "onednn_autogen", - srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]), - copts = [ - "-O1", - "-U_FORTIFY_SOURCE", - "-fexceptions", - "-UUSE_MKL", - "-UUSE_CBLAS", - "-DDNNL_ENABLE_MAX_CPU_ISA", - "-DDNNL_ENABLE_ITT_TASKS", - "-DDNNL_ENABLE_GRAPH_DUMP", - "-DDNNL_EXPERIMENTAL_UKERNEL", - ], - includes = [ - "include", - "src", - "src/common", - "src/cpu", - "src/cpu/gemm", - "src/graph", - "third_party", - "third_party/ittnotify", - "third_party/xbyak", - ], - textual_hdrs = glob([ - "include/**/*", - "src/common/*.hpp", - "src/cpu/*.hpp", - "src/cpu/**/*.hpp", - "src/cpu/jit_utils/**/*.hpp", - "src/graph/interface/*.hpp", - "src/graph/backend/*.hpp", - "src/graph/backend/dnnl/*.hpp", - "src/graph/backend/dnnl/executables/*.hpp", - "src/graph/backend/fake/*.hpp", - "src/graph/backend/dnnl/passes/*.hpp", - "src/graph/backend/dnnl/patterns/*.hpp", - "src/graph/backend/dnnl/kernels/*.hpp", - "src/graph/utils/*.hpp", - "src/graph/utils/pm/*.hpp", - "third_party/ittnotify/**/*.h", - "third_party/spdlog/**/*.h", - "third_party/xbyak/*.h", - ]) + [ - ":dnnl_config_h", - ":dnnl_version_h", - ":dnnl_version_hash_h", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "onednn", - srcs = glob( - [ - "src/common/*.cpp", - "src/cpu/*.cpp", - "src/cpu/**/*.cpp", - "src/cpu/jit_utils/**/*.cpp", - "src/cpu/x64/**/*.cpp", - "src/graph/interface/*.cpp", - "src/graph/backend/*.cpp", - "src/graph/backend/dnnl/*.cpp", - "src/graph/backend/dnnl/executables/*.cpp", - "src/graph/backend/fake/*.cpp", - "src/graph/backend/dnnl/passes/*.cpp", - "src/graph/backend/dnnl/patterns/*.cpp", - "src/graph/backend/dnnl/kernels/*.cpp", - "src/graph/utils/*.cpp", - "src/graph/utils/pm/*.cpp", - "third_party/ittnotify/*.c", - ], - exclude = [ - "src/cpu/aarch64/**", - "src/cpu/rv64/**", - "src/cpu/ppc64/**", - "src/cpu/s390x/**", - "src/cpu/x64/gemm/**/*_kern_autogen.cpp", - "src/cpu/sycl/**", - ], - ), - copts = [ - "-fexceptions", - "-UUSE_MKL", - "-UUSE_CBLAS", - "-DDNNL_ENABLE_MAX_CPU_ISA", - "-DDNNL_ENABLE_ITT_TASKS", - "-DDNNL_ENABLE_GRAPH_DUMP", - "-DDNNL_EXPERIMENTAL_UKERNEL", - ], - includes = [ - "include", - "src", - "src/common", - "src/cpu", - "src/cpu/gemm", - "src/graph", - "third_party", - "third_party/ittnotify", - "third_party/xbyak", - ], - linkopts = [ - "-lrt", - "-Wl,--allow-multiple-definition", - ], - textual_hdrs = glob([ - "include/**/*", - "src/common/*.hpp", - "src/cpu/*.hpp", - "src/cpu/**/*.hpp", - "src/cpu/jit_utils/**/*.hpp", - "src/graph/interface/*.hpp", - "src/graph/backend/*.hpp", - "src/graph/backend/dnnl/*.hpp", - "src/graph/backend/fake/*.hpp", - "src/graph/backend/dnnl/passes/*.hpp", - "src/graph/backend/dnnl/patterns/*.hpp", - "src/graph/backend/dnnl/kernels/*.hpp", - "src/graph/utils/*.hpp", - "src/graph/utils/pm/*.hpp", - "third_party/ittnotify/**/*.h", - "third_party/spdlog/**/*.h", - "third_party/xbyak/*.h", - ]) + [ - ":dnnl_config_h", - ":dnnl_version_h", - ":dnnl_version_hash_h", - ], - visibility = ["//visibility:public"], - deps = [ - ":onednn_autogen", - ], -) diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index e9432276..67c702f5 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -130,11 +130,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { keep += hwy::ConvertScalarTo(C.Row(0)[hwy::Unpredictable1()]); // Only record times after autotuning finished. - bool done = per_key->autotune.Best(); -#if GEMMA_ONEDNN_BRGEMM - done = done || per_key->brgemm_autotune.Best(); -#endif - if (done) times.push_back(elapsed); + if (per_key->autotune.Best()) times.push_back(elapsed); } hwy::PreventElision(keep); env.ctx.pools.MaybeStopSpinning(use_spinning); diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h deleted file mode 100644 index 04f8f1cf..00000000 --- a/ops/brgemm-inl.h +++ /dev/null @@ -1,560 +0,0 @@ -// Copyright 2026 DeepMind Technologies Limited. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// BRGeMM dispatch for BF16 MatMul on Intel AMX/AVX-512. - -#include -#include - -#include -#include -#include - -#include "ops/brgemm.h" -#include "ops/matmul.h" -#include "util/mat.h" -#include "util/threading_context.h" -#include "util/zones.h" -#include "hwy/base.h" - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_BRGEMM_TOGGLE -#endif - -#include "hwy/highway.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - -#if GEMMA_ONEDNN_BRGEMM - -static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, - int64_t k, int64_t batch, int64_t lda, int64_t ldb, - int64_t ldc, dnnl::memory::data_type a_dt, - dnnl::memory::data_type b_dt, - dnnl::memory::data_type c_dt, bool add_C) { - try { - brg = dnnl::ukernel::brgemm(m, n, k, batch, lda, ldb, ldc, a_dt, b_dt, c_dt, - true); - if (!brg) { - HWY_WARN("BRGeMM: kernel creation failed m=%lld n=%lld k=%lld.", - static_cast(m), static_cast(n), - static_cast(k)); - return false; - } - brg.set_add_C(add_C); - if (!brg.finalize()) { - HWY_WARN("BRGeMM: kernel finalize failed m=%lld n=%lld k=%lld.", - static_cast(m), static_cast(n), - static_cast(k)); - return false; - } - brg.generate(); - return true; - } catch (...) { - HWY_WARN("BRGeMM: kernel JIT exception m=%lld n=%lld k=%lld.", - static_cast(m), static_cast(n), - static_cast(k)); - return false; - } -} - -// JIT-compiles brgemm kernels, B-packing transforms, and offset tables for -// the given matrix dimensions and tiling config. Returns false on failure. -static HWY_NOINLINE bool InitBRGeMMKernels(const BRGeMMConfig& cfg, size_t M, - size_t K, size_t N, size_t lda, - size_t ldb_orig, - BRGeMMKernelEntry& ke) { - using dnnl::ukernel::brgemm; - using dnnl::ukernel::pack_type; - using dnnl::ukernel::transform; - - ke.K_blk = cfg.K_blk; - ke.N_blk = cfg.N_blk; - ke.M_blk = std::min(cfg.M_blk, M); - ke.div_M_blk = hwy::Divisor(ke.M_blk); - ke.div_N_blk = hwy::Divisor(ke.N_blk); - ke.div_K_blk = hwy::Divisor(ke.K_blk); - - ke.M_tail = ke.div_M_blk.Remainder(M); - ke.N_tail = ke.div_N_blk.Remainder(N); - ke.K_tail = ke.div_K_blk.Remainder(K); - - // Floor division: K_tail remainder is handled by a dedicated brg_ktail - // kernel rather than padding K, avoiding extra memory writes to zero-pad - // A and B along the K dimension. - ke.K_chunks = ke.div_K_blk.Divide(K); - ke.N_full_tiles = ke.div_N_blk.Divide(N); - ke.M_full_tiles = ke.div_M_blk.Divide(M); - ke.N_total_tiles = ke.N_full_tiles + (ke.N_tail ? 1 : 0); - ke.M_total_tiles = ke.M_full_tiles + (ke.M_tail ? 1 : 0); - ke.N_padded = ke.N_total_tiles * ke.N_blk; - - if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0 || - (ke.K_chunks == 0 && ke.K_tail == 0)) { - return false; - } - - ke.K_super_size = std::min(cfg.batch_size, ke.K_chunks); - ke.K_super_blocks = (ke.K_chunks > 0) ? ke.K_chunks / ke.K_super_size : 0; - ke.K_super_rem = (ke.K_chunks > 0) ? ke.K_chunks % ke.K_super_size : 0; - ke.batch_full = ke.K_super_size; - ke.batch_rem = ke.K_super_rem; - - const auto a_dt = dnnl::memory::data_type::bf16; - const auto b_dt = dnnl::memory::data_type::bf16; - const auto c_dt = dnnl::memory::data_type::f32; - ke.a_dt_size = dnnl::memory::data_type_size(a_dt); - ke.b_dt_size = dnnl::memory::data_type_size(b_dt); - - const auto pack = brgemm::get_B_pack_type(a_dt, b_dt); - if (pack == pack_type::undef) return false; - ke.need_pack = (pack != pack_type::no_trans); - - ke.lda = lda; - ke.ldb_orig = ldb_orig; - - // Indexed by tail flag: [0] = full tile size, [1] = tail size (or full if - // no tail). Separate kernels are JIT-compiled for full vs. tail tile widths - // along both M and N dimensions. - ke.m_sizes[0] = ke.M_blk; - ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk; - ke.n_sizes[0] = ke.N_blk; - ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk; - const int64_t ldb_for[2] = { - static_cast(ke.N_blk), - static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; - const int64_t ldc_for[2] = { - static_cast(ke.N_blk), - static_cast(ke.N_tail ? ke.N_tail : ke.N_blk)}; - - // JIT a brgemm kernel for each (mi, ni) where mi/ni indicate whether we - // are processing the M-tail or N-tail: 0 = full block, 1 = tail block. - // Skipped when the corresponding tail is zero (no partial tile exists). - size_t max_sp = 0; - for (int mi = 0; mi < 2; ++mi) { - for (int ni = 0; ni < 2; ++ni) { - if (mi == 1 && ke.M_tail == 0) continue; - if (ni == 1 && ke.N_tail == 0) continue; - if (mi == 0 && ke.M_full_tiles == 0) continue; - if (ni == 0 && ke.N_full_tiles == 0) continue; - - const int64_t ms = static_cast(ke.m_sizes[mi]); - const int64_t ns = static_cast(ke.n_sizes[ni]); - - if (ke.K_chunks > 0) { - if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, - static_cast(ke.K_blk), - static_cast(ke.K_super_size), - static_cast(ke.lda), ldb_for[ni], ldc_for[ni], - a_dt, b_dt, c_dt, false)) { - return false; - } - max_sp = - std::max(max_sp, ke.brg_first_all[mi][ni].get_scratchpad_size()); - } - if (ke.K_super_blocks > 1) { - if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, - static_cast(ke.K_blk), - static_cast(ke.batch_full), - static_cast(ke.lda), ldb_for[ni], ldc_for[ni], - a_dt, b_dt, c_dt, true)) { - return false; - } - max_sp = std::max(max_sp, ke.brg_full[mi][ni].get_scratchpad_size()); - } - if (ke.K_super_rem > 0) { - const bool rem_is_first = (ke.K_super_blocks == 0); - auto& target = - rem_is_first ? ke.brg_first_rem[mi][ni] : ke.brg_rem[mi][ni]; - if (!MakeBrgemm(target, ms, ns, static_cast(ke.K_blk), - static_cast(ke.batch_rem), - static_cast(ke.lda), ldb_for[ni], ldc_for[ni], - a_dt, b_dt, c_dt, !rem_is_first)) { - return false; - } - max_sp = std::max(max_sp, target.get_scratchpad_size()); - } - if (ke.K_tail > 0) { - const bool add_c = (ke.K_chunks > 0); - if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, - static_cast(ke.K_tail), 1, - static_cast(ke.lda), ldb_for[ni], ldc_for[ni], - a_dt, b_dt, c_dt, add_c)) { - return false; - } - max_sp = std::max(max_sp, ke.brg_ktail[mi][ni].get_scratchpad_size()); - } - } - } - ke.scratchpad_size = max_sp + 64; - - // Create B-packing transforms. - if (ke.need_pack) { - for (int ni = 0; ni < 2; ++ni) { - if (ni == 1 && ke.N_tail == 0) continue; - if (ni == 0 && ke.N_full_tiles == 0) continue; - - const int64_t ns = static_cast(ke.n_sizes[ni]); - if (ke.K_chunks > 0) { - const int64_t K_full = static_cast(ke.K_chunks * ke.K_blk); - try { - ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, - static_cast(ke.ldb_orig), - ldb_for[ni], b_dt, b_dt); - if (!ke.pack_B[ni]) return false; - ke.pack_B[ni].generate(); - ke.blocked_B_size[ni] = - static_cast(ldb_for[ni]) * ke.K_blk * ke.b_dt_size; - } catch (...) { - return false; - } - } - if (ke.K_tail > 0) { - try { - ke.pack_B_ktail[ni] = transform( - static_cast(ke.K_tail), ns, pack_type::trans, - static_cast(ke.ldb_orig), ldb_for[ni], b_dt, b_dt); - if (!ke.pack_B_ktail[ni]) return false; - ke.pack_B_ktail[ni].generate(); - ke.blocked_B_ktail_size[ni] = - static_cast(ldb_for[ni]) * ke.K_tail * ke.b_dt_size; - } catch (...) { - return false; - } - } - } - } - - // Precompute A/B offset tables for each K-super-block. - for (int ni = 0; ni < 2; ++ni) { - if (ni == 1 && ke.N_tail == 0) continue; - if (ni == 0 && ke.N_full_tiles == 0) continue; - const size_t cur_n = ke.n_sizes[ni]; - - if (ke.K_chunks > 0) { - ke.offsets_first_all[ni].resize(ke.K_super_size); - for (size_t i = 0; i < ke.K_super_size; ++i) { - const int64_t a_off = static_cast(i * ke.K_blk * ke.a_dt_size); - const int64_t b_off = - ke.need_pack - ? static_cast(i * ke.blocked_B_size[ni]) - : static_cast(i * cur_n * ke.K_blk * ke.b_dt_size); - ke.offsets_first_all[ni][i] = {a_off, b_off}; - } - } - - if (ke.K_super_blocks > 1) { - ke.offsets_full[ni].resize(ke.K_super_blocks - 1); - for (size_t ks = 1; ks < ke.K_super_blocks; ++ks) { - auto& tbl = ke.offsets_full[ni][ks - 1]; - tbl.resize(ke.batch_full); - const size_t k_start = ks * ke.K_super_size; - for (size_t i = 0; i < ke.batch_full; ++i) { - const size_t k_idx = k_start + i; - const int64_t a_off = - static_cast(k_idx * ke.K_blk * ke.a_dt_size); - const int64_t b_off = - ke.need_pack ? static_cast(k_idx * ke.blocked_B_size[ni]) - : static_cast(k_idx * cur_n * ke.K_blk * - ke.b_dt_size); - tbl[i] = {a_off, b_off}; - } - } - } - - if (ke.K_super_rem > 0) { - const size_t k_base = ke.K_super_blocks * ke.K_super_size; - auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni] - : ke.offsets_rem[ni]; - rem_tbl.resize(ke.K_super_rem); - for (size_t i = 0; i < ke.K_super_rem; ++i) { - const size_t k_idx = k_base + i; - const int64_t a_off = - static_cast(k_idx * ke.K_blk * ke.a_dt_size); - const int64_t b_off = - ke.need_pack - ? static_cast(k_idx * ke.blocked_B_size[ni]) - : static_cast(k_idx * cur_n * ke.K_blk * ke.b_dt_size); - rem_tbl[i] = {a_off, b_off}; - } - } - } - - return true; -} - -template -static HWY_NOINLINE bool DoMatMul_BRGeMM( - const MatPtrT& A, const MatPtrT& B, RowPtrs C, size_t M, - size_t K, size_t N, float scale, const float* HWY_RESTRICT add, - const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) { - using dnnl::ukernel::brgemm; - - // Level-1 cache: kernels keyed on (M, K, N, config). - const BRGeMMKernelKey kern_key{ - M, K, N, cfg.M_blk, cfg.N_blk, cfg.K_blk, cfg.batch_size}; - auto& kern_cache = GetBRGeMMKernelCache(); - auto kern_it = kern_cache.find(kern_key); - - if (kern_it == kern_cache.end()) { - BRGeMMKernelEntry ke; - if (!InitBRGeMMKernels(cfg, M, K, N, A.Stride(), B.Stride(), ke)) { - return false; - } - kern_it = kern_cache.emplace(kern_key, std::move(ke)).first; - } - - BRGeMMKernelEntry& ke = kern_it->second; - - // Level-2 cache: packed B keyed on (B_ptr, K, N, config). - const uint8_t* A_base = reinterpret_cast(A.Row(0)); - const uint8_t* B_base = reinterpret_cast(B.Row(0)); - - const BRGeMMPackedBKey pb_key{reinterpret_cast(B_base), K, N, - ke.K_blk, ke.N_blk}; - auto& pb_cache = GetBRGeMMPackedBCache(); - auto pb_it = pb_cache.find(pb_key); - - if (pb_it == pb_cache.end()) { - BRGeMMPackedBEntry pe; - pe.B_tile_offset.resize(ke.N_total_tiles, 0); - pe.B_ktail_offset.resize(ke.N_total_tiles, 0); - - if (ke.need_pack) { - size_t total_packed = 0; - for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) { - const int ni = (nt < ke.N_full_tiles) ? 0 : 1; - pe.B_tile_offset[nt] = total_packed; - if (ke.K_chunks > 0) - total_packed += ke.blocked_B_size[ni] * ke.K_chunks; - pe.B_ktail_offset[nt] = total_packed; - if (ke.K_tail > 0) total_packed += ke.blocked_B_ktail_size[ni]; - } - - pe.B_packed_buf.Resize(total_packed); - uint8_t* B_packed = pe.B_packed_buf.data(); - if (!B_packed) { - HWY_WARN("BRGeMM: packed B allocation failed."); - return false; - } - - for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) { - const int ni = (nt < ke.N_full_tiles) ? 0 : 1; - const size_t b_row = - (nt < ke.N_full_tiles) ? nt * ke.N_blk : ke.N_full_tiles * ke.N_blk; - const uint8_t* B_in = B_base + b_row * ke.ldb_orig * ke.b_dt_size; - - try { - if (ke.K_chunks > 0) { - ke.pack_B[ni].execute(const_cast(B_in), - B_packed + pe.B_tile_offset[nt]); - } - if (ke.K_tail > 0) { - const uint8_t* B_in_ktail = - B_in + ke.K_chunks * ke.K_blk * ke.b_dt_size; - ke.pack_B_ktail[ni].execute(const_cast(B_in_ktail), - B_packed + pe.B_ktail_offset[nt]); - } - } catch (...) { - HWY_WARN("BRGeMM: B-packing execution failed."); - return false; - } - } - } - - pb_it = pb_cache.emplace(pb_key, std::move(pe)).first; - } - - const BRGeMMPackedBEntry& pe = pb_it->second; - const uint8_t* B_packed = ke.need_pack ? pe.B_packed_buf.data() : nullptr; - - std::vector> offsets_ktail(1); - if (ke.K_tail > 0) offsets_ktail[0] = {0, 0}; - - // Execute one (m, n) tile for a given K-super-block. - const auto execute_tile = [&](size_t m_start, size_t n_start, size_t k_super, - float* temp_C, uint8_t* scratch) HWY_ATTR { - const size_t m_tile_idx = ke.div_M_blk.Divide(m_start); - const size_t n_tile_idx = ke.div_N_blk.Divide(n_start); - const int mi = (m_tile_idx < ke.M_full_tiles) ? 0 : 1; - const int ni = (n_tile_idx < ke.N_full_tiles) ? 0 : 1; - const size_t cur_m = ke.m_sizes[mi]; - const size_t cur_n = ke.n_sizes[ni]; - - const size_t real_m = (m_tile_idx < ke.M_full_tiles) - ? m_tile_idx * ke.M_blk - : ke.M_full_tiles * ke.M_blk; - const size_t real_n = (n_tile_idx < ke.N_full_tiles) - ? n_tile_idx * ke.N_blk - : ke.N_full_tiles * ke.N_blk; - - const uint8_t* A_tile = A_base + real_m * ke.lda * ke.a_dt_size; - const void* B_tile = - ke.need_pack - ? static_cast(B_packed + pe.B_tile_offset[n_tile_idx]) - : static_cast(B_base + - real_n * ke.ldb_orig * ke.b_dt_size); - - float* C_tile_ptr = temp_C; - const size_t k_total = ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); - - if (k_super < ke.K_super_blocks) { - if (k_super == 0) { - ke.brg_first_all[mi][ni].execute(A_tile, const_cast(B_tile), - ke.offsets_first_all[ni], C_tile_ptr, - scratch); - } else { - ke.brg_full[mi][ni].execute(A_tile, const_cast(B_tile), - ke.offsets_full[ni][k_super - 1], - C_tile_ptr, scratch); - } - } else if (ke.K_super_rem > 0 && k_super == ke.K_super_blocks) { - if (ke.K_super_blocks == 0) { - ke.brg_first_rem[mi][ni].execute(A_tile, const_cast(B_tile), - ke.offsets_first_rem[ni], C_tile_ptr, - scratch); - } else { - ke.brg_rem[mi][ni].execute(A_tile, const_cast(B_tile), - ke.offsets_rem[ni], C_tile_ptr, scratch); - } - } - - const bool is_last = (k_total > 0) ? (k_super == k_total - 1) : true; - if (is_last) { - if (ke.K_tail > 0) { - const uint8_t* A_ktail = A_tile + ke.K_chunks * ke.K_blk * ke.a_dt_size; - const void* B_ktail = - ke.need_pack - ? static_cast(B_packed + - pe.B_ktail_offset[n_tile_idx]) - : static_cast( - B_base + (real_n * ke.ldb_orig + ke.K_chunks * ke.K_blk) * - ke.b_dt_size); - ke.brg_ktail[mi][ni].execute(A_ktail, const_cast(B_ktail), - offsets_ktail, C_tile_ptr, scratch); - } - - // Scale and copy temp_C to output. - const hn::ScalableTag df; - const auto vscale = hn::Set(df, scale); - const size_t lanes = hn::Lanes(df); - for (size_t m = 0; m < cur_m; ++m) { - TC* C_row = C.Row(real_m + m) + real_n; - const float* t_row = C_tile_ptr + m * cur_n; - const float* add_row = add ? add + real_n : nullptr; - size_t n = 0; - if (add_row) { - for (; n + lanes <= cur_n; n += lanes) { - const auto v = hn::Load(df, t_row + n); - const auto va = hn::Load(df, add_row + n); - const auto result = hn::MulAdd(v, vscale, va); - if constexpr (hwy::IsSame()) { - hn::Store(result, df, HWY_RCAST_ALIGNED(float*, C_row) + n); - } else { - const hn::Rebind dc; - hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); - } - } - for (; n < cur_n; ++n) { - float val = t_row[n] * scale + add_row[n]; - C_row[n] = hwy::ConvertScalarTo(val); - } - } else { - for (; n + lanes <= cur_n; n += lanes) { - const auto v = hn::Load(df, t_row + n); - const auto result = hn::Mul(v, vscale); - if constexpr (hwy::IsSame()) { - hn::Store(result, df, HWY_RCAST_ALIGNED(float*, C_row) + n); - } else { - const hn::Rebind dc; - hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); - } - } - for (; n < cur_n; ++n) { - float val = t_row[n] * scale; - C_row[n] = hwy::ConvertScalarTo(val); - } - } - } - } - }; - - // Parallel dispatch: K-super outer, N middle, M inner (keeps B in L2). - const size_t k_total_supers = - ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); - const size_t k_iters = (k_total_supers > 0) ? k_total_supers : size_t{1}; - - const size_t num_threads = ctx.pools.MaxWorkersPerCluster(); - const size_t total_n_tiles = ke.N_total_tiles; - const size_t total_m_tiles = ke.M_total_tiles; - const size_t n_tasks = - std::max(size_t{1}, std::min(total_n_tiles, num_threads)); - - const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kBRGeMM); - - ParallelForWithinCluster( - n_tasks, ctx, cluster_idx, caller, - [&](uint64_t task_idx, size_t /*worker*/) HWY_ATTR { - const size_t tiles_per_task = total_n_tiles / n_tasks; - const size_t extra = total_n_tiles % n_tasks; - const size_t n_begin = task_idx * tiles_per_task + - std::min(static_cast(task_idx), extra); - const size_t n_end = - n_begin + tiles_per_task + (task_idx < extra ? 1 : 0); - - auto& tbufs = GetBRGeMMThreadBufs(); - tbufs.MaybeSetHwContext(ke.brg_first_all[0][0]); - uint8_t* sp = tbufs.EnsureScratch(ke.scratchpad_size); - - const size_t n_tiles_in_range = n_end - n_begin; - const size_t total_tc = total_m_tiles * n_tiles_in_range; - float* tc_base = tbufs.EnsureTempC(total_tc); - - for (size_t ks = 0; ks < k_iters; ++ks) { - size_t n_idx = 0; - for (size_t nt = n_begin; nt < n_end; ++nt) { - const size_t n = nt * ke.N_blk; - for (size_t mt = 0; mt < total_m_tiles; ++mt) { - const size_t m = mt * ke.M_blk; - float* temp_C = tc_base + (mt * n_tiles_in_range + n_idx) * - BRGeMMThreadBufs::kMaxTempCSize; - execute_tile(m, n, ks, temp_C, sp); - } - ++n_idx; - } - } - }); - - dnnl::ukernel::brgemm::release_hw_context(); - auto& main_bufs = GetBRGeMMThreadBufs(); - main_bufs.hw_ctx_kernel = nullptr; - return true; -} - -#endif // GEMMA_ONEDNN_BRGEMM - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT diff --git a/ops/brgemm.h b/ops/brgemm.h deleted file mode 100644 index f2c6eea4..00000000 --- a/ops/brgemm.h +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2026 DeepMind Technologies Limited. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512. -// Enabled at compile time via GEMMA_ONEDNN_BRGEMM=1 (Bazel: --define -// gemma_onednn_brgemm=1). - -#ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ -#define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ - -#include -#include - -// opt-in -#ifndef GEMMA_ONEDNN_BRGEMM -#define GEMMA_ONEDNN_BRGEMM 0 -#endif // GEMMA_ONEDNN_BRGEMM - -#if GEMMA_ONEDNN_BRGEMM -#include - -#include -#include -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" -#include "oneapi/dnnl/dnnl_ukernel.hpp" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#endif // GEMMA_ONEDNN_BRGEMM - -namespace gcpp { - -struct BRGeMMConfig { - size_t M_blk; - size_t N_blk = 32; - size_t K_blk = 32; - size_t batch_size; - size_t par_m; -}; - -#if GEMMA_ONEDNN_BRGEMM - -// Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16). -// Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}. -inline std::vector BRGeMMCandidates(size_t M, size_t K, - size_t N) { - std::vector out; - out.reserve(10); // At most 2 M_blk * 5 batch_size candidates. - static constexpr size_t kNBlk = 32; - static constexpr size_t kKBlk = 32; - static constexpr size_t kMBlkValues[] = {32, 64}; - static constexpr size_t kBatchValues[] = {16, 32, 64, 128, 256}; - - const size_t k_chunks = K / kKBlk; - for (size_t mb : kMBlkValues) { - if (mb > M) continue; - if (kNBlk > N) continue; - for (size_t bs : kBatchValues) { - const size_t eff_bs = (k_chunks > 0) ? std::min(bs, k_chunks) : size_t{1}; - bool dup = false; - for (const auto& c : out) { - if (c.M_blk == mb && c.batch_size == eff_bs) { - dup = true; - break; - } - } - if (dup) continue; - out.push_back({mb, kNBlk, kKBlk, eff_bs, /*par_m=*/1}); - } - } - if (out.empty()) { - out.push_back({std::min(M, size_t{32}), std::min(N, size_t{32}), 32, 1, 1}); - } - return out; -} - -// Hugepage-backed buffer via mmap with MADV_HUGEPAGE for packed-B matrices. -class HugePageBuffer { - public: - HugePageBuffer() = default; - ~HugePageBuffer() { - if (ptr_ && size_) munmap(ptr_, size_); - } - - HugePageBuffer(HugePageBuffer&& o) noexcept : ptr_(o.ptr_), size_(o.size_) { - o.ptr_ = nullptr; - o.size_ = 0; - } - HugePageBuffer& operator=(HugePageBuffer&& o) noexcept { - if (this != &o) { - if (ptr_ && size_) munmap(ptr_, size_); - ptr_ = o.ptr_; - size_ = o.size_; - o.ptr_ = nullptr; - o.size_ = 0; - } - return *this; - } - - HugePageBuffer(const HugePageBuffer&) = delete; - HugePageBuffer& operator=(const HugePageBuffer&) = delete; - - void Resize(size_t n) { - if (ptr_ && size_) munmap(ptr_, size_); - static constexpr size_t kHugePageSize = 2u << 20; - size_ = (n + kHugePageSize - 1) & ~(kHugePageSize - 1); - ptr_ = static_cast(mmap(nullptr, size_, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); - if (ptr_ == MAP_FAILED) { - ptr_ = nullptr; - size_ = 0; - return; - } - madvise(ptr_, size_, MADV_HUGEPAGE); - for (size_t off = 0; off < size_; off += kHugePageSize) { - ptr_[off] = 0; - hwy::PreventElision(ptr_[off]); - } - } - - uint8_t* data() { return ptr_; } - const uint8_t* data() const { return ptr_; } - size_t size() const { return size_; } - - private: - uint8_t* ptr_ = nullptr; - size_t size_ = 0; -}; - -// Kernel cache key: identifies a JIT-compiled kernel set. -struct BRGeMMKernelKey { - size_t M, K, N; - size_t M_blk, N_blk, K_blk, batch_size; - bool operator==(const BRGeMMKernelKey& o) const { - return M == o.M && K == o.K && N == o.N && M_blk == o.M_blk && - N_blk == o.N_blk && K_blk == o.K_blk && batch_size == o.batch_size; - } -}; - -struct BRGeMMKernelKeyHash { - size_t operator()(const BRGeMMKernelKey& k) const { - size_t h = 14695981039346656037ULL; - h = (h ^ k.M) * 1099511628211ULL; - h = (h ^ k.K) * 1099511628211ULL; - h = (h ^ k.N) * 1099511628211ULL; - h = (h ^ k.M_blk) * 1099511628211ULL; - h = (h ^ k.N_blk) * 1099511628211ULL; - h = (h ^ k.K_blk) * 1099511628211ULL; - h = (h ^ k.batch_size) * 1099511628211ULL; - return h; - } -}; - -// Cached JIT-compiled kernels with precomputed tile parameters and offsets. -struct BRGeMMKernelEntry { - size_t M_blk, N_blk, K_blk; - // Precomputed divisors for fast modulo/division by block sizes. - hwy::Divisor div_M_blk{1}; - hwy::Divisor div_N_blk{1}; - hwy::Divisor div_K_blk{1}; - size_t M_tail, N_tail, K_tail; - size_t K_chunks; - size_t M_full_tiles, N_full_tiles; - size_t M_total_tiles, N_total_tiles; - size_t K_super_size, K_super_blocks; - size_t K_super_rem; - size_t batch_full, batch_rem; - size_t m_sizes[2], n_sizes[2]; - size_t lda; - size_t ldb_orig; - bool need_pack; - size_t a_dt_size, b_dt_size; - size_t N_padded; - - // Kernels indexed by [m_tail_flag][n_tail_flag]. - dnnl::ukernel::brgemm brg_first_all[2][2]; - dnnl::ukernel::brgemm brg_full[2][2]; - dnnl::ukernel::brgemm brg_ktail[2][2]; - dnnl::ukernel::brgemm brg_first_rem[2][2]; - dnnl::ukernel::brgemm brg_rem[2][2]; - - // B-packing transforms indexed by n_tail_flag. - dnnl::ukernel::transform pack_B[2], pack_B_ktail[2]; - size_t blocked_B_size[2] = {0, 0}; - size_t blocked_B_ktail_size[2] = {0, 0}; - - size_t scratchpad_size = 0; - - using OffsetVec = - std::vector>; - OffsetVec offsets_first_all[2]; - std::vector offsets_full[2]; - OffsetVec offsets_first_rem[2]; - OffsetVec offsets_rem[2]; -}; - -// Packed-B cache key. -struct BRGeMMPackedBKey { - uintptr_t B_ptr; - size_t K, N; - size_t K_blk, N_blk; - bool operator==(const BRGeMMPackedBKey& o) const { - return B_ptr == o.B_ptr && K == o.K && N == o.N && K_blk == o.K_blk && - N_blk == o.N_blk; - } -}; - -struct BRGeMMPackedBKeyHash { - size_t operator()(const BRGeMMPackedBKey& k) const { - size_t h = 14695981039346656037ULL; - h = (h ^ k.B_ptr) * 1099511628211ULL; - h = (h ^ k.K) * 1099511628211ULL; - h = (h ^ k.N) * 1099511628211ULL; - h = (h ^ k.K_blk) * 1099511628211ULL; - h = (h ^ k.N_blk) * 1099511628211ULL; - return h; - } -}; - -struct BRGeMMPackedBEntry { - HugePageBuffer B_packed_buf; - std::vector B_tile_offset; - std::vector B_ktail_offset; -}; - -// Thread-local buffers for BRGeMM parallel dispatch. -struct BRGeMMThreadBufs { - static constexpr size_t kMaxTempCSize = 64 * 64; - - hwy::AlignedVector scratch; - hwy::AlignedVector tc_storage; - const void* hw_ctx_kernel = nullptr; - - uint8_t* EnsureScratch(size_t size) { - if (scratch.size() < size) scratch.resize(size); - return scratch.data(); - } - - float* EnsureTempC(size_t n_tiles) { - const size_t need = n_tiles * kMaxTempCSize * sizeof(float); - if (tc_storage.size() < need) tc_storage.resize(need); - return reinterpret_cast(tc_storage.data()); - } - - void MaybeSetHwContext(const dnnl::ukernel::brgemm& brg) { - const void* brg_ptr = &brg; - if (hw_ctx_kernel != brg_ptr) { - brg.set_hw_context(); - hw_ctx_kernel = brg_ptr; - } - } -}; - -inline BRGeMMThreadBufs& GetBRGeMMThreadBufs() { - static thread_local BRGeMMThreadBufs bufs; - return bufs; -} - -// Singleton caches. Thread-safety: MatMul is not called concurrently per env. -inline auto& GetBRGeMMKernelCache() { - static std::unordered_map - cache; - return cache; -} - -inline auto& GetBRGeMMPackedBCache() { - static std::unordered_map - cache; - return cache; -} - -#endif // GEMMA_ONEDNN_BRGEMM - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index fa7d11e5..4b217a15 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -41,9 +41,6 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" -#if GEMMA_ONEDNN_BRGEMM -#include "ops/brgemm-inl.h" -#endif // GEMMA_ONEDNN_BRGEMM HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -1080,48 +1077,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = MMImpl::FindOrAddPerKey( M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); -#if GEMMA_ONEDNN_BRGEMM - // BRGeMM path for BF16×BF16 on Intel AMX/AVX-512. - // Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint). - if constexpr (IsBF16() && IsBF16()) { - if (M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) { - const float scale = A.Scale() * B.Scale(); - MMAutoTune& brg_tuner = per_key.brgemm_autotune; - - if (HWY_LIKELY(brg_tuner.Best())) { - if (DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, - *brg_tuner.Best(), env.ctx, cluster_idx)) { - return &per_key; - } - // BRGeMM failed; fall through to standard matmul. - } - - if (HWY_UNLIKELY(!brg_tuner.HasCandidates())) { - brg_tuner.SetCandidates(BRGeMMCandidates(M, K, N)); - } - - const BRGeMMConfig& cfg = brg_tuner.NextConfig(); - const uint64_t t0 = hwy::timer::Start(); - if (DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx, - cluster_idx)) { - const uint64_t t1 = - env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - brg_tuner.NotifyTicks(t1 - t0); - - if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) { - const BRGeMMConfig& best = *brg_tuner.Best(); - fprintf(stderr, - "BRGeMM best: %zux%zux%zu M_blk=%zu N_blk=%zu K_blk=%zu " - "batch=%zu\n", - M, K, N, best.M_blk, best.N_blk, best.K_blk, best.batch_size); - } - return &per_key; - } - // BRGeMM failed; fall through to standard matmul. - } - } // if constexpr BF16/float -#endif // GEMMA_ONEDNN_BRGEMM - // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); diff --git a/ops/matmul.h b/ops/matmul.h index f0d95d1e..8d75f2bb 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -24,7 +24,6 @@ #include // IWYU pragma: begin_exports -#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM #include "util/basics.h" #include "util/mat.h" #include "util/threading.h" @@ -640,9 +639,6 @@ class MMKeys { struct MMPerKey { MMAutoTune autotune; MMAutoTune autotune_par_a; -#if GEMMA_ONEDNN_BRGEMM - MMAutoTune brgemm_autotune; -#endif // GEMMA_ONEDNN_BRGEMM }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive diff --git a/util/zones.cc b/util/zones.cc index b552bb17..aec4bbd0 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -135,8 +135,6 @@ const char* CallerName(Callers caller) { return "Att.DotSoftmaxWeightedSum"; case Callers::kBlobWriter: return "BlobWriter"; - case Callers::kBRGeMM: - return "BRGeMM"; case Callers::kCompress: return "Compress"; case Callers::kFixupWeights: diff --git a/util/zones.h b/util/zones.h index ba3d5a9b..64b859d2 100644 --- a/util/zones.h +++ b/util/zones.h @@ -81,7 +81,6 @@ enum class Callers { // Keep sorted kAttComputeQKV, kAttDotSoftmaxWeightedSum, kBlobWriter, - kBRGeMM, kCompress, kFixupWeights, kFlashAttention,