Skip to content

Commit

Permalink
Fixed registering statically generated code (issue introduced by SoA …
Browse files Browse the repository at this point in the history
…change set). Simplified code, and disabled "basic hash" function (to be replaced eventually, aiming for lower cost hash is less important due to thread-local code cache; basic hash only applied to legacy code path).
  • Loading branch information
hfp committed May 3, 2016
1 parent 46d014d commit de0af05
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 55 deletions.
8 changes: 4 additions & 4 deletions scripts/libxsmm_dispatch.py
Expand Up @@ -53,16 +53,16 @@
print("LIBXSMM_GEMM_DESCRIPTOR(desc, LIBXSMM_ALIGNMENT, LIBXSMM_FLAGS,")
print(" " + mnksig + ", " + ldxsig + ",")
print(" LIBXSMM_ALPHA, LIBXSMM_BETA, INTERNAL_PREFETCH);")
print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, LIBXSMM_HASH_FUNCTION, desc);")
print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, desc);")
print("func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_" + mnkstr + ";")
print("internal_register_static_code(&desc, indx, hash, func, result + indx, &cdp_reg, &cdp_tot);")
print("internal_register_static_code(&desc, indx, hash, func, result, &cdp_reg, &cdp_tot);")
if (2 != precision): # only single-precision
print("LIBXSMM_GEMM_DESCRIPTOR(desc, LIBXSMM_ALIGNMENT, LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_F32PREC,")
print(" " + mnksig + ", " + ldxsig + ",")
print(" LIBXSMM_ALPHA, LIBXSMM_BETA, INTERNAL_PREFETCH);")
print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, LIBXSMM_HASH_FUNCTION, desc);")
print("LIBXSMM_HASH_FUNCTION_CALL(hash, indx, desc);")
print("func.smm = (libxsmm_smmfunction)libxsmm_smm_" + mnkstr + ";")
print("internal_register_static_code(&desc, indx, hash, func, result + indx, &csp_reg, &csp_tot);")
print("internal_register_static_code(&desc, indx, hash, func, result, &csp_reg, &csp_tot);")
elif (1 < argc):
print("/* no static code */")
else:
Expand Down
86 changes: 36 additions & 50 deletions src/libxsmm.c
Expand Up @@ -84,7 +84,7 @@
/* alternative hash algorithm (instead of CRC32) */
#if !defined(LIBXSMM_HASH_BASIC) && !defined(LIBXSMM_REGSIZE)
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) || (LIBXSMM_X86_SSE4_2 > LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_HASH_BASIC
/*# define LIBXSMM_HASH_BASIC*/
# endif
#endif

Expand All @@ -105,15 +105,13 @@
#endif

#if defined(LIBXSMM_HASH_BASIC)
# define LIBXSMM_HASH_FUNCTION libxsmm_hash_npot
# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, HASH_FUNCTION, DESCRIPTOR) \
HASH = (HASH_FUNCTION)(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, LIBXSMM_REGSIZE); \
# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, DESCRIPTOR) \
HASH = libxsmm_hash_npot(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, LIBXSMM_REGSIZE); \
assert((LIBXSMM_REGSIZE) > (HASH)); \
INDX = (HASH)
#else
# define LIBXSMM_HASH_FUNCTION libxsmm_crc32
# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, HASH_FUNCTION, DESCRIPTOR) \
HASH = (HASH_FUNCTION)(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, 25071975/*seed*/); \
# define LIBXSMM_HASH_FUNCTION_CALL(HASH, INDX, DESCRIPTOR) \
HASH = libxsmm_crc32(&(DESCRIPTOR), LIBXSMM_GEMM_DESCRIPTOR_SIZE, 25071975/*seed*/); \
INDX = LIBXSMM_HASH_MOD(HASH, LIBXSMM_REGSIZE)
#endif

Expand Down Expand Up @@ -303,7 +301,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl
# define INTERNAL_FIND_CODE_JIT(DESCRIPTOR, CODE, RESULT)
#endif

#define INTERNAL_FIND_CODE(DESCRIPTOR, CODE, HASH_FUNCTION, DIFF_FUNCTION) \
#define INTERNAL_FIND_CODE(DESCRIPTOR, CODE) \
internal_regentry flux_entry; \
{ \
INTERNAL_FIND_CODE_CACHE_DECL(cache_id, cache_keys, cache, cache_hit); \
Expand All @@ -312,7 +310,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl
INTERNAL_FIND_CODE_CACHE_BEGIN(cache_id, cache_keys, cache, cache_hit, flux_entry, DESCRIPTOR) { \
/* check if the requested xGEMM is already JITted */ \
LIBXSMM_PRAGMA_FORCEINLINE /* must precede a statement */ \
LIBXSMM_HASH_FUNCTION_CALL(hash, i = i0, HASH_FUNCTION, *(DESCRIPTOR)); \
LIBXSMM_HASH_FUNCTION_CALL(hash, i = i0, *(DESCRIPTOR)); \
(CODE) += i; /* actual entry */ \
do { \
INTERNAL_FIND_CODE_READ(CODE, flux_entry.function.pmm); /* read registered code */ \
Expand All @@ -321,22 +319,22 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl
if (0 == (LIBXSMM_HASH_COLLISION & flux_entry.function.imm)) { /* check for no collision */ \
/* calculate bitwise difference (deep check) */ \
LIBXSMM_PRAGMA_FORCEINLINE /* must precede a statement */ \
diff = (DIFF_FUNCTION)(DESCRIPTOR, &internal_registry_keys[i].descriptor); \
diff = libxsmm_gemm_diff(DESCRIPTOR, &internal_registry_keys[i].descriptor); \
if (0 != diff) { /* new collision discovered (but no code version yet) */ \
/* allow to fix-up current entry inside of the guarded/locked region */ \
flux_entry.function.pmm = 0; \
} \
} \
/* collision discovered but code version exists; perform deep check */ \
else if (0 != (DIFF_FUNCTION)(DESCRIPTOR, &internal_registry_keys[i].descriptor)) { \
else if (0 != libxsmm_gemm_diff(DESCRIPTOR, &internal_registry_keys[i].descriptor)) { \
/* continue linearly searching code starting at re-hashed index position */ \
const unsigned int index = LIBXSMM_HASH_MOD(LIBXSMM_HASH_VALUE(hash), LIBXSMM_REGSIZE); \
unsigned int next; \
for (i0 = (index != i ? index : LIBXSMM_HASH_MOD(index + 1, LIBXSMM_REGSIZE)), \
i = i0, next = LIBXSMM_HASH_MOD(i0 + 1, LIBXSMM_REGSIZE); \
/* skip any (still invalid) descriptor which corresponds to no code, or continue on difference */ \
(0 == (CODE = (internal_registry + i))->function.pmm || \
0 != (diff = (DIFF_FUNCTION)(DESCRIPTOR, &internal_registry_keys[i].descriptor))) \
0 != (diff = libxsmm_gemm_diff(DESCRIPTOR, &internal_registry_keys[i].descriptor))) \
/* entire registry was searched and no code version was found */ \
&& next != i0; \
i = next, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE)); \
Expand Down Expand Up @@ -370,7 +368,7 @@ LIBXSMM_RETARGETABLE LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_LOCK_TYPE internal_regl
} \
return flux_entry.function.xmm

#define INTERNAL_DISPATCH_MAIN(DESCRIPTOR_DECL, DESC, FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) { \
#define INTERNAL_DISPATCH_MAIN(DESCRIPTOR_DECL, DESC, FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) { \
INTERNAL_FIND_CODE_DECLARE(code); \
const signed char scalpha = (signed char)(0 == (PALPHA) ? LIBXSMM_ALPHA : *(PALPHA)), scbeta = (signed char)(0 == (PBETA) ? LIBXSMM_BETA : *(PBETA)); \
if (0 == ((FLAGS) & (LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B)) && 1 == scalpha && (1 == scbeta || 0 == scbeta)) { \
Expand All @@ -381,7 +379,7 @@ return flux_entry.function.xmm
0 == (PLDC) ? LIBXSMM_LD(M, N) : *(PLDC), scalpha, scbeta, \
0 > internal_dispatch_main_prefetch ? internal_prefetch : internal_dispatch_main_prefetch); \
{ \
INTERNAL_FIND_CODE(DESC, code, HASH_FUNCTION, DIFF_FUNCTION).SELECTOR; \
INTERNAL_FIND_CODE(DESC, code).SELECTOR; \
} \
} \
else { /* TODO: not supported (bypass) */ \
Expand All @@ -390,52 +388,52 @@ return flux_entry.function.xmm
}

#if defined(LIBXSMM_GEMM_DIFF_MASK_A) /* no padding i.e., LIBXSMM_GEMM_DESCRIPTOR_SIZE */
# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) \
# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) \
INTERNAL_DISPATCH_MAIN(libxsmm_gemm_descriptor descriptor, &descriptor, \
FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION)
FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/)
#else /* padding: LIBXSMM_GEMM_DESCRIPTOR_SIZE -> LIBXSMM_ALIGNMENT */
# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION) { \
# define INTERNAL_DISPATCH(FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/) { \
INTERNAL_DISPATCH_MAIN(union { libxsmm_gemm_descriptor desc; char simd[LIBXSMM_ALIGNMENT]; } simd_descriptor; \
for (i = LIBXSMM_GEMM_DESCRIPTOR_SIZE; i < sizeof(simd_descriptor.simd); ++i) simd_descriptor.simd[i] = 0, &simd_descriptor.desc, \
FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/, HASH_FUNCTION, DIFF_FUNCTION)
FLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, SELECTOR/*smm or dmm*/)
#endif

