diff --git a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp index b144a1f77..791df0542 100644 --- a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp +++ b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp @@ -1323,7 +1323,7 @@ class DequantGemmTPP< static_assert(N % 16 == 0, "N must be a multiple of 16"); if (std::is_same()) TLA_ASSERT(K % 2 == 0, "Kb must be a multiple of 2 for bfloat16"); - pgemm = std::make_shared>( + pgemm = new BrgemmTPP( M, N, K, @@ -1338,6 +1338,10 @@ class DequantGemmTPP< /*b_vnni*/ std::is_same()); } + ~DequantGemmTPP() { + delete pgemm; + } + inline void operator()( Tin* A, uint8_t* qB, @@ -1430,7 +1434,7 @@ class DequantGemmTPP< } private: - std::shared_ptr> pgemm; + BrgemmTPP* pgemm; long M; long K; long lda; @@ -1466,7 +1470,7 @@ class DequantGemmTPP< static_assert(N % 16 == 0, "N must be a multiple of 16"); TLA_ASSERT(K % 4 == 0, "Kb must be a multiple of 4 for int8 VNNI"); // TODO(jgong5): output fp32 directly - pgemm = std::make_shared( + pgemm = new TBrgemmTPP( M, N, K, @@ -1481,6 +1485,10 @@ class DequantGemmTPP< /*b_vnni*/ true); } + ~DequantGemmTPP() { + delete pgemm; + } + inline void operator()( uint8_t* A, uint8_t* qB, @@ -1620,7 +1628,7 @@ class DequantGemmTPP< } private: - std::shared_ptr pgemm; + TBrgemmTPP* pgemm; long M; long K; long lda; @@ -1654,8 +1662,8 @@ void qlinear_woq_affine_impl( const TensorList& others_list, int64_t quant_block_k, const std::optional& zps = std::nullopt, // dtype is TComp - at::Tensor t_scale_a = at::empty({1}, at::kFloat), - at::Tensor t_zp_a = at::empty({1}, at::kInt)) { + float* scales_a_ptr = nullptr, + int32_t* zps_a_ptr = nullptr) { const bool is_4bit_flag = is_4bit(qw_type); const bool sym_quant = is_sym_quant(qw_type); auto x_sizes = x.sizes(); @@ -1708,8 +1716,6 @@ void qlinear_woq_affine_impl( auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb; auto ldc = (no_y_buf || k_splits > 1) ? ldy : Nb; - auto scales_a_ptr = t_scale_a.data_ptr(); - auto zps_a_ptr = t_zp_a.data_ptr(); auto px = GetVLAPtr(x, {Kc, Kb}); auto pw = GetVLAPtr( (uint8_t*)qw_packed.data_ptr(), {Kc, Kb * (is_4bit_flag ? Nb / 2 : Nb)}); @@ -3273,8 +3279,6 @@ at::Tensor qlinear_woq_affine( x_reshape_contig, &scale_a, &zp_a); auto x_quantized = quantize_per_tensor( x_reshape_contig, scale_a, zp_a); - auto scale_a_t = at::full({1}, scale_a, at::kFloat); - auto zp_a_t = at::full({1}, zp_a, at::kInt); qlinear_woq_affine_impl< uint8_t, uint8_t, @@ -3296,8 +3300,8 @@ at::Tensor qlinear_woq_affine( others_list, quant_block_k, zp_list[int8_idx], - scale_a_t, - zp_a_t); + &scale_a, + &zp_a); } else { auto block_k = w_sizes[2]; if (quant_block_k <= 0) @@ -3312,6 +3316,8 @@ at::Tensor qlinear_woq_affine( zp_a, quant_block_k, quant_a_mode); + float* scale_a_ptr = (float*)scale_a.data_ptr(); + int32_t* zp_a_ptr = (int32_t*)zp_a.data_ptr(); range_dispatcher< long, QUANT_A_PER_K_BLOCK, @@ -3340,8 +3346,8 @@ at::Tensor qlinear_woq_affine( others_list, quant_block_k, zp_list[int8_idx], - scale_a, - zp_a); + scale_a_ptr, + zp_a_ptr); }, [&](auto quant_a_mode_) { failing_fallback(); }); } diff --git a/csrc/cpu/tpp/xsmm_functors.h b/csrc/cpu/tpp/xsmm_functors.h index 1a6241575..a9b449755 100644 --- a/csrc/cpu/tpp/xsmm_functors.h +++ b/csrc/cpu/tpp/xsmm_functors.h @@ -299,20 +299,39 @@ inline int meqn_push_ternary_op( op_metadata, type, dtype, flags); } +template +inline uint64_t string_to_hash_int( + const std::string& str, + const std::array& params) { + uint64_t hash_value = 0; + unsigned int b = 378551; + unsigned int a = 63689; + unsigned int i = 0; + for (i = 0; i < str.length(); i++) { + hash_value = hash_value * a + str[i]; + a = a * b; + } + for (auto param : params) { + hash_value = hash_value * a + param; + a = a * b; + } + return hash_value; +} + class BaseTPP { public: void* get_kernel() { auto& kernel_cache = get_kernel_cache(); void* kernel = NULL; - if (hash == "") - hash = hash_str(); + if (hash == 0) + hash = hash_int(); auto search = kernel_cache.find(hash); if (search != kernel_cache.end()) kernel = search->second; if (kernel == NULL) { kernel = build_kernel(); if (kernel == NULL) { - fprintf(stderr, "Unable to get JIT kernel for %s\n", hash.c_str()); + print_error(); exit(1); } // printf("TPP: %s @ %p\n", hash.c_str(), kernel); @@ -320,19 +339,16 @@ class BaseTPP { } return kernel; } - // We should make hash_str() public - std::string get_hash_str() { - return hash_str(); - } protected: - std::unordered_map& get_kernel_cache() { - static std::unordered_map kernel_cache; + std::unordered_map& get_kernel_cache() { + static std::unordered_map kernel_cache; return kernel_cache; } - virtual std::string hash_str() = 0; + virtual uint64_t hash_int() = 0; virtual void* build_kernel() = 0; - std::string hash = ""; + virtual void print_error() = 0; + uint64_t hash = 0; bool initialized = false; }; @@ -416,12 +432,21 @@ class UnaryTPP : public BaseTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "unary_r%d_c%d_i%d_o%d_di%d_do%d_dc%d_f%d_t%d", + uint64_t hash_int() override { + std::array params = { + rows, cols, ldi, ldo, dt_in, dt_out, dt_compute, (int)flags, type}; + uint64_t hash_value = string_to_hash_int<9>("unary", params); + return hash_value; + } + void* build_kernel() override { + libxsmm_meltw_unary_shape shape = libxsmm_create_meltw_unary_shape( + cols, rows, ldi, ldo, dt_in, dt_out, dt_compute); + return (void*)libxsmm_dispatch_meltw_unary_v2(type, shape, flags); + } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for unary. Params: rows=%d, cols=%d, ldi=%d, ldo=%d, dt_in=%d, dt_out=%d, dt_compute=%d, flags=%d, type=%d\n", rows, cols, ldi, @@ -429,14 +454,8 @@ class UnaryTPP : public BaseTPP { dt_in, dt_out, dt_compute, - flags, + (int)flags, type); - return std::string(hash); - } - void* build_kernel() override { - libxsmm_meltw_unary_shape shape = libxsmm_create_meltw_unary_shape( - cols, rows, ldi, ldo, dt_in, dt_out, dt_compute); - return (void*)libxsmm_dispatch_meltw_unary_v2(type, shape, flags); } libxsmm_blasint rows = 0; @@ -515,12 +534,8 @@ class BinaryTPP : public BaseTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "binary_r%d_c%d_i0%d_i1%d_o%d_di0%d_di1%d_do%d_dc%d_f%d_t%d", + uint64_t hash_int() override { + std::array params = { rows, cols, ldi0, @@ -530,15 +545,32 @@ class BinaryTPP : public BaseTPP { dt_in1, dt_out, dt_compute, - flags, - type); - return std::string(hash); + (int)flags, + type}; + uint64_t hash_value = string_to_hash_int<11>("binary", params); + return hash_value; } void* build_kernel() override { libxsmm_meltw_binary_shape shape = libxsmm_create_meltw_binary_shape( cols, rows, ldi0, ldi1, ldo, dt_in0, dt_in1, dt_out, dt_compute); return (void*)libxsmm_dispatch_meltw_binary_v2(type, shape, flags); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for binary. Params: rows=%d, cols=%d, ldi0=%d, ldi1=%d, ldo=%d, dt_in0=%d, dt_in1=%d, dt_out=%d, dt_compute=%d, flags=%d, type=%d\n", + rows, + cols, + ldi0, + ldi1, + ldo, + dt_in0, + dt_in1, + dt_out, + dt_compute, + (int)flags, + type); + } libxsmm_blasint rows = 0; libxsmm_blasint cols = 0; @@ -965,18 +997,11 @@ class MulReduceTPP : public BaseTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "mul_reduce_eqn_t%d_%d_%d_r%d_c%d", - XsmmDtype(), - XsmmDtype(), - XsmmDtype(), - N, //); - M); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + N, M, XsmmDtype(), XsmmDtype(), XsmmDtype()}; + uint64_t hash_value = string_to_hash_int<5>("mul_reduce_eqn", params); + return hash_value; } void* build_kernel() override { auto dt1 = XsmmDtype(); @@ -1000,6 +1025,16 @@ class MulReduceTPP : public BaseTPP { debug_print_eqn_tree(my_eqn0); return (void*)meqn_dispatch(1, N, &ld, dt3, my_eqn0); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for mul_reduce_eqn. Params: N=%d, M=%d, dt1=%d, dt2=%d, dt3=%d\n", + N, + M, + XsmmDtype(), + XsmmDtype(), + XsmmDtype()); + } private: int N = 0; @@ -1958,12 +1993,8 @@ class BrgemmTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "brgemm_m%ld_n%ld_k%ld_a%ld_b%ld_t%ld_beta%d_at%d_uh%d_ld_a%ld_b%ld_c%ld_cfg%d_bv%d", + uint64_t hash_int() override { + std::array params = { p->M, p->N, p->K, @@ -1973,12 +2004,13 @@ class BrgemmTPP { (int)p->beta, p->a_trans, p->unroll_hint, - (long)p->lda, - (long)p->ldb, - (long)p->ldc, + p->lda, + p->ldb, + p->ldc, config, - p->b_vnni); - return std::string(hash); + p->b_vnni}; + uint64_t hash_value = string_to_hash_int<14>("brgemm", params); + return hash_value; } void* build_kernel() override { // float alpha = 1.0; @@ -2041,6 +2073,25 @@ class BrgemmTPP { return (void*)l_test_jit.gemm; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for brgemm. Params: M=%d, N=%d, K=%d, str_a=%d, str_b=%d, brgemm_type=%d, beta=%d, a_trans=%d, unroll_hint=%d, lda=%d, ldb=%d, ldc=%d, config=%d, b_vnni=%d", + p->M, + p->N, + p->K, + p->str_a, + p->str_b, + brgemm_type, + (int)p->beta, + p->a_trans, + p->unroll_hint, + p->lda, + p->ldb, + p->ldc, + config, + p->b_vnni); + } private: BrgemmTPP* p; @@ -2248,17 +2299,11 @@ class GeluBwdTPP : public BaseTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "gelu_bwd_eqn_t%d_%d_%d_i%d", - XsmmDtype(), - XsmmDtype(), - XsmmDtype(), - N); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + XsmmDtype(), XsmmDtype(), XsmmDtype(), N}; + uint64_t hash_value = string_to_hash_int<4>("gelu_bwd_eqn", params); + return hash_value; } void* build_kernel() override { auto dt1 = XsmmDtype(); @@ -2273,6 +2318,15 @@ class GeluBwdTPP : public BaseTPP { debug_print_eqn_tree(my_eqn0); return (void*)meqn_dispatch(N, 1, &ld, dt3, my_eqn0); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for gelu_bwd_eqn. Params: dt1=%d, dt2=%d, dt3=%d, N=%d", + XsmmDtype(), + XsmmDtype(), + XsmmDtype(), + N); + } private: int N = 0; @@ -2634,10 +2688,10 @@ class SiLUBwdTPP : public BaseTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf(hash, 200, "silu_bwd_eqn_%d_%d", rows, cols); - return std::string(hash); + uint64_t hash_int() override { + std::array params = {rows, cols}; + uint64_t hash_value = string_to_hash_int<2>("silu_bwd_eqn", params); + return hash_value; } void* build_kernel() override { libxsmm_blasint my_eqn0 = libxsmm_matrix_eqn_create(); @@ -2657,6 +2711,13 @@ class SiLUBwdTPP : public BaseTPP { auto func0 = meqn_dispatch(cols, rows, &ldo, XsmmDtype(), my_eqn0); return (void*)func0; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for silu_bwd_eqn. Params: rows=%d, cols=%d", + rows, + cols); + } private: int rows = 0; @@ -2884,18 +2945,11 @@ class SoftMaxFwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "softmax_fwd_eqn0_ti%d_to%d_S1%d_S2%d_S3%d", - XsmmDtype(), - LIBXSMM_DATATYPE_F32, - S1, - S2, - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + XsmmDtype(), LIBXSMM_DATATYPE_F32, S1, S2, S3}; + uint64_t hash_value = string_to_hash_int<5>("softmax_fwd_eqn0", params); + return hash_value; } void* build_kernel() override { auto dt_in = XsmmDtype(); @@ -2921,6 +2975,16 @@ class SoftMaxFwdTPP { return (void*)meqn_dispatch( S3, S1, &tmp_ld, LIBXSMM_DATATYPE_F32, my_eqn0); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for softmax_fwd_eqn0. Params: dt_in=%d, dt_out=%d, S1=%d, S2=%d, S3=%d", + XsmmDtype(), + LIBXSMM_DATATYPE_F32, + S1, + S2, + S3); + } private: int S1, S2, S3; @@ -2947,18 +3011,11 @@ class SoftMaxFwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "softmax_fwd_eqn1_ti%d_to%d_S1%d_S2%d_S3%d", - LIBXSMM_DATATYPE_F32, - XsmmDtype(), - S1, - S2, - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + LIBXSMM_DATATYPE_F32, XsmmDtype(), S1, S2, S3}; + uint64_t hash_value = string_to_hash_int<5>("softmax_fwd_eqn1", params); + return hash_value; } void* build_kernel() override { auto dt_out = XsmmDtype(); @@ -2983,6 +3040,16 @@ class SoftMaxFwdTPP { /*debug_print_eqn_tree( my_eqn1 );*/ return (void*)meqn_dispatch(S3, S1, &ld, dt_out, my_eqn1); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for softmax_fwd_eqn1. Params: dt_in=%d, dt_out=%d, S1=%d, S2=%d, S3=%d", + LIBXSMM_DATATYPE_F32, + XsmmDtype(), + S1, + S2, + S3); + } private: int S1, S2, S3; @@ -3112,20 +3179,17 @@ class SoftMaxBwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "softmax_bwd_eqn%d_t1%d_t2%d_t3%d_S1%d_S2%d_S3%d", + uint64_t hash_int() override { + std::array params = { eqn_no, + XsmmDtype(), XsmmDtype(), - XsmmDtype(), LIBXSMM_DATATYPE_F32, S1, S2, - S3); - return std::string(hash); + S3}; + uint64_t hash_value = string_to_hash_int<7>("softmax_bwd_eqn", params); + return hash_value; } void* build_kernel() override { auto dt_1 = XsmmDtype(); @@ -3185,6 +3249,18 @@ class SoftMaxBwdTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for softmax_bwd_eqn. Params: eqn_no=%d, dt_1=%d, dt_2=%d, dt_3=%d, S1=%d, S2=%d, S3=%d", + eqn_no, + XsmmDtype(), + XsmmDtype(), + LIBXSMM_DATATYPE_F32, + S1, + S2, + S3); + } private: int S1, S2, S3, eqn_no; @@ -3518,18 +3594,11 @@ class VarSoftMaxBwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "varsoftmax_bwd_eqn%d_t1%d_t2%d_t3%d_S3%d", - eqn_no, - XsmmDtype(), - XsmmDtype(), - XsmmDtype(), - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + eqn_no, XsmmDtype(), XsmmDtype(), XsmmDtype(), S3}; + uint64_t hash_value = string_to_hash_int<5>("varsoftmax_bwd_eqn", params); + return hash_value; } void* build_kernel() override { auto dt_1 = XsmmDtype(); @@ -3568,6 +3637,16 @@ class VarSoftMaxBwdTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for varsoftmax_bwd_eqn. Params: eqn_no=%d, dt_1=%d, dt_2=%d, dt_3=%d, S3=%d", + eqn_no, + XsmmDtype(), + XsmmDtype(), + XsmmDtype(), + S3); + } private: int S3, eqn_no; @@ -3686,17 +3765,10 @@ class LayerNormFwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "layernorm_fwd_eqn_t%d_S1%d_S2%d_S3%d", - XsmmDtype(), - S1, - S2, - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = {XsmmDtype(), S1, S2, S3}; + uint64_t hash_value = string_to_hash_int<5>("layernorm_fwd_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -3723,6 +3795,15 @@ class LayerNormFwdTPP { debug_print_eqn_tree(my_eqn0); // printf return (void*)meqn_dispatch(S3, S1, &ld, out_dt, my_eqn0); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for layernorm_fwd_eqn. Params: dt_1=%dS1=%d, S2=%d, S3=%d", + XsmmDtype(), + S1, + S2, + S3); + } private: int S1, S2, S3; @@ -3861,18 +3942,10 @@ class LayerNormBwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "layernorm_bwd_eqn%d_t%d_S1%d_S2%d_S3%d", - eqn_no, - XsmmDtype(), - S1, - S2, - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = {eqn_no, XsmmDtype(), S1, S2, S3}; + uint64_t hash_value = string_to_hash_int<6>("layernorm_bwd_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -3956,6 +4029,16 @@ class LayerNormBwdTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for layernorm_bwd_eqn. Params: eqn_no=%d, dt_1=%d, S1=%d, S2=%d, S3=%d", + eqn_no, + XsmmDtype(), + S1, + S2, + S3); + } private: int S1, S2, S3, eqn_no; @@ -4072,17 +4155,10 @@ class GroupNormFwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "group_norm_fwd_eqn_t%d_S1%d_S2%d_S3%d", - XsmmDtype(), - S1, - S2, - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = {XsmmDtype(), S1, S2, S3}; + uint64_t hash_value = string_to_hash_int<4>("groupnorm_fwd_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -4111,6 +4187,15 @@ class GroupNormFwdTPP { debug_print_eqn_tree(my_eqn0); // printf return (void*)meqn_dispatch(S3, S1, &ld, out_dt, my_eqn0); } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for groupnorm_fwd_eqn. Params: dt_1=%d, S1=%d, S2=%d, S3=%d", + XsmmDtype(), + S1, + S2, + S3); + } private: int S1, S2, S3; @@ -4242,18 +4327,10 @@ class GroupNormBwdTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "group_norm_bwd_eqn%d_t%d_S1%d_S2%d_S3%d", - eqn_no, - XsmmDtype(), - S1, - S2, - S3); - return std::string(hash); + uint64_t hash_int() override { + std::array params = {eqn_no, XsmmDtype(), S1, S2, S3}; + uint64_t hash_value = string_to_hash_int<5>("groupnorm_bwd_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -4349,6 +4426,16 @@ class GroupNormBwdTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for groupnorm_bwd_eqn. Params: eqn_no=%d, dt_1=%d, S1=%d, S2=%d, S3=%d", + eqn_no, + XsmmDtype(), + S1, + S2, + S3); + } private: int S1, S2, S3, eqn_no; @@ -4421,10 +4508,10 @@ class SplitSGDTPP : public BaseTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf(hash, 200, "split_sgd_eqn_i%d", N); - return std::string(hash); + uint64_t hash_int() override { + std::array params = {N}; + uint64_t hash_value = string_to_hash_int<1>("split_sgd_eqn", params); + return hash_value; } void* build_kernel() override { libxsmm_blasint ld = N; @@ -4448,6 +4535,10 @@ class SplitSGDTPP : public BaseTPP { auto func0 = meqn_dispatch(N, 1, &ld, LIBXSMM_DATATYPE_I16, my_eqn0); return (void*)func0; } + void print_error() override { + fprintf( + stderr, "Unable to get JIT kernel for split_sgd_eqn. Params: N=%d", N); + } private: int N = 0; @@ -4662,17 +4753,11 @@ class FusedAdamWTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "fused_adamw_eqn%d_t%d_n%d_wd%d", - eqn_no, - XsmmDtype(), - p->N, - (p->weight_decay == 0.0 ? 0 : 1)); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + eqn_no, XsmmDtype(), p->N, (p->weight_decay == 0.0 ? 0 : 1)}; + uint64_t hash_value = string_to_hash_int<4>("fused_adamw_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -4754,6 +4839,15 @@ class FusedAdamWTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for fused_adamw_eqn. Params: eqn_no=%d, dt_1=%d, N=%d, weight_decay=%d", + eqn_no, + XsmmDtype(), + p->N, + (p->weight_decay == 0.0 ? 0 : 1)); + } private: FusedAdamWTPP* p; @@ -4925,17 +5019,12 @@ class FusedSplitAdamWTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "fused_split_adamw_eqn%d_t%d_n%d_wd%d", - eqn_no, - XsmmDtype(), - p->N, - (p->weight_decay == 0.0 ? 0 : 1)); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + eqn_no, XsmmDtype(), p->N, (p->weight_decay == 0.0 ? 0 : 1)}; + uint64_t hash_value = + string_to_hash_int<4>("fused_split_adamw_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -5031,6 +5120,15 @@ class FusedSplitAdamWTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for fused_split_adamw_eqn. Params: eqn_no=%d, dt_1=%d, N=%d, weight_decay=%d", + eqn_no, + XsmmDtype(), + p->N, + (p->weight_decay == 0.0 ? 0 : 1)); + } private: FusedSplitAdamWTPP* p; @@ -5214,17 +5312,12 @@ class FusedAdamStepTPP { } protected: - std::string hash_str() override { - char hash[200]; - snprintf( - hash, - 200, - "fused_adam_step_eqn%d_t%d_n%d_wd%d", - eqn_no, - XsmmDtype(), - p->N, - p->use_weight_decay); - return std::string(hash); + uint64_t hash_int() override { + std::array params = { + eqn_no, XsmmDtype(), p->N, p->use_weight_decay}; + uint64_t hash_value = + string_to_hash_int<4>("fused_adam_step_eqn", params); + return hash_value; } void* build_kernel() override { auto in_dt = XsmmDtype(); @@ -5315,6 +5408,15 @@ class FusedAdamStepTPP { } return (void*)func; } + void print_error() override { + fprintf( + stderr, + "Unable to get JIT kernel for fused_adam_step_eqn. Params: eqn_no=%d, dt_1=%d, N=%d, weight_decay=%d", + eqn_no, + XsmmDtype(), + p->N, + (p->use_weight_decay == 0.0 ? 0 : 1)); + } private: FusedAdamStepTPP* p;