From de0af05137c9542ef87dbac6b82bb9b828eb5c1a Mon Sep 17 00:00:00 2001 From: Hans Pabst Date: Tue, 3 May 2016 19:55:24 +0200 Subject: [PATCH] Fixed registering statically generated code (issue introduced by SoA change set). Simplified code, and disabled "basic hash" function (to be replaced eventually, aiming for lower cost hash is less important due to thread-local code cache; basic hash only applied to legacy code path). --- scripts/libxsmm_dispatch.py | 8 ++-- src/libxsmm.c | 86 ++++++++++++++++--------------------- version.txt | 2 +- 3 files changed, 41 insertions(+), 55 deletions(-) diff --git a/scripts/libxsmm_dispatch.py b/scripts/libxsmm_dispatch.py index e4d3bc9aa6..2c3121ac85 100755 --- a/scripts/libxsmm_dispatch.py +++ b/scripts/libxsmm_dispatch.py @@ -53,16 +53,16 @@ print("LIBXSMM_GEMM_DESCRIPTOR(desc, LIBXSMM_ALIGNMENT, LIBXSMM_FLAGS,") print(" " + mnksig + ", " + ldxsig + ",") print(" LIBXSMM_ALPHA, LIBXSMM_BETA, INTERNAL_PREFETCH);") - print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, LIBXSMM_HASH_FUNCTION, desc);") + print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, desc);") print("func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_" + mnkstr + ";") - print("internal_register_static_code(&desc, indx, hash, func, result + indx, &cdp_reg, &cdp_tot);") + print("internal_register_static_code(&desc, indx, hash, func, result, &cdp_reg, &cdp_tot);") if (2 != precision): # only single-precision print("LIBXSMM_GEMM_DESCRIPTOR(desc, LIBXSMM_ALIGNMENT, LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_F32PREC,") print(" " + mnksig + ", " + ldxsig + ",") print(" LIBXSMM_ALPHA, LIBXSMM_BETA, INTERNAL_PREFETCH);") - print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, LIBXSMM_HASH_FUNCTION, desc);") + print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, desc);") print("func.smm = (libxsmm_smmfunction)libxsmm_smm_" + mnkstr + ";") - print("internal_register_static_code(&desc, indx, hash, func, result + indx, &csp_reg, &csp_tot);") + print("internal_register_static_code(&desc, indx, hash, func, result, &csp_reg, &csp_tot);") elif (1 < argc): print("/* no static code */") else: diff --git a/src/libxsmm.c b/src/libxsmm.c index 8b6ee684c0..05871ed9d3 100644 --- a/src/libxsmm.c +++ b/src/libxsmm.c @@ -84,7 +84,7 @@ /* alternative hash algorithm (instead of CRC32) */ #if !defined(LIBXSMM_HASH_BASIC) && !defined(LIBXSMM_REGSIZE) # if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) || (LIBXSMM_X86_SSE4_2 > LIBXSMM_MAX_STATIC_TARGET_ARCH) -# define LIBXSMM_HASH_BASIC +/*# define LIBXSMM_HASH_BASIC*/ # endif #endif @@ -105,15 +105,13 @@ #endif #if defined(LIBXSMM_HASH_BASIC) -# define LIBXSMM_HASH_FUNCTION libxsmm_hash_npot -# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, HASH_FUNCTION, DESCRIPTOR) \ - HASH = (HASH_FUNCTION)(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, LIBXSMM_REGSIZE); \ +# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, DESCRIPTOR) \ + HASH = libxsmm_hash_npot(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, LIBXSMM_REGSIZE); \ assert((LIBXSMM_REGSIZE) > (HASH)); \ INDX = (HASH) #else -# define LIBXSMM_HASH_FUNCTION libxsmm_crc32 -# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, HASH_FUNCTION, DESCRIPTOR) \ - HASH = (HASH_FUNCTION)(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, 25071975/*seed*/); \ +# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, DESCRIPTOR) \ + HASH = libxsmm_crc32(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, 25071975/*seed*/); \ INDX = LIBXSMM_HASH_MOD(HASH, LIBXSMM_REGSIZE) #endif @@ -303,7 +301,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl # define INTERNAL_FIND_CODE_JIT(DESCRIPTOR, CODE, RESULT) #endif -#define INTERNAL_FIND_CODE(DESCRIPTOR, CODE, HASH_FUNCTION, DIFF_FUNCTION) \ +#define INTERNAL_FIND_CODE(DESCRIPTOR, CODE) \ internal_regentry flux_entry; \ { \ INTERNAL_FIND_CODE_CACHE_DECL(cache_id, cache_keys, cache, cache_hit); \ @@ -312,7 +310,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl INTERNAL_FIND_CODE_CACHE_BEGIN(cache_id, cache_keys, cache, cache_hit, flux_entry, DESCRIPTOR) { \ /* check if the requested xGEMM is already JITted */ \ LIBXSMM_PRAGMA_FORCEINLINE /* must precede a statement */ \ - LIBXSMM_HASH_FUNCTION_CALL(hash, i = i0, HASH_FUNCTION, *(DESCRIPTOR)); \ + LIBXSMM_HASH_FUNCTION_CALL(hash, i = i0, *(DESCRIPTOR)); \ (CODE) += i; /* actual entry */ \ do { \ INTERNAL_FIND_CODE_READ(CODE, flux_entry.function.pmm); /* read registered code */ \ @@ -321,14 +319,14 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl if (0 == (LIBXSMM_HASH_COLLISION & flux_entry.function.imm)) { /* check for no collision */ \ /* calculate bitwise difference (deep check) */ \ LIBXSMM_PRAGMA_FORCEINLINE /* must precede a statement */ \ - diff = (DIFF_FUNCTION)(DESCRIPTOR, &internal_registry_keys[i].descriptor); \ + diff = libxsmm_gemm_diff(DESCRIPTOR, &internal_registry_keys[i].descriptor); \ if (0 != diff) { /* new collision discovered (but no code version yet) */ \ /* allow to fix-up current entry inside of the guarded/locked region */ \ flux_entry.function.pmm = 0; \ } \ } \ /* collision discovered but code version exists; perform deep check */ \ - else if (0 != (DIFF_FUNCTION)(DESCRIPTOR, &internal_registry_keys[i].descriptor)) { \ + else if (0 != libxsmm_gemm_diff(DESCRIPTOR, &internal_registry_keys[i].descriptor)) { \ /* continue linearly searching code starting at re-hashed index position */ \ const unsigned int index = LIBXSMM_HASH_MOD(LIBXSMM_HASH_VALUE(hash), LIBXSMM_REGSIZE); \ unsigned int next; \ @@ -336,7 +334,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl i = i0, next = LIBXSMM_HASH_MOD(i0 + 1, LIBXSMM_REGSIZE); \ /* skip any (still invalid) descriptor which corresponds to no code, or continue on difference */ \ (0 == (CODE = (internal_registry + i))->function.pmm || \ - 0 != (diff = (DIFF_FUNCTION)(DESCRIPTOR, &internal_registry_keys[i].descriptor))) \ + 0 != (diff = libxsmm_gemm_diff(DESCRIPTOR, &internal_registry_keys[i].descriptor))) \ /* entire registry was searched and no code version was found */ \ && next != i0; \ i = next, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE)); \ @@ -370,7 +368,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl } \ return flux_entry.function.xmm -#define INTERNAL_DISPATCH_MAIN(DESCRIPTOR_DECL, DESC, FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) { \ +#define INTERNAL_DISPATCH_MAIN(DESCRIPTOR_DECL, DESC, FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) { \ INTERNAL_FIND_CODE_DECLARE(code); \ const signed char scalpha = (signed char)(0 == (PALPHA) ? LIBXSMM_ALPHA : *(PALPHA)), scbeta = (signed char)(0 == (PBETA) ? LIBXSMM_BETA : *(PBETA)); \ if (0 == ((FLAGS) & (LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B)) && 1 == scalpha && (1 == scbeta || 0 == scbeta)) { \ @@ -381,7 +379,7 @@ return flux_entry.function.xmm 0 == (PLDC) ? LIBXSMM_LD(M, N) : *(PLDC), scalpha, scbeta, \ 0 > internal_dispatch_main_prefetch ? internal_prefetch : internal_dispatch_main_prefetch); \ { \ - INTERNAL_FIND_CODE(DESC, code, HASH_FUNCTION, DIFF_FUNCTION).SELECTOR; \ + INTERNAL_FIND_CODE(DESC, code).SELECTOR; \ } \ } \ else { /* TODO: not supported (bypass) */ \ @@ -390,52 +388,52 @@ return flux_entry.function.xmm } #if defined(LIBXSMM_GEMM_DIFF_MASK_A) /* no padding i.e., LIBXSMM_GEMM_DESCRIPTOR_SIZE */ -# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) \ +# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) \ INTERNAL_DISPATCH_MAIN(libxsmm_gemm_descriptor descriptor, &descriptor, \ - FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) + FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) #else /* padding: LIBXSMM_GEMM_DESCRIPTOR_SIZE -> LIBXSMM_ALIGNMENT */ -# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) { \ +# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) { \ INTERNAL_DISPATCH_MAIN(union { libxsmm_gemm_descriptor desc; char simd[LIBXSMM_ALIGNMENT]; } simd_descriptor; \ for (i = LIBXSMM_GEMM_DESCRIPTOR_SIZE; i < sizeof(simd_descriptor.simd); ++i) simd_descriptor.simd[i] = 0, &simd_descriptor.desc, \ - FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) + FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) #endif -#define INTERNAL_SMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, HASH_FUNCTION, DIFF_FUNCTION) \ +#define INTERNAL_SMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH) \ INTERNAL_DISPATCH((0 == (PFLAGS) ? LIBXSMM_FLAGS : *(PFLAGS)) | LIBXSMM_GEMM_FLAG_F32PREC, \ - M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, smm, HASH_FUNCTION, DIFF_FUNCTION) - -#define INTERNAL_DMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, HASH_FUNCTION, DIFF_FUNCTION) \ + M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, smm) +#define INTERNAL_DMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH) \ INTERNAL_DISPATCH((0 == (PFLAGS) ? LIBXSMM_FLAGS : *(PFLAGS)), \ - M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, dmm, HASH_FUNCTION, DIFF_FUNCTION) + M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, dmm) LIBXSMM_INLINE LIBXSMM_RETARGETABLE void internal_register_static_code( const libxsmm_gemm_descriptor* desc, unsigned int index, unsigned int hash, libxsmm_xmmfunction src, - internal_regentry* dst_entries, unsigned int* registered, unsigned int* total) + internal_regentry* registry, unsigned int* registered, unsigned int* total) { - internal_regkey* dst_keys = internal_registry_keys; - assert(0 != desc && 0 != src.dmm && 0 != dst_keys && 0 != dst_entries && 0 != registered && 0 != total); + internal_regkey* dst_key = internal_registry_keys + index; + internal_regentry* dst_entry = registry + index; + assert(0 != desc && 0 != src.dmm && 0 != dst_key && 0 != registry && 0 != registered && 0 != total); - if (0 != dst_entries->function.pmm) { /* collision? */ + if (0 != dst_entry->function.pmm) { /* collision? */ /* start at a re-hashed index position */ const unsigned int start = LIBXSMM_HASH_MOD(LIBXSMM_HASH_VALUE(hash), LIBXSMM_REGSIZE); - internal_regentry *const registry = dst_entries - index; /* recalculate base address */ unsigned int i0, i, next; /* mark current entry as a collision (this might be already the case) */ - dst_entries->function.imm |= LIBXSMM_HASH_COLLISION; + dst_entry->function.imm |= LIBXSMM_HASH_COLLISION; /* start linearly searching for an available slot */ for (i = (start != index) ? start : LIBXSMM_HASH_MOD(start + 1, LIBXSMM_REGSIZE), i0 = i, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE); - 0 != (dst_entries = registry + i)->function.pmm && next != i0; i = next, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE)); + 0 != (dst_entry = registry + i)->function.pmm && next != i0; i = next, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE)); - dst_keys += i; + /* corresponding key position */ + dst_key = internal_registry_keys + i; } - if (0 == dst_entries->function.pmm) { /* registry not (yet) exhausted */ - dst_entries->function.xmm = src; - dst_entries->size = 0; /* statically generated code */ - dst_keys->descriptor = *desc; + if (0 == dst_entry->function.pmm) { /* registry not (yet) exhausted */ + dst_entry->function.xmm = src; + dst_entry->size = 0; /* statically generated code */ + dst_key->descriptor = *desc; ++(*registered); } @@ -951,11 +949,7 @@ LIBXSMM_INLINE LIBXSMM_RETARGETABLE libxsmm_xmmfunction internal_xmmdispatch(con INTERNAL_FIND_CODE_DECLARE(code); assert(descriptor); { -#if defined(LIBXSMM_HASH_BASIC) - INTERNAL_FIND_CODE(descriptor, code, libxsmm_hash_npot, libxsmm_gemm_diff); -#else - INTERNAL_FIND_CODE(descriptor, code, libxsmm_crc32, libxsmm_gemm_diff); -#endif + INTERNAL_FIND_CODE(descriptor, code); } } @@ -972,11 +966,7 @@ LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE libxsmm_smmfunction libxsmm_smmdispatch(in const float* alpha, const float* beta, const int* flags, const int* prefetch) { -#if defined(LIBXSMM_HASH_BASIC) - INTERNAL_SMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_hash_npot, libxsmm_gemm_diff); -#else - INTERNAL_SMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_crc32, libxsmm_gemm_diff); -#endif + INTERNAL_SMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); } @@ -985,10 +975,6 @@ LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE libxsmm_dmmfunction libxsmm_dmmdispatch(in const double* alpha, const double* beta, const int* flags, const int* prefetch) { -#if defined(LIBXSMM_HASH_BASIC) - INTERNAL_DMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_hash_npot, libxsmm_gemm_diff); -#else - INTERNAL_DMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_crc32, libxsmm_gemm_diff); -#endif + INTERNAL_DMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); } diff --git a/version.txt b/version.txt index 56db1af24e..a9b11034e0 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -master-1.4-35 +master-1.4-36