#define INTERNAL_SMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, HASH_FUNCTION, DIFF_FUNCTION) \
#define INTERNAL_SMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH) \
INTERNAL_DISPATCH((0 == (PFLAGS) ? LIBXSMM_FLAGS : *(PFLAGS)) | LIBXSMM_GEMM_FLAG_F32PREC, \
M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, smm, HASH_FUNCTION, DIFF_FUNCTION)

#define INTERNAL_DMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, HASH_FUNCTION, DIFF_FUNCTION) \
M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, smm)
#define INTERNAL_DMMDISPATCH(PFLAGS, M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH) \
INTERNAL_DISPATCH((0 == (PFLAGS) ? LIBXSMM_FLAGS : *(PFLAGS)), \
M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, dmm, HASH_FUNCTION, DIFF_FUNCTION)
M, N, K, PLDA, PLDB, PLDC, PALPHA, PBETA, PREFETCH, dmm)


LIBXSMM_INLINE LIBXSMM_RETARGETABLE void internal_register_static_code(
const libxsmm_gemm_descriptor* desc, unsigned int index, unsigned int hash, libxsmm_xmmfunction src,
internal_regentry* dst_entries, unsigned int* registered, unsigned int* total)
internal_regentry* registry, unsigned int* registered, unsigned int* total)
{
internal_regkey* dst_keys = internal_registry_keys;
assert(0 != desc && 0 != src.dmm && 0 != dst_keys && 0 != dst_entries && 0 != registered && 0 != total);
internal_regkey* dst_key = internal_registry_keys + index;
internal_regentry* dst_entry = registry + index;
assert(0 != desc && 0 != src.dmm && 0 != dst_key && 0 != registry && 0 != registered && 0 != total);

if (0 != dst_entries->function.pmm) { /* collision? */
if (0 != dst_entry->function.pmm) { /* collision? */
/* start at a re-hashed index position */
const unsigned int start = LIBXSMM_HASH_MOD(LIBXSMM_HASH_VALUE(hash), LIBXSMM_REGSIZE);
internal_regentry *const registry = dst_entries - index; /* recalculate base address */
unsigned int i0, i, next;

/* mark current entry as a collision (this might be already the case) */
dst_entries->function.imm |= LIBXSMM_HASH_COLLISION;
dst_entry->function.imm |= LIBXSMM_HASH_COLLISION;

/* start linearly searching for an available slot */
for (i = (start != index) ? start : LIBXSMM_HASH_MOD(start + 1, LIBXSMM_REGSIZE), i0 = i, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE);
0 != (dst_entries = registry + i)->function.pmm && next != i0; i = next, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE));
0 != (dst_entry = registry + i)->function.pmm && next != i0; i = next, next = LIBXSMM_HASH_MOD(i + 1, LIBXSMM_REGSIZE));

