diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp index 1a31ccc4591cd..43fcf7c084906 100644 --- a/flang/runtime/matmul-transpose.cpp +++ b/flang/runtime/matmul-transpose.cpp @@ -52,25 +52,64 @@ using namespace Fortran::runtime; // DO 2 I = 1, NROWS // DO 2 K = 1, N // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term -template +template inline static void MatrixTransposedTimesMatrix( CppTypeFor *RESTRICT product, SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, - SubscriptValue n) { + SubscriptValue n, std::size_t xColumnByteStride = 0, + std::size_t yColumnByteStride = 0) { using ResultType = CppTypeFor; std::memset(product, 0, rows * cols * sizeof *product); for (SubscriptValue j{0}; j < cols; ++j) { for (SubscriptValue i{0}; i < rows; ++i) { for (SubscriptValue k{0}; k < n; ++k) { - ResultType x_ki = static_cast(x[i * n + k]); - ResultType y_kj = static_cast(y[j * n + k]); + ResultType x_ki; + if constexpr (!X_HAS_STRIDED_COLUMNS) { + x_ki = static_cast(x[i * n + k]); + } else { + x_ki = static_cast(reinterpret_cast( + reinterpret_cast(x) + i * xColumnByteStride)[k]); + } + ResultType y_kj; + if constexpr (!Y_HAS_STRIDED_COLUMNS) { + y_kj = static_cast(y[j * n + k]); + } else { + y_kj = static_cast(reinterpret_cast( + reinterpret_cast(y) + j * yColumnByteStride)[k]); + } product[j * rows + i] += x_ki * y_kj; } } } } +template +inline static void MatrixTransposedTimesMatrixHelper( + CppTypeFor *RESTRICT product, SubscriptValue rows, + SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, + SubscriptValue n, std::optional xColumnByteStride, + std::optional yColumnByteStride) { + if (!xColumnByteStride) { + if (!yColumnByteStride) { + MatrixTransposedTimesMatrix( + product, rows, cols, x, y, n); + } else { + MatrixTransposedTimesMatrix( + product, rows, cols, x, y, n, 0, *yColumnByteStride); + } + } else { + if (!yColumnByteStride) { + MatrixTransposedTimesMatrix( + product, rows, cols, x, y, n, *xColumnByteStride); + } else { + MatrixTransposedTimesMatrix( + product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); + } + } +} + // Contiguous numeric matrix*vector multiplication // matrix(rows,n) * column vector(n) -> column vector(rows) // Straightforward algorithm: @@ -85,21 +124,43 @@ inline static void MatrixTransposedTimesMatrix( // DO 2 I = 1, NROWS // DO 2 K = 1, N // 2 RES(I) = RES(I) + X(K,I)*Y(K) -template +template inline static void MatrixTransposedTimesVector( CppTypeFor *RESTRICT product, SubscriptValue rows, - SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) { + SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, + std::size_t xColumnByteStride = 0) { using ResultType = CppTypeFor; std::memset(product, 0, rows * sizeof *product); for (SubscriptValue i{0}; i < rows; ++i) { for (SubscriptValue k{0}; k < n; ++k) { - ResultType x_ki = static_cast(x[i * n + k]); + ResultType x_ki; + if constexpr (!X_HAS_STRIDED_COLUMNS) { + x_ki = static_cast(x[i * n + k]); + } else { + x_ki = static_cast(reinterpret_cast( + reinterpret_cast(x) + i * xColumnByteStride)[k]); + } ResultType y_k = static_cast(y[k]); product[i] += x_ki * y_k; } } } +template +inline static void MatrixTransposedTimesVectorHelper( + CppTypeFor *RESTRICT product, SubscriptValue rows, + SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, + std::optional xColumnByteStride) { + if (!xColumnByteStride) { + MatrixTransposedTimesVector( + product, rows, n, x, y); + } else { + MatrixTransposedTimesVector( + product, rows, n, x, y, *xColumnByteStride); + } +} + // Implements an instance of MATMUL for given argument types. template @@ -149,19 +210,39 @@ inline static void DoMatmulTranspose( const SubscriptValue rows{extent[0]}; const SubscriptValue cols{extent[1]}; if constexpr (RCAT != TypeCategory::Logical) { - if (x.IsContiguous() && y.IsContiguous() && + if (x.IsContiguous(1) && y.IsContiguous(1) && (IS_ALLOCATING || result.IsContiguous())) { - // Contiguous numeric matrices + // Contiguous numeric matrices (maybe with columns + // separated by a stride). + std::optional xColumnByteStride; + if (!x.IsContiguous()) { + // X's columns are strided. + SubscriptValue xAt[2]{}; + x.GetLowerBounds(xAt); + xAt[1]++; + xColumnByteStride = x.SubscriptsToByteOffset(xAt); + } + std::optional yColumnByteStride; + if (!y.IsContiguous()) { + // Y's columns are strided. + SubscriptValue yAt[2]{}; + y.GetLowerBounds(yAt); + yAt[1]++; + yColumnByteStride = y.SubscriptsToByteOffset(yAt); + } if (resRank == 2) { // M*M -> M - MatrixTransposedTimesMatrix( + // TODO: use BLAS-3 GEMM for supported types. + MatrixTransposedTimesMatrixHelper( result.template OffsetElement(), rows, cols, - x.OffsetElement(), y.OffsetElement(), n); + x.OffsetElement(), y.OffsetElement(), n, xColumnByteStride, + yColumnByteStride); return; } if (xRank == 2) { // M*V -> V - MatrixTransposedTimesVector( + // TODO: use BLAS-2 GEMM for supported types. + MatrixTransposedTimesVectorHelper( result.template OffsetElement(), rows, n, - x.OffsetElement(), y.OffsetElement()); + x.OffsetElement(), y.OffsetElement(), xColumnByteStride); return; } // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp index df260e1fa5ebd..b46a94de01ced 100644 --- a/flang/runtime/matmul.cpp +++ b/flang/runtime/matmul.cpp @@ -69,10 +69,12 @@ class Accumulator { // DO 2 J = 1, NCOLS // DO 2 I = 1, NROWS // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term -template +template inline void MatrixTimesMatrix(CppTypeFor *RESTRICT product, SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x, - const YT *RESTRICT y, SubscriptValue n) { + const YT *RESTRICT y, SubscriptValue n, std::size_t xColumnByteStride = 0, + std::size_t yColumnByteStride = 0) { using ResultType = CppTypeFor; std::memset(product, 0, rows * cols * sizeof *product); const XT *RESTRICT xp0{x}; @@ -80,12 +82,48 @@ inline void MatrixTimesMatrix(CppTypeFor *RESTRICT product, ResultType *RESTRICT p{product}; for (SubscriptValue j{0}; j < cols; ++j) { const XT *RESTRICT xp{xp0}; - auto yv{static_cast(y[k + j * n])}; + ResultType yv; + if constexpr (!Y_HAS_STRIDED_COLUMNS) { + yv = static_cast(y[k + j * n]); + } else { + yv = static_cast(reinterpret_cast( + reinterpret_cast(y) + j * yColumnByteStride)[k]); + } for (SubscriptValue i{0}; i < rows; ++i) { *p++ += static_cast(*xp++) * yv; } } - xp0 += rows; + if constexpr (!X_HAS_STRIDED_COLUMNS) { + xp0 += rows; + } else { + xp0 = reinterpret_cast( + reinterpret_cast(xp0) + xColumnByteStride); + } + } +} + +template +inline void MatrixTimesMatrixHelper(CppTypeFor *RESTRICT product, + SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x, + const YT *RESTRICT y, SubscriptValue n, + std::optional xColumnByteStride, + std::optional yColumnByteStride) { + if (!xColumnByteStride) { + if (!yColumnByteStride) { + MatrixTimesMatrix( + product, rows, cols, x, y, n); + } else { + MatrixTimesMatrix( + product, rows, cols, x, y, n, 0, *yColumnByteStride); + } + } else { + if (!yColumnByteStride) { + MatrixTimesMatrix( + product, rows, cols, x, y, n, *xColumnByteStride); + } else { + MatrixTimesMatrix( + product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); + } } } @@ -103,18 +141,37 @@ inline void MatrixTimesMatrix(CppTypeFor *RESTRICT product, // DO 2 K = 1, N // DO 2 J = 1, NROWS // 2 RES(J) = RES(J) + X(J,K)*Y(K) -template +template inline void MatrixTimesVector(CppTypeFor *RESTRICT product, SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x, - const YT *RESTRICT y) { + const YT *RESTRICT y, std::size_t xColumnByteStride = 0) { using ResultType = CppTypeFor; std::memset(product, 0, rows * sizeof *product); + [[maybe_unused]] const XT *RESTRICT xp0{x}; for (SubscriptValue k{0}; k < n; ++k) { ResultType *RESTRICT p{product}; auto yv{static_cast(*y++)}; for (SubscriptValue j{0}; j < rows; ++j) { *p++ += static_cast(*x++) * yv; } + if constexpr (X_HAS_STRIDED_COLUMNS) { + xp0 = reinterpret_cast( + reinterpret_cast(xp0) + xColumnByteStride); + x = xp0; + } + } +} + +template +inline void MatrixTimesVectorHelper(CppTypeFor *RESTRICT product, + SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x, + const YT *RESTRICT y, std::optional xColumnByteStride) { + if (!xColumnByteStride) { + MatrixTimesVector(product, rows, n, x, y); + } else { + MatrixTimesVector( + product, rows, n, x, y, *xColumnByteStride); } } @@ -132,10 +189,11 @@ inline void MatrixTimesVector(CppTypeFor *RESTRICT product, // DO 2 K = 1, N // DO 2 J = 1, NCOLS // 2 RES(J) = RES(J) + X(K)*Y(K,J) -template +template inline void VectorTimesMatrix(CppTypeFor *RESTRICT product, SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x, - const YT *RESTRICT y) { + const YT *RESTRICT y, std::size_t yColumnByteStride = 0) { using ResultType = CppTypeFor; std::memset(product, 0, cols * sizeof *product); for (SubscriptValue k{0}; k < n; ++k) { @@ -144,11 +202,29 @@ inline void VectorTimesMatrix(CppTypeFor *RESTRICT product, const YT *RESTRICT yp{&y[k]}; for (SubscriptValue j{0}; j < cols; ++j) { *p++ += xv * static_cast(*yp); - yp += n; + if constexpr (!Y_HAS_STRIDED_COLUMNS) { + yp += n; + } else { + yp = reinterpret_cast( + reinterpret_cast(yp) + yColumnByteStride); + } } } } +template +inline void VectorTimesMatrixHelper(CppTypeFor *RESTRICT product, + SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x, + const YT *RESTRICT y, std::optional yColumnByteStride) { + if (!yColumnByteStride) { + VectorTimesMatrix(product, n, cols, x, y); + } else { + VectorTimesMatrix( + product, n, cols, x, y, *yColumnByteStride); + } +} + // Implements an instance of MATMUL for given argument types. template @@ -194,13 +270,35 @@ static inline void DoMatmul( CppTypeFor; if constexpr (RCAT != TypeCategory::Logical) { - if (x.IsContiguous() && y.IsContiguous() && + if (x.IsContiguous(1) && y.IsContiguous(1) && (IS_ALLOCATING || result.IsContiguous())) { - // Contiguous numeric matrices + // Contiguous numeric matrices (maybe with columns + // separated by a stride). + std::optional xColumnByteStride; + if (!x.IsContiguous()) { + // X's columns are strided. + SubscriptValue xAt[2]{}; + x.GetLowerBounds(xAt); + xAt[1]++; + xColumnByteStride = x.SubscriptsToByteOffset(xAt); + } + std::optional yColumnByteStride; + if (!y.IsContiguous()) { + // Y's columns are strided. + SubscriptValue yAt[2]{}; + y.GetLowerBounds(yAt); + yAt[1]++; + yColumnByteStride = y.SubscriptsToByteOffset(yAt); + } + // Note that BLAS GEMM can be used for the strided + // columns by setting proper leading dimension size. + // This implies that the column stride is divisible + // by the element size, which is usually true. if (resRank == 2) { // M*M -> M if (std::is_same_v) { if constexpr (std::is_same_v) { // TODO: call BLAS-3 SGEMM + // TODO: try using CUTLASS for device. } else if constexpr (std::is_same_v) { // TODO: call BLAS-3 DGEMM } else if constexpr (std::is_same_v>) { @@ -209,9 +307,10 @@ static inline void DoMatmul( // TODO: call BLAS-3 ZGEMM } } - MatrixTimesMatrix( + MatrixTimesMatrixHelper( result.template OffsetElement(), extent[0], extent[1], - x.OffsetElement(), y.OffsetElement(), n); + x.OffsetElement(), y.OffsetElement(), n, xColumnByteStride, + yColumnByteStride); return; } else if (xRank == 2) { // M*V -> V if (std::is_same_v) { @@ -225,9 +324,9 @@ static inline void DoMatmul( // TODO: call BLAS-2 ZGEMV(x,y) } } - MatrixTimesVector( + MatrixTimesVectorHelper( result.template OffsetElement(), extent[0], n, - x.OffsetElement(), y.OffsetElement()); + x.OffsetElement(), y.OffsetElement(), xColumnByteStride); return; } else { // V*M -> V if (std::is_same_v) { @@ -241,9 +340,9 @@ static inline void DoMatmul( // TODO: call BLAS-2 ZGEMV(y,x) } } - VectorTimesMatrix( + VectorTimesMatrixHelper( result.template OffsetElement(), n, extent[0], - x.OffsetElement(), y.OffsetElement()); + x.OffsetElement(), y.OffsetElement(), yColumnByteStride); return; } } diff --git a/flang/unittests/Runtime/Matmul.cpp b/flang/unittests/Runtime/Matmul.cpp index 30ce3d8a88825..1d6c5ccc609b4 100644 --- a/flang/unittests/Runtime/Matmul.cpp +++ b/flang/unittests/Runtime/Matmul.cpp @@ -27,6 +27,16 @@ TEST(Matmul, Basic) { std::vector{3, 2}, std::vector{6, 7, 8, 9, 10, 11})}; auto v{MakeArray( std::vector{2}, std::vector{-1, -2})}; + + // X2 0 2 4 Y2 -1 -1 + // 1 3 5 6 9 + // -1 -1 -1 7 10 + // 8 11 + auto x2{MakeArray(std::vector{3, 3}, + std::vector{0, 1, -1, 2, 3, -1, 4, 5})}; + auto y2{MakeArray(std::vector{4, 2}, + std::vector{-1, 6, 7, 8, -1, 9, 10, 11})}; + StaticDescriptor<2, true> statDesc; Descriptor &result{statDesc.descriptor()}; @@ -73,6 +83,98 @@ TEST(Matmul, Basic) { EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -30); result.Destroy(); + // Test non-contiguous sections. + static constexpr int sectionRank{2}; + StaticDescriptor sectionStaticDescriptorX2; + Descriptor §ionX2{sectionStaticDescriptorX2.descriptor()}; + sectionX2.Establish(x2->type(), x2->ElementBytes(), + /*p=*/nullptr, /*rank=*/sectionRank); + static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{2, 3}; + // Section of X2: + // +--------+ + // | 0 2 4| + // | 1 3 5| + // +--------+ + // -1 -1 -1 + const auto errorX2{CFI_section( + §ionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)}; + ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2; + + StaticDescriptor sectionStaticDescriptorY2; + Descriptor §ionY2{sectionStaticDescriptorY2.descriptor()}; + sectionY2.Establish(y2->type(), y2->ElementBytes(), + /*p=*/nullptr, /*rank=*/sectionRank); + static const SubscriptValue lowersY2[]{2, 1}; + // Section of Y2: + // -1 -1 + // +-----+ + // | 6 9| + // | 7 10| + // | 8 11| + // +-----+ + const auto errorY2{CFI_section(§ionY2.raw(), &y2->raw(), lowersY2, + /*uppers=*/nullptr, /*strides=*/nullptr)}; + ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2; + + RTNAME(Matmul)(result, sectionX2, *y, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(Matmul)(result, *x, sectionY2, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(Matmul)(result, sectionX2, sectionY2, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(Matmul)(result, *v, sectionX2, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), -2); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), -8); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -14); + result.Destroy(); + + RTNAME(Matmul)(result, sectionY2, *v, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), -24); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), -27); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -30); + result.Destroy(); + // X F F T Y F T // F T T F T // F F diff --git a/flang/unittests/Runtime/MatmulTranspose.cpp b/flang/unittests/Runtime/MatmulTranspose.cpp index 83db1328963a6..2362887c414ec 100644 --- a/flang/unittests/Runtime/MatmulTranspose.cpp +++ b/flang/unittests/Runtime/MatmulTranspose.cpp @@ -32,6 +32,17 @@ TEST(MatmulTranspose, Basic) { std::vector{0, 0, 0, 1, 1, 0, 1, 1})}; auto v{MakeArray( std::vector{2}, std::vector{-1, -2})}; + // X2 0 1 Y2 -1 -1 Z2 6 7 8 + // 2 3 6 9 9 10 11 + // 4 5 7 10 -1 -1 -1 + // -1 -1 8 11 + auto x2{MakeArray(std::vector{4, 2}, + std::vector{0, 2, 4, -1, 1, 3, 5, -1})}; + auto y2{MakeArray(std::vector{4, 2}, + std::vector{-1, 6, 7, 8, -1, 9, 10, 11})}; + auto z2{MakeArray(std::vector{3, 3}, + std::vector{6, 9, -1, 7, 10, -1, 8, 11, -1})}; + StaticDescriptor<2, true> statDesc; Descriptor &result{statDesc.descriptor()}; @@ -89,6 +100,104 @@ TEST(MatmulTranspose, Basic) { EXPECT_EQ(*result.ZeroBasedIndexedElement(11), 19); result.Destroy(); + // Test non-contiguous sections. + static constexpr int sectionRank{2}; + StaticDescriptor sectionStaticDescriptorX2; + Descriptor §ionX2{sectionStaticDescriptorX2.descriptor()}; + sectionX2.Establish(x2->type(), x2->ElementBytes(), + /*p=*/nullptr, /*rank=*/sectionRank); + static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{3, 2}; + // Section of X2: + // +-----+ + // | 0 1| + // | 2 3| + // | 4 5| + // +-----+ + // -1 -1 + const auto errorX2{CFI_section( + §ionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)}; + ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2; + + StaticDescriptor sectionStaticDescriptorY2; + Descriptor §ionY2{sectionStaticDescriptorY2.descriptor()}; + sectionY2.Establish(y2->type(), y2->ElementBytes(), + /*p=*/nullptr, /*rank=*/sectionRank); + static const SubscriptValue lowersY2[]{2, 1}; + // Section of Y2: + // -1 -1 + // +-----+ + // | 6 0| + // | 7 10| + // | 8 11| + // +-----+ + const auto errorY2{CFI_section(§ionY2.raw(), &y2->raw(), lowersY2, + /*uppers=*/nullptr, /*strides=*/nullptr)}; + ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2; + + StaticDescriptor sectionStaticDescriptorZ2; + Descriptor §ionZ2{sectionStaticDescriptorZ2.descriptor()}; + sectionZ2.Establish(z2->type(), z2->ElementBytes(), + /*p=*/nullptr, /*rank=*/sectionRank); + static const SubscriptValue lowersZ2[]{1, 1}, uppersZ2[]{2, 3}; + // Section of Z2: + // +--------+ + // | 6 7 8| + // | 9 10 11| + // +--------+ + // -1 -1 -1 + const auto errorZ2{CFI_section( + §ionZ2.raw(), &z2->raw(), lowersZ2, uppersZ2, /*strides=*/nullptr)}; + ASSERT_EQ(errorZ2, 0) << "CFI_section failed for Z2: " << errorZ2; + + RTNAME(MatmulTranspose)(result, sectionX2, *y, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(MatmulTranspose)(result, *x, sectionY2, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(MatmulTranspose)(result, sectionX2, sectionY2, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(MatmulTranspose)(result, sectionZ2, *v, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), -24); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), -27); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -30); + result.Destroy(); + // X F F Y F T V T F T // T F F T // T T F F