From 58fcb4173b3327f2c058ec54910c356e94da57ba Mon Sep 17 00:00:00 2001 From: Hans Pabst Date: Sat, 7 May 2016 22:07:25 +0200 Subject: [PATCH] Implemented MATMUL-style routines for single-precision and double precision as well as overloaded routines. Avoid short-cycling in IF condition (construct_?mmfunction FUNCTIONs). --- src/libxsmm.template.f | 77 ++++++++++++++++++++++++++++++++++++++---- version.txt | 2 +- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/libxsmm.template.f b/src/libxsmm.template.f index ebf438394c..83506101f4 100644 --- a/src/libxsmm.template.f +++ b/src/libxsmm.template.f @@ -235,6 +235,11 @@ PURE SUBROUTINE LIBXSMM_MMFUNCTION1(a, b, c, & MODULE PROCEDURE libxsmm_blas_sgemm, libxsmm_blas_dgemm END INTERFACE + ! Overloaded MATMUL-style routines (single/double precision). + INTERFACE libxsmm_matmul + MODULE PROCEDURE libxsmm_smatmul, libxsmm_dmatmul + END INTERFACE + !DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_init, libxsmm_finalize !DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_get_target_arch !DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_set_target_arch @@ -356,6 +361,7 @@ TYPE(LIBXSMM_SMMFUNCTION) FUNCTION construct_smmfunction( & !DIR$ ATTRIBUTES OFFLOAD:MIC :: fn1 PROCEDURE(LIBXSMM_MMFUNCTION1), POINTER :: fn1 !DIR$ ATTRIBUTES OFFLOAD:MIC :: sdispatch + INTEGER(C_INT) :: oprefetch INTERFACE TYPE(C_FUNPTR) PURE FUNCTION sdispatch( & & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) & @@ -367,9 +373,12 @@ TYPE(C_FUNPTR) PURE FUNCTION sdispatch( & INTEGER(C_INT), INTENT(IN) :: flags, prefetch END FUNCTION END INTERFACE - IF (.NOT.PRESENT(prefetch).OR. & - & LIBXSMM_PREFETCH_NONE.EQ.prefetch) & - & THEN + IF (.NOT.PRESENT(prefetch)) THEN + oprefetch = LIBXSMM_PREFETCH_NONE + ELSE + oprefetch = prefetch + END IF + IF (LIBXSMM_PREFETCH_NONE.EQ.oprefetch) THEN CALL C_F_PROCPOINTER(sdispatch(m, n, k, & & lda, ldb, ldc, alpha, beta, flags, prefetch), & & fn0) @@ -396,6 +405,7 @@ TYPE(LIBXSMM_DMMFUNCTION) FUNCTION construct_dmmfunction( & !DIR$ ATTRIBUTES OFFLOAD:MIC :: fn1 PROCEDURE(LIBXSMM_MMFUNCTION1), POINTER :: fn1 !DIR$ ATTRIBUTES OFFLOAD:MIC :: ddispatch + INTEGER(C_INT) :: oprefetch INTERFACE PURE FUNCTION ddispatch( & & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) & @@ -408,9 +418,12 @@ PURE FUNCTION ddispatch( & TYPE(C_FUNPTR) :: fn END FUNCTION END INTERFACE - IF (.NOT.PRESENT(prefetch).OR. & - & LIBXSMM_PREFETCH_NONE.EQ.prefetch) & - & THEN + IF (.NOT.PRESENT(prefetch)) THEN + oprefetch = LIBXSMM_PREFETCH_NONE + ELSE + oprefetch = prefetch + END IF + IF (LIBXSMM_PREFETCH_NONE.EQ.oprefetch) THEN CALL C_F_PROCPOINTER(ddispatch(m, n, k, & & lda, ldb, ldc, alpha, beta, flags, prefetch), & & fn0) @@ -728,4 +741,56 @@ SUBROUTINE internal_gemm(transa, transb, m, n, k, & CALL internal_gemm(transa, transb, m, n, k, & & alpha, a, lda, b, ldb, beta, c, ldc) END SUBROUTINE + + !DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smatmul + SUBROUTINE libxsmm_smatmul(c, a, b, alpha, beta, transb) + REAL(C_FLOAT), INTENT(INOUT) :: c(:,:) + REAL(C_FLOAT), INTENT(IN) :: a(:,:), b(:,:) + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + CHARACTER, INTENT(IN), OPTIONAL :: transb + CHARACTER :: otransb + IF (.NOT.PRESENT(transb)) THEN + otransb = 'N' + ELSE + otransb = transb + END IF + IF (('N'.EQ.otransb).OR.('n'.EQ.otransb)) THEN + CALL libxsmm_sgemm('N', transb, & + & SIZE(c, 1), SIZE(c, 2), SIZE(a, 2), & + & alpha, a, SIZE(a, 1), & + & b, SIZE(b, 1), beta, c, SIZE(c, 1)) + ELSE + ! TODO: transpose is currently not supported by LIBXSMM + CALL libxsmm_sgemm('N', 'N', & + & SIZE(c, 1), SIZE(c, 2), SIZE(a, 2), & + & alpha, a, SIZE(a, 1), & + & TRANSPOSE(b), SIZE(b, 2), beta, c, SIZE(c, 1)) + END IF + END SUBROUTINE + + !DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_dmatmul + SUBROUTINE libxsmm_dmatmul(c, a, b, alpha, beta, transb) + REAL(C_DOUBLE), INTENT(INOUT) :: c(:,:) + REAL(C_DOUBLE), INTENT(IN) :: a(:,:), b(:,:) + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + CHARACTER, INTENT(IN), OPTIONAL :: transb + CHARACTER :: otransb + IF (.NOT.PRESENT(transb)) THEN + otransb = 'N' + ELSE + otransb = transb + END IF + IF (('N'.EQ.otransb).OR.('n'.EQ.otransb)) THEN + CALL libxsmm_dgemm('N', transb, & + & SIZE(c, 1), SIZE(c, 2), SIZE(a, 2), & + & alpha, a, SIZE(a, 1), & + & b, SIZE(b, 1), beta, c, SIZE(c, 1)) + ELSE + ! TODO: transpose is currently not supported by LIBXSMM + CALL libxsmm_dgemm('N', 'N', & + & SIZE(c, 1), SIZE(c, 2), SIZE(a, 2), & + & alpha, a, SIZE(a, 1), & + & TRANSPOSE(b), SIZE(b, 2), beta, c, SIZE(c, 1)) + END IF + END SUBROUTINE END MODULE diff --git a/version.txt b/version.txt index c6af7f4777..b8146b7606 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -master-1.4.1-17 +master-1.4.1-18