diff --git a/flang/include/flang/Runtime/c-or-cpp.h b/flang/include/flang/Runtime/c-or-cpp.h index 4babd885cad32..8bac523907750 100644 --- a/flang/include/flang/Runtime/c-or-cpp.h +++ b/flang/include/flang/Runtime/c-or-cpp.h @@ -13,11 +13,13 @@ #define IF_CPLUSPLUS(x) x #define IF_NOT_CPLUSPLUS(x) #define DEFAULT_VALUE(x) = (x) +#define RESTRICT __restrict #else #include #define IF_CPLUSPLUS(x) #define IF_NOT_CPLUSPLUS(x) x #define DEFAULT_VALUE(x) +#define RESTRICT restrict #endif #define FORTRAN_EXTERN_C_BEGIN IF_CPLUSPLUS(extern "C" {) diff --git a/flang/include/flang/Runtime/descriptor.h b/flang/include/flang/Runtime/descriptor.h index 2b927df3bcd29..75c5e2176d929 100644 --- a/flang/include/flang/Runtime/descriptor.h +++ b/flang/include/flang/Runtime/descriptor.h @@ -304,7 +304,10 @@ class Descriptor { bool IsContiguous(int leadingDimensions = maxRank) const { auto bytes{static_cast(ElementBytes())}; - for (int j{0}; j < leadingDimensions && j < raw_.rank; ++j) { + if (leadingDimensions > raw_.rank) { + leadingDimensions = raw_.rank; + } + for (int j{0}; j < leadingDimensions; ++j) { const Dimension &dim{GetDimension(j)}; if (bytes != dim.ByteStride()) { return false; diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp index db790c392c11b..4b8029768b950 100644 --- a/flang/runtime/dot-product.cpp +++ b/flang/runtime/dot-product.cpp @@ -15,21 +15,29 @@ namespace Fortran::runtime { -template +// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first +// argument; MATMUL does not. + +// General accumulator for any type and stride; this is not used for +// contiguous numeric vectors. +template class Accumulator { public: - using Result = RESULT; + using Result = AccumulationType; Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} - void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { - if constexpr (XCAT == TypeCategory::Complex) { - sum_ += std::conj(static_cast(*x_.Element(&xAt))) * - static_cast(*y_.Element(&yAt)); - } else if constexpr (XCAT == TypeCategory::Logical) { + void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { + if constexpr (RCAT == TypeCategory::Logical) { sum_ = sum_ || (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); } else { - sum_ += static_cast(*x_.Element(&xAt)) * - static_cast(*y_.Element(&yAt)); + const XT &xElement{*x_.Element(&xAt)}; + const YT &yElement{*y_.Element(&yAt)}; + if constexpr (RCAT == TypeCategory::Complex) { + sum_ += std::conj(static_cast(xElement)) * + static_cast(yElement); + } else { + sum_ += static_cast(xElement) * static_cast(yElement); + } } } Result GetResult() const { return sum_; } @@ -39,9 +47,10 @@ class Accumulator { Result sum_{}; }; -template -static inline RESULT DoDotProduct( +template +static inline CppTypeFor DoDotProduct( const Descriptor &x, const Descriptor &y, Terminator &terminator) { + using Result = CppTypeFor; RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); SubscriptValue n{x.GetDimension(0).Extent()}; if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { @@ -49,24 +58,48 @@ static inline RESULT DoDotProduct( "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", static_cast(n), static_cast(yN)); } - if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) { - // TODO: call BLAS-1 SDOT or SDSDOT - } else if constexpr (std::is_same_v) { - // TODO: call BLAS-1 DDOT - } else if constexpr (std::is_same_v>) { - // TODO: call BLAS-1 CDOTC - } else if constexpr (std::is_same_v>) { - // TODO: call BLAS-1 ZDOTC + if constexpr (RCAT != TypeCategory::Logical) { + if (x.GetDimension(0).ByteStride() == sizeof(XT) && + y.GetDimension(0).ByteStride() == sizeof(YT)) { + // Contiguous numeric vectors + if constexpr (std::is_same_v) { + // Contiguous homogeneous numeric vectors + if constexpr (std::is_same_v) { + // TODO: call BLAS-1 SDOT or SDSDOT + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-1 DDOT + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-1 CDOTC + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-1 ZDOTC + } + } + XT *xp{x.OffsetElement(0)}; + YT *yp{y.OffsetElement(0)}; + using AccumType = AccumulationType; + AccumType accum{}; + if constexpr (RCAT == TypeCategory::Complex) { + for (SubscriptValue j{0}; j < n; ++j) { + accum += std::conj(static_cast(*xp++)) * + static_cast(*yp++); + } + } else { + for (SubscriptValue j{0}; j < n; ++j) { + accum += + static_cast(*xp++) * static_cast(*yp++); + } + } + return static_cast(accum); } } + // Non-contiguous, heterogeneous, & LOGICAL cases SubscriptValue xAt{x.GetDimension(0).LowerBound()}; SubscriptValue yAt{y.GetDimension(0).LowerBound()}; - Accumulator accumulator{x, y}; + Accumulator accumulator{x, y}; for (SubscriptValue j{0}; j < n; ++j) { - accumulator.Accumulate(xAt++, yAt++); + accumulator.AccumulateIndexed(xAt++, yAt++); } - return accumulator.GetResult(); + return static_cast(accumulator.GetResult()); } template struct DotProduct { @@ -79,7 +112,7 @@ template struct DotProduct { GetResultType(XCAT, XKIND, YCAT, YKIND)}) { if constexpr (resultType->first == RCAT && resultType->second <= RKIND) { - return DoDotProduct, + return DoDotProduct, CppTypeFor>(x, y, terminator); } } @@ -97,26 +130,32 @@ template struct DotProduct { Result operator()(const Descriptor &x, const Descriptor &y, const char *source, int line) const { Terminator terminator{source, line}; - auto xCatKind{x.type().GetCategoryAndKind()}; - auto yCatKind{y.type().GetCategoryAndKind()}; - RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); - return ApplyType(xCatKind->first, xCatKind->second, terminator, - x, y, terminator, yCatKind->first, yCatKind->second); + if (RCAT != TypeCategory::Logical && x.type() == y.type()) { + // No conversions needed, operands and result have same known type + return typename DP1::template DP2{}( + x, y, terminator); + } else { + auto xCatKind{x.type().GetCategoryAndKind()}; + auto yCatKind{y.type().GetCategoryAndKind()}; + RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); + return ApplyType(xCatKind->first, xCatKind->second, + terminator, x, y, terminator, yCatKind->first, yCatKind->second); + } } }; extern "C" { std::int8_t RTNAME(DotProductInteger1)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}(x, y, source, line); + return DotProduct{}(x, y, source, line); } std::int16_t RTNAME(DotProductInteger2)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}(x, y, source, line); + return DotProduct{}(x, y, source, line); } std::int32_t RTNAME(DotProductInteger4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}(x, y, source, line); + return DotProduct{}(x, y, source, line); } std::int64_t RTNAME(DotProductInteger8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { @@ -130,9 +169,10 @@ common::int128_t RTNAME(DotProductInteger16)( #endif // TODO: REAL/COMPLEX(2 & 3) +// Intermediate results and operations are at least 64 bits float RTNAME(DotProductReal4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}(x, y, source, line); + return DotProduct{}(x, y, source, line); } double RTNAME(DotProductReal8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { @@ -152,7 +192,7 @@ long double RTNAME(DotProductReal16)( void RTNAME(CppDotProductComplex4)(std::complex &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { - auto z{DotProduct{}(x, y, source, line)}; + auto z{DotProduct{}(x, y, source, line)}; result = std::complex{ static_cast(z.real()), static_cast(z.imag())}; } diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp index ec1581456fcb9..2d0459c0f35a9 100644 --- a/flang/runtime/matmul.cpp +++ b/flang/runtime/matmul.cpp @@ -22,19 +22,19 @@ #include "flang/Runtime/matmul.h" #include "terminator.h" #include "tools.h" +#include "flang/Runtime/c-or-cpp.h" #include "flang/Runtime/cpp-type.h" #include "flang/Runtime/descriptor.h" +#include namespace Fortran::runtime { +// General accumulator for any type and stride; this is not used for +// contiguous numeric cases. template class Accumulator { public: - // Accumulate floating-point results in (at least) double precision - using Result = CppTypeFor(sizeof(double))) - : RKIND>; + using Result = AccumulationType; Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { if constexpr (RCAT == TypeCategory::Logical) { @@ -52,6 +52,103 @@ class Accumulator { Result sum_{}; }; +// Contiguous numeric matrix*matrix multiplication +// matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) +// Straightforward algorithm: +// DO 1 I = 1, NROWS +// DO 1 J = 1, NCOLS +// RES(I,J) = 0 +// DO 1 K = 1, N +// 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) +// With loop distribution and transposition to avoid the inner sum +// reduction and to avoid non-unit strides: +// DO 1 I = 1, NROWS +// DO 1 J = 1, NCOLS +// 1 RES(I,J) = 0 +// DO 2 K = 1, N +// DO 2 J = 1, NCOLS +// DO 2 I = 1, NROWS +// 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term +template +inline void MatrixTimesMatrix(CppTypeFor *RESTRICT product, + SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x, + const YT *RESTRICT y, SubscriptValue n) { + using ResultType = CppTypeFor; + std::memset(product, 0, rows * cols * sizeof *product); + const XT *RESTRICT xp0{x}; + for (SubscriptValue k{0}; k < n; ++k) { + ResultType *RESTRICT p{product}; + for (SubscriptValue j{0}; j < cols; ++j) { + const XT *RESTRICT xp{xp0}; + auto yv{static_cast(y[k + j * n])}; + for (SubscriptValue i{0}; i < rows; ++i) { + *p++ += static_cast(*xp++) * yv; + } + } + xp0 += rows; + } +} + +// Contiguous numeric matrix*vector multiplication +// matrix(rows,n) * column vector(n) -> column vector(rows) +// Straightforward algorithm: +// DO 1 J = 1, NROWS +// RES(J) = 0 +// DO 1 K = 1, N +// 1 RES(J) = RES(J) + X(J,K)*Y(K) +// With loop distribution and transposition to avoid the inner +// sum reduction and to avoid non-unit strides: +// DO 1 J = 1, NROWS +// 1 RES(J) = 0 +// DO 2 K = 1, N +// DO 2 J = 1, NROWS +// 2 RES(J) = RES(J) + X(J,K)*Y(K) +template +inline void MatrixTimesVector(CppTypeFor *RESTRICT product, + SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x, + const YT *RESTRICT y) { + using ResultType = CppTypeFor; + std::memset(product, 0, rows * sizeof *product); + for (SubscriptValue k{0}; k < n; ++k) { + ResultType *RESTRICT p{product}; + auto yv{static_cast(*y++)}; + for (SubscriptValue j{0}; j < rows; ++j) { + *p++ += static_cast(*x++) * yv; + } + } +} + +// Contiguous numeric vector*matrix multiplication +// row vector(n) * matrix(n,cols) -> row vector(cols) +// Straightforward algorithm: +// DO 1 J = 1, NCOLS +// RES(J) = 0 +// DO 1 K = 1, N +// 1 RES(J) = RES(J) + X(K)*Y(K,J) +// With loop distribution and transposition to avoid the inner +// sum reduction and one non-unit stride (the other remains): +// DO 1 J = 1, NCOLS +// 1 RES(J) = 0 +// DO 2 K = 1, N +// DO 2 J = 1, NCOLS +// 2 RES(J) = RES(J) + X(K)*Y(K,J) +template +inline void VectorTimesMatrix(CppTypeFor *RESTRICT product, + SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x, + const YT *RESTRICT y) { + using ResultType = CppTypeFor; + std::memset(product, 0, cols * sizeof *product); + for (SubscriptValue k{0}; k < n; ++k) { + ResultType *RESTRICT p{product}; + auto xv{static_cast(*x++)}; + const YT *RESTRICT yp{&y[k]}; + for (SubscriptValue j{0}; j < cols; ++j) { + *p++ += xv * static_cast(*yp); + yp += n; + } + } +} + // Implements an instance of MATMUL for given argument types. template @@ -79,36 +176,82 @@ static inline void DoMatmul( } } else { RUNTIME_CHECK(terminator, resRank == result.rank()); - RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND})); + RUNTIME_CHECK( + terminator, result.ElementBytes() == static_cast(RKIND)); RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); RUNTIME_CHECK(terminator, resRank == 1 || result.GetDimension(1).Extent() == extent[1]); } - using WriteResult = - CppTypeFor; SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; if (n != y.GetDimension(0).Extent()) { terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)", static_cast(n), static_cast(y.GetDimension(0).Extent())); } + using WriteResult = + CppTypeFor; + if constexpr (RCAT != TypeCategory::Logical) { + if (x.IsContiguous() && y.IsContiguous() && + (IS_ALLOCATING || result.IsContiguous())) { + // Contiguous numeric matrices + if (resRank == 2) { // M*M -> M + if (std::is_same_v) { + if constexpr (std::is_same_v) { + // TODO: call BLAS-3 SGEMM + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-3 DGEMM + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-3 CGEMM + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-3 ZGEMM + } + } + MatrixTimesMatrix( + result.template OffsetElement(), extent[0], extent[1], + x.OffsetElement(), y.OffsetElement(), n); + return; + } else if (xRank == 2) { // M*V -> V + if (std::is_same_v) { + if constexpr (std::is_same_v) { + // TODO: call BLAS-2 SGEMV(x,y) + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-2 DGEMV(x,y) + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-2 CGEMV(x,y) + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-2 ZGEMV(x,y) + } + } + MatrixTimesVector( + result.template OffsetElement(), extent[0], n, + x.OffsetElement(), y.OffsetElement()); + return; + } else { // V*M -> V + if (std::is_same_v) { + if constexpr (std::is_same_v) { + // TODO: call BLAS-2 SGEMV(y,x) + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-2 DGEMV(y,x) + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-2 CGEMV(y,x) + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-2 ZGEMV(y,x) + } + } + VectorTimesMatrix( + result.template OffsetElement(), n, extent[0], + x.OffsetElement(), y.OffsetElement()); + return; + } + } + } + // General algorithms for LOGICAL and noncontiguity SubscriptValue xAt[2], yAt[2], resAt[2]; x.GetLowerBounds(xAt); y.GetLowerBounds(yAt); result.GetLowerBounds(resAt); if (resRank == 2) { // M*M -> M - if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) { - // TODO: call BLAS-3 SGEMM - } else if constexpr (std::is_same_v) { - // TODO: call BLAS-3 DGEMM - } else if constexpr (std::is_same_v>) { - // TODO: call BLAS-3 CGEMM - } else if constexpr (std::is_same_v>) { - // TODO: call BLAS-3 ZGEMM - } - } SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; for (SubscriptValue i{0}; i < extent[0]; ++i) { for (SubscriptValue j{0}; j < extent[1]; ++j) { @@ -125,44 +268,31 @@ static inline void DoMatmul( ++resAt[0]; ++xAt[0]; } - } else { - if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) { - // TODO: call BLAS-2 SGEMV - } else if constexpr (std::is_same_v) { - // TODO: call BLAS-2 DGEMV - } else if constexpr (std::is_same_v>) { - // TODO: call BLAS-2 CGEMV - } else if constexpr (std::is_same_v>) { - // TODO: call BLAS-2 ZGEMV + } else if (xRank == 2) { // M*V -> V + SubscriptValue x1{xAt[1]}, y0{yAt[0]}; + for (SubscriptValue j{0}; j < extent[0]; ++j) { + Accumulator accumulator{x, y}; + for (SubscriptValue k{0}; k < n; ++k) { + xAt[1] = x1 + k; + yAt[0] = y0 + k; + accumulator.Accumulate(xAt, yAt); } + *result.template Element(resAt) = accumulator.GetResult(); + ++resAt[0]; + ++xAt[0]; } - if (xRank == 2) { // M*V -> V - SubscriptValue x1{xAt[1]}, y0{yAt[0]}; - for (SubscriptValue j{0}; j < extent[0]; ++j) { - Accumulator accumulator{x, y}; - for (SubscriptValue k{0}; k < n; ++k) { - xAt[1] = x1 + k; - yAt[0] = y0 + k; - accumulator.Accumulate(xAt, yAt); - } - *result.template Element(resAt) = accumulator.GetResult(); - ++resAt[0]; - ++xAt[0]; - } - } else { // V*M -> V - SubscriptValue x0{xAt[0]}, y0{yAt[0]}; - for (SubscriptValue j{0}; j < extent[0]; ++j) { - Accumulator accumulator{x, y}; - for (SubscriptValue k{0}; k < n; ++k) { - xAt[0] = x0 + k; - yAt[0] = y0 + k; - accumulator.Accumulate(xAt, yAt); - } - *result.template Element(resAt) = accumulator.GetResult(); - ++resAt[0]; - ++yAt[1]; + } else { // V*M -> V + SubscriptValue x0{xAt[0]}, y0{yAt[0]}; + for (SubscriptValue j{0}; j < extent[0]; ++j) { + Accumulator accumulator{x, y}; + for (SubscriptValue k{0}; k < n; ++k) { + xAt[0] = x0 + k; + yAt[0] = y0 + k; + accumulator.Accumulate(xAt, yAt); } + *result.template Element(resAt) = accumulator.GetResult(); + ++resAt[0]; + ++yAt[1]; } } } diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h index ee2641b305b05..3e0a68b180172 100644 --- a/flang/runtime/tools.h +++ b/flang/runtime/tools.h @@ -334,5 +334,12 @@ std::optional> inline constexpr GetResultType( return std::nullopt; } +// Accumulate floating-point results in (at least) double precision +template +using AccumulationType = CppTypeFor(sizeof(double))) + : KIND>; + } // namespace Fortran::runtime #endif // FORTRAN_RUNTIME_TOOLS_H_