dst_keys += i;
/* corresponding key position */
dst_key = internal_registry_keys + i;
}

if (0 == dst_entries->function.pmm) { /* registry not (yet) exhausted */
dst_entries->function.xmm = src;
dst_entries->size = 0; /* statically generated code */
dst_keys->descriptor = *desc;
if (0 == dst_entry->function.pmm) { /* registry not (yet) exhausted */
dst_entry->function.xmm = src;
dst_entry->size = 0; /* statically generated code */
dst_key->descriptor = *desc;
++(*registered);
}

Expand Down Expand Up @@ -951,11 +949,7 @@ LIBXSMM_INLINE LIBXSMM_RETARGETABLE libxsmm_xmmfunction internal_xmmdispatch(con
INTERNAL_FIND_CODE_DECLARE(code);
assert(descriptor);
{
#if defined(LIBXSMM_HASH_BASIC)
INTERNAL_FIND_CODE(descriptor, code, libxsmm_hash_npot, libxsmm_gemm_diff);
#else
INTERNAL_FIND_CODE(descriptor, code, libxsmm_crc32, libxsmm_gemm_diff);
#endif
INTERNAL_FIND_CODE(descriptor, code);
}
}

Expand All @@ -972,11 +966,7 @@ LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE libxsmm_smmfunction libxsmm_smmdispatch(in
const float* alpha, const float* beta,
const int* flags, const int* prefetch)
{
#if defined(LIBXSMM_HASH_BASIC)
INTERNAL_SMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_hash_npot, libxsmm_gemm_diff);
#else
INTERNAL_SMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_crc32, libxsmm_gemm_diff);
#endif
INTERNAL_SMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch);
}


Expand All @@ -985,10 +975,6 @@ LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE libxsmm_dmmfunction libxsmm_dmmdispatch(in
const double* alpha, const double* beta,
const int* flags, const int* prefetch)
{
#if defined(LIBXSMM_HASH_BASIC)
INTERNAL_DMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_hash_npot, libxsmm_gemm_diff);
#else
INTERNAL_DMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch, libxsmm_crc32, libxsmm_gemm_diff);
#endif
INTERNAL_DMMDISPATCH(flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch);
}

2 changes: 1 addition & 1 deletion version.txt
@@ -1 +1 @@
master-1.4-35
master-1.4-36

0 comments on commit de0af05

Please sign in to comment.