Skip to content

Commit

Permalink
[flang] Improved performance of runtime Matmul/MatmulTranspose.
Browse files Browse the repository at this point in the history
This patch mostly affects performance of the code produced by
HLIFR lowering. If MATMUL argument is an array slice, then
HLFIR lowering passes the slice to the runtime, whereas
FIR lowering would create a contiguous temporary for the slice.
Performance might be better than the generic implementation
for cases where the leading dimension is contiguous.
This patch improves CPU2000/178.galgel making HLFIR version
faster than FIR version (due to avoiding the temporary copies
for MATMUL arguments).

Reviewed By: klausler

Differential Revision: https://reviews.llvm.org/D159134
  • Loading branch information
vzakhari committed Aug 30, 2023
1 parent 8f48392 commit 4d97717
Show file tree
Hide file tree
Showing 4 changed files with 421 additions and 30 deletions.
107 changes: 94 additions & 13 deletions flang/runtime/matmul-transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,64 @@ using namespace Fortran::runtime;
// DO 2 I = 1, NROWS
// DO 2 K = 1, N
// 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
inline static void MatrixTransposedTimesMatrix(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
SubscriptValue n) {
SubscriptValue n, std::size_t xColumnByteStride = 0,
std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;

std::memset(product, 0, rows * cols * sizeof *product);
for (SubscriptValue j{0}; j < cols; ++j) {
for (SubscriptValue i{0}; i < rows; ++i) {
for (SubscriptValue k{0}; k < n; ++k) {
ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
ResultType y_kj = static_cast<ResultType>(y[j * n + k]);
ResultType x_ki;
if constexpr (!X_HAS_STRIDED_COLUMNS) {
x_ki = static_cast<ResultType>(x[i * n + k]);
} else {
x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]);
}
ResultType y_kj;
if constexpr (!Y_HAS_STRIDED_COLUMNS) {
y_kj = static_cast<ResultType>(y[j * n + k]);
} else {
y_kj = static_cast<ResultType>(reinterpret_cast<const YT *>(
reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
}
product[j * rows + i] += x_ki * y_kj;
}
}
}
}

template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline static void MatrixTransposedTimesMatrixHelper(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
SubscriptValue n, std::optional<std::size_t> xColumnByteStride,
std::optional<std::size_t> yColumnByteStride) {
if (!xColumnByteStride) {
if (!yColumnByteStride) {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
product, rows, cols, x, y, n);
} else {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
product, rows, cols, x, y, n, 0, *yColumnByteStride);
}
} else {
if (!yColumnByteStride) {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
product, rows, cols, x, y, n, *xColumnByteStride);
} else {
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
}
}
}

// Contiguous numeric matrix*vector multiplication
// matrix(rows,n) * column vector(n) -> column vector(rows)
// Straightforward algorithm:
Expand All @@ -85,21 +124,43 @@ inline static void MatrixTransposedTimesMatrix(
// DO 2 I = 1, NROWS
// DO 2 K = 1, N
// 2 RES(I) = RES(I) + X(K,I)*Y(K)
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS>
inline static void MatrixTransposedTimesVector(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) {
SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
std::size_t xColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * sizeof *product);
for (SubscriptValue i{0}; i < rows; ++i) {
for (SubscriptValue k{0}; k < n; ++k) {
ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
ResultType x_ki;
if constexpr (!X_HAS_STRIDED_COLUMNS) {
x_ki = static_cast<ResultType>(x[i * n + k]);
} else {
x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]);
}
ResultType y_k = static_cast<ResultType>(y[k]);
product[i] += x_ki * y_k;
}
}
}

template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline static void MatrixTransposedTimesVectorHelper(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
std::optional<std::size_t> xColumnByteStride) {
if (!xColumnByteStride) {
MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, false>(
product, rows, n, x, y);
} else {
MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, true>(
product, rows, n, x, y, *xColumnByteStride);
}
}

// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
Expand Down Expand Up @@ -149,19 +210,39 @@ inline static void DoMatmulTranspose(
const SubscriptValue rows{extent[0]};
const SubscriptValue cols{extent[1]};
if constexpr (RCAT != TypeCategory::Logical) {
if (x.IsContiguous() && y.IsContiguous() &&
if (x.IsContiguous(1) && y.IsContiguous(1) &&
(IS_ALLOCATING || result.IsContiguous())) {
// Contiguous numeric matrices
// Contiguous numeric matrices (maybe with columns
// separated by a stride).
std::optional<std::size_t> xColumnByteStride;
if (!x.IsContiguous()) {
// X's columns are strided.
SubscriptValue xAt[2]{};
x.GetLowerBounds(xAt);
xAt[1]++;
xColumnByteStride = x.SubscriptsToByteOffset(xAt);
}
std::optional<std::size_t> yColumnByteStride;
if (!y.IsContiguous()) {
// Y's columns are strided.
SubscriptValue yAt[2]{};
y.GetLowerBounds(yAt);
yAt[1]++;
yColumnByteStride = y.SubscriptsToByteOffset(yAt);
}
if (resRank == 2) { // M*M -> M
MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT>(
// TODO: use BLAS-3 GEMM for supported types.
MatrixTransposedTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), rows, cols,
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
yColumnByteStride);
return;
}
if (xRank == 2) { // M*V -> V
MatrixTransposedTimesVector<RCAT, RKIND, XT, YT>(
// TODO: use BLAS-2 GEMM for supported types.
MatrixTransposedTimesVectorHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), rows, n,
x.OffsetElement<XT>(), y.OffsetElement<YT>());
x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
return;
}
// else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
Expand Down
133 changes: 116 additions & 17 deletions flang/runtime/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,61 @@ class Accumulator {
// DO 2 J = 1, NCOLS
// DO 2 I = 1, NROWS
// 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y, SubscriptValue n) {
const YT *RESTRICT y, SubscriptValue n, std::size_t xColumnByteStride = 0,
std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * cols * sizeof *product);
const XT *RESTRICT xp0{x};
for (SubscriptValue k{0}; k < n; ++k) {
ResultType *RESTRICT p{product};
for (SubscriptValue j{0}; j < cols; ++j) {
const XT *RESTRICT xp{xp0};
auto yv{static_cast<ResultType>(y[k + j * n])};
ResultType yv;
if constexpr (!Y_HAS_STRIDED_COLUMNS) {
yv = static_cast<ResultType>(y[k + j * n]);
} else {
yv = static_cast<ResultType>(reinterpret_cast<const YT *>(
reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
}
for (SubscriptValue i{0}; i < rows; ++i) {
*p++ += static_cast<ResultType>(*xp++) * yv;
}
}
xp0 += rows;
if constexpr (!X_HAS_STRIDED_COLUMNS) {
xp0 += rows;
} else {
xp0 = reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(xp0) + xColumnByteStride);
}
}
}

template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline void MatrixTimesMatrixHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y, SubscriptValue n,
std::optional<std::size_t> xColumnByteStride,
std::optional<std::size_t> yColumnByteStride) {
if (!xColumnByteStride) {
if (!yColumnByteStride) {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
product, rows, cols, x, y, n);
} else {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
product, rows, cols, x, y, n, 0, *yColumnByteStride);
}
} else {
if (!yColumnByteStride) {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
product, rows, cols, x, y, n, *xColumnByteStride);
} else {
MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
}
}
}

Expand All @@ -103,18 +141,37 @@ inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
// DO 2 K = 1, N
// DO 2 J = 1, NROWS
// 2 RES(J) = RES(J) + X(J,K)*Y(K)
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool X_HAS_STRIDED_COLUMNS>
inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
const YT *RESTRICT y) {
const YT *RESTRICT y, std::size_t xColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * sizeof *product);
[[maybe_unused]] const XT *RESTRICT xp0{x};
for (SubscriptValue k{0}; k < n; ++k) {
ResultType *RESTRICT p{product};
auto yv{static_cast<ResultType>(*y++)};
for (SubscriptValue j{0}; j < rows; ++j) {
*p++ += static_cast<ResultType>(*x++) * yv;
}
if constexpr (X_HAS_STRIDED_COLUMNS) {
xp0 = reinterpret_cast<const XT *>(
reinterpret_cast<const char *>(xp0) + xColumnByteStride);
x = xp0;
}
}
}

template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
inline void MatrixTimesVectorHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
const YT *RESTRICT y, std::optional<std::size_t> xColumnByteStride) {
if (!xColumnByteStride) {
MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y);
} else {
MatrixTimesVector<RCAT, RKIND, XT, YT, true>(
product, rows, n, x, y, *xColumnByteStride);
}
}

Expand All @@ -132,10 +189,11 @@ inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
// DO 2 K = 1, N
// DO 2 J = 1, NCOLS
// 2 RES(J) = RES(J) + X(K)*Y(K,J)
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool Y_HAS_STRIDED_COLUMNS>
inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y) {
const YT *RESTRICT y, std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, cols * sizeof *product);
for (SubscriptValue k{0}; k < n; ++k) {
Expand All @@ -144,11 +202,29 @@ inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
const YT *RESTRICT yp{&y[k]};
for (SubscriptValue j{0}; j < cols; ++j) {
*p++ += xv * static_cast<ResultType>(*yp);
yp += n;
if constexpr (!Y_HAS_STRIDED_COLUMNS) {
yp += n;
} else {
yp = reinterpret_cast<const YT *>(
reinterpret_cast<const char *>(yp) + yColumnByteStride);
}
}
}
}

template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
bool SPARSE_COLUMNS = false>
inline void VectorTimesMatrixHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
const YT *RESTRICT y, std::optional<std::size_t> yColumnByteStride) {
if (!yColumnByteStride) {
VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y);
} else {
VectorTimesMatrix<RCAT, RKIND, XT, YT, true>(
product, n, cols, x, y, *yColumnByteStride);
}
}

// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
Expand Down Expand Up @@ -194,13 +270,35 @@ static inline void DoMatmul(
CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
RKIND>;
if constexpr (RCAT != TypeCategory::Logical) {
if (x.IsContiguous() && y.IsContiguous() &&
if (x.IsContiguous(1) && y.IsContiguous(1) &&
(IS_ALLOCATING || result.IsContiguous())) {
// Contiguous numeric matrices
// Contiguous numeric matrices (maybe with columns
// separated by a stride).
std::optional<std::size_t> xColumnByteStride;
if (!x.IsContiguous()) {
// X's columns are strided.
SubscriptValue xAt[2]{};
x.GetLowerBounds(xAt);
xAt[1]++;
xColumnByteStride = x.SubscriptsToByteOffset(xAt);
}
std::optional<std::size_t> yColumnByteStride;
if (!y.IsContiguous()) {
// Y's columns are strided.
SubscriptValue yAt[2]{};
y.GetLowerBounds(yAt);
yAt[1]++;
yColumnByteStride = y.SubscriptsToByteOffset(yAt);
}
// Note that BLAS GEMM can be used for the strided
// columns by setting proper leading dimension size.
// This implies that the column stride is divisible
// by the element size, which is usually true.
if (resRank == 2) { // M*M -> M
if (std::is_same_v<XT, YT>) {
if constexpr (std::is_same_v<XT, float>) {
// TODO: call BLAS-3 SGEMM
// TODO: try using CUTLASS for device.
} else if constexpr (std::is_same_v<XT, double>) {
// TODO: call BLAS-3 DGEMM
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
Expand All @@ -209,9 +307,10 @@ static inline void DoMatmul(
// TODO: call BLAS-3 ZGEMM
}
}
MatrixTimesMatrix<RCAT, RKIND, XT, YT>(
MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), extent[0], extent[1],
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
yColumnByteStride);
return;
} else if (xRank == 2) { // M*V -> V
if (std::is_same_v<XT, YT>) {
Expand All @@ -225,9 +324,9 @@ static inline void DoMatmul(
// TODO: call BLAS-2 ZGEMV(x,y)
}
}
MatrixTimesVector<RCAT, RKIND, XT, YT>(
MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), extent[0], n,
x.OffsetElement<XT>(), y.OffsetElement<YT>());
x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
return;
} else { // V*M -> V
if (std::is_same_v<XT, YT>) {
Expand All @@ -241,9 +340,9 @@ static inline void DoMatmul(
// TODO: call BLAS-2 ZGEMV(y,x)
}
}
VectorTimesMatrix<RCAT, RKIND, XT, YT>(
VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), n, extent[0],
x.OffsetElement<XT>(), y.OffsetElement<YT>());
x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride);
return;
}
}
Expand Down

0 comments on commit 4d97717

Please sign in to comment.