diff --git a/sycl-fusion/jit-compiler/include/Hashing.h b/sycl-fusion/jit-compiler/include/Hashing.h new file mode 100644 index 0000000000000..4491e2c13bf03 --- /dev/null +++ b/sycl-fusion/jit-compiler/include/Hashing.h @@ -0,0 +1,49 @@ +//==---- Hashing.h - helper for hashes for JIT internal representations ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H +#define SYCL_FUSION_JIT_COMPILER_HASHING_H + +#include "Parameter.h" + +#include "llvm/ADT/Hashing.h" + +#include +#include + +namespace jit_compiler { +inline llvm::hash_code hash_value(const ParameterInternalization &P) { + return llvm::hash_combine(P.LocalSize, P.Intern, P.Param); +} + +inline llvm::hash_code hash_value(const Parameter &P) { + return llvm::hash_combine(P.ParamIdx, P.KernelIdx); +} + +inline llvm::hash_code hash_value(const JITConstant &C) { + return llvm::hash_combine(C.Param, C.Value); +} + +inline llvm::hash_code hash_value(const ParameterIdentity &IP) { + return llvm::hash_combine(IP.LHS, IP.RHS); +} +} // namespace jit_compiler + +namespace std { +template inline llvm::hash_code hash_value(const vector &V) { + return llvm::hash_combine_range(V.begin(), V.end()); +} + +template struct hash> { + size_t operator()(const tuple &Tuple) const noexcept { + return llvm::hash_value(Tuple); + } +}; +} // namespace std + +#endif // SYCL_FUSION_JIT_COMPILER_HASHING_H diff --git a/sycl-fusion/jit-compiler/include/JITContext.h b/sycl-fusion/jit-compiler/include/JITContext.h index b082d85063c01..965fbdd24569f 100644 --- a/sycl-fusion/jit-compiler/include/JITContext.h +++ b/sycl-fusion/jit-compiler/include/JITContext.h @@ -12,8 +12,12 @@ #include #include #include +#include +#include #include +#include +#include "Hashing.h" #include "Kernel.h" #include "Parameter.h" @@ -23,6 +27,10 @@ class LLVMContext; namespace jit_compiler { +using CacheKeyT = + std::tuple, ParamIdentList, int, + std::vector, std::vector>; + /// /// Wrapper around a SPIR-V binary. class SPIRVBinary { @@ -51,6 +59,10 @@ class JITContext { SPIRVBinary &emplaceSPIRVBinary(std::string Binary); + std::optional getCacheEntry(CacheKeyT &Identifier) const; + + void addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel); + private: // FIXME: Change this to std::shared_mutex after switching to C++17. using MutexT = std::shared_timed_mutex; @@ -64,6 +76,10 @@ class JITContext { MutexT BinariesMutex; std::vector Binaries; + + mutable MutexT CacheMutex; + + std::unordered_map Cache; }; } // namespace jit_compiler diff --git a/sycl-fusion/jit-compiler/include/Options.h b/sycl-fusion/jit-compiler/include/Options.h index fde8cc4e3298c..335f58fb64cf7 100644 --- a/sycl-fusion/jit-compiler/include/Options.h +++ b/sycl-fusion/jit-compiler/include/Options.h @@ -76,6 +76,8 @@ namespace option { struct JITEnableVerbose : public OptionBase {}; +struct JITEnableCaching : public OptionBase {}; + } // namespace option } // namespace jit_compiler diff --git a/sycl-fusion/jit-compiler/lib/JITContext.cpp b/sycl-fusion/jit-compiler/lib/JITContext.cpp index fba490e78670d..68c7031b9d8a9 100644 --- a/sycl-fusion/jit-compiler/lib/JITContext.cpp +++ b/sycl-fusion/jit-compiler/lib/JITContext.cpp @@ -33,3 +33,18 @@ SPIRVBinary &JITContext::emplaceSPIRVBinary(std::string Binary) { Binaries.emplace_back(std::move(Binary)); return Binaries.back(); } + +std::optional +JITContext::getCacheEntry(CacheKeyT &Identifier) const { + ReadLockT ReadLock{CacheMutex}; + auto Entry = Cache.find(Identifier); + if (Entry != Cache.end()) { + return Entry->second; + } + return {}; +} + +void JITContext::addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel) { + WriteLockT WriteLock{CacheMutex}; + Cache.emplace(Identifier, Kernel); +} diff --git a/sycl-fusion/jit-compiler/lib/KernelFusion.cpp b/sycl-fusion/jit-compiler/lib/KernelFusion.cpp index 3ccf828b5a7c7..a29c8522d32e7 100644 --- a/sycl-fusion/jit-compiler/lib/KernelFusion.cpp +++ b/sycl-fusion/jit-compiler/lib/KernelFusion.cpp @@ -49,6 +49,19 @@ FusionResult KernelFusion::fuseKernels( // available (on a per-thread basis). ConfigHelper::setConfig(std::move(JITConfig)); + bool CachingEnabled = ConfigHelper::get(); + CacheKeyT CacheKey{KernelsToFuse, Identities, BarriersFlags, Internalization, + Constants}; + if (CachingEnabled) { + std::optional CachedKernel = JITCtx.getCacheEntry(CacheKey); + if (CachedKernel) { + helper::printDebugMessage("Re-using cached JIT kernel"); + return FusionResult{*CachedKernel, /*Cached*/ true}; + } + helper::printDebugMessage( + "Compiling new kernel, no suitable cached kernel found"); + } + SYCLModuleInfo ModuleInfo; // Copy the kernel information for the input kernels to the module // information. We could remove the copy, if we removed the const from the @@ -115,5 +128,9 @@ FusionResult KernelFusion::fuseKernels( FusedBinaryInfo.BinaryStart = SPIRVBin->address(); FusedBinaryInfo.BinarySize = SPIRVBin->size(); + if (CachingEnabled) { + JITCtx.addCacheEntry(CacheKey, FusedKernelInfo); + } + return FusionResult{FusedKernelInfo}; } diff --git a/sycl/doc/EnvironmentVariables.md b/sycl/doc/EnvironmentVariables.md index b695765a9f327..718044a40bc16 100644 --- a/sycl/doc/EnvironmentVariables.md +++ b/sycl/doc/EnvironmentVariables.md @@ -24,6 +24,7 @@ compiler and runtime. | `SYCL_USM_HOSTPTR_IMPORT` | Integer | Enable by specifying non-zero value. Buffers created with a host pointer will result in host data promotion to USM, improving data transfer performance. To use this feature, also set SYCL_HOST_UNIFIED_MEMORY=1. | | `SYCL_EAGER_INIT` | Integer | Enable by specifying non-zero value. Tells the SYCL runtime to do as much as possible initialization at objects construction as opposed to doing lazy initialization on the fly. This may mean doing some redundant work at warmup but ensures fastest possible execution on the following hot and reportable paths. It also instructs PI plugins to do the same. Default is "0". | | `SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE` | See [below](#sycl_reduction_preferred_workgroup_size) | Controls the preferred work-group size of reductions. | +| `SYCL_ENABLE_FUSION_CACHING` | '1' or '0' | Enable ('1') or disable ('0') caching of JIT compilations for kernel fusion. Caching avoids repeatedly running the JIT compilation pipeline if the same sequence of kernels is fused multiple times. Default value is '1'. | `(*) Note: Any means this environment variable is effective when set to any non-null value.` diff --git a/sycl/source/detail/config.def b/sycl/source/detail/config.def index 5dbe3b540d1e7..34e80375226f6 100644 --- a/sycl/source/detail/config.def +++ b/sycl/source/detail/config.def @@ -39,3 +39,4 @@ CONFIG(SYCL_QUEUE_THREAD_POOL_SIZE, 4, __SYCL_QUEUE_THREAD_POOL_SIZE) CONFIG(SYCL_RT_WARNING_LEVEL, 4, __SYCL_RT_WARNING_LEVEL) CONFIG(SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE, 16, __SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE) CONFIG(ONEAPI_DEVICE_SELECTOR, 1024, __ONEAPI_DEVICE_SELECTOR) +CONFIG(SYCL_ENABLE_FUSION_CACHING, 1, __SYCL_ENABLE_FUSION_CACHING) diff --git a/sycl/source/detail/config.hpp b/sycl/source/detail/config.hpp index 7322b8cfab643..7a21efb515b23 100644 --- a/sycl/source/detail/config.hpp +++ b/sycl/source/detail/config.hpp @@ -579,6 +579,34 @@ template <> class SYCLConfig { } }; +template <> class SYCLConfig { + using BaseT = SYCLConfigBase; + +public: + static bool get() { + constexpr bool DefaultValue = true; + + const char *ValStr = getCachedValue(); + + if (!ValStr) + return DefaultValue; + + return ValStr[0] == '1'; + } + + static void reset() { (void)getCachedValue(/*ResetCache=*/true); } + + static const char *getName() { return BaseT::MConfigName; } + +private: + static const char *getCachedValue(bool ResetCache = false) { + static const char *ValStr = BaseT::getRawValue(); + if (ResetCache) + ValStr = BaseT::getRawValue(); + return ValStr; + } +}; + #undef INVALID_CONFIG_EXCEPTION } // namespace detail diff --git a/sycl/source/detail/jit_compiler.cpp b/sycl/source/detail/jit_compiler.cpp index e1ee8e1c94301..4e4ab478fab0f 100644 --- a/sycl/source/detail/jit_compiler.cpp +++ b/sycl/source/detail/jit_compiler.cpp @@ -751,7 +751,8 @@ jit_compiler::fuseKernels(QueueImplPtr Queue, bool DebugEnabled = detail::SYCLConfig::get() > 0; JITConfig.set<::jit_compiler::option::JITEnableVerbose>(DebugEnabled); - // TODO: Enable caching in a separate PR. + JITConfig.set<::jit_compiler::option::JITEnableCaching>( + detail::SYCLConfig::get()); auto FusionResult = ::jit_compiler::KernelFusion::fuseKernels( *MJITContext, std::move(JITConfig), InputKernelInfo, InputKernelNames,