Skip to content

Commit

Permalink
[flang] Fold MATMUL() (#72176)
Browse files Browse the repository at this point in the history
Implements constant folding for matrix multiplication for all four
accepted type categories.
  • Loading branch information
klausler committed Nov 14, 2023
1 parent 2602d88 commit 0fdf912
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 6 deletions.
4 changes: 3 additions & 1 deletion flang/lib/Evaluate/fold-complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "fold-implementation.h"
#include "fold-matmul.h"
#include "fold-reduction.h"

namespace Fortran::evaluate {
Expand Down Expand Up @@ -64,13 +65,14 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
}
} else if (name == "dot_product") {
return FoldDotProduct<T>(context, std::move(funcRef));
} else if (name == "matmul") {
return FoldMatmul(context, std::move(funcRef));
} else if (name == "product") {
auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});
} else if (name == "sum") {
return FoldSum<T>(context, std::move(funcRef));
}
// TODO: matmul
return Expr<T>{std::move(funcRef)};
}

Expand Down
4 changes: 3 additions & 1 deletion flang/lib/Evaluate/fold-integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "fold-implementation.h"
#include "fold-matmul.h"
#include "fold-reduction.h"
#include "flang/Evaluate/check-expression.h"

Expand Down Expand Up @@ -1042,6 +1043,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
return fptr(static_cast<int>(places.ToInt64()));
}));
} else if (name == "matmul") {
return FoldMatmul(context, std::move(funcRef));
} else if (name == "max") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
} else if (name == "max0" || name == "max1") {
Expand Down Expand Up @@ -1279,7 +1282,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "ubound") {
return UBOUND(context, std::move(funcRef));
}
// TODO: dot_product, matmul, sign
return Expr<T>{std::move(funcRef)};
}

Expand Down
4 changes: 3 additions & 1 deletion flang/lib/Evaluate/fold-logical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "fold-implementation.h"
#include "fold-matmul.h"
#include "fold-reduction.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Runtime/magic-numbers.h"
Expand Down Expand Up @@ -231,6 +232,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
return Fold(context, ConvertToType<T>(std::move(*expr)));
}
} else if (name == "matmul") {
return FoldMatmul(context, std::move(funcRef));
} else if (name == "out_of_range") {
if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
auto restorer{context.messages().DiscardMessages()};
Expand Down Expand Up @@ -367,7 +370,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
name == "__builtin_ieee_support_underflow_control") {
return Expr<T>{true};
}
// TODO: logical, matmul, parity
return Expr<T>{std::move(funcRef)};
}

Expand Down
103 changes: 103 additions & 0 deletions flang/lib/Evaluate/fold-matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
#define FORTRAN_EVALUATE_FOLD_MATMUL_H_

#include "fold-implementation.h"

namespace Fortran::evaluate {

template <typename T>
static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
using Element = typename Constant<T>::Element;
auto args{funcRef.arguments()};
CHECK(args.size() == 2);
Folder<T> folder{context};
Constant<T> *ma{folder.Folding(args[0])};
Constant<T> *mb{folder.Folding(args[1])};
if (!ma || !mb) {
return Expr<T>{std::move(funcRef)};
}
CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
ConstantSubscript commonExtent{ma->shape().back()};
if (mb->shape().front() != commonExtent) {
context.messages().Say(
"Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
commonExtent, mb->shape().front());
return MakeInvalidIntrinsic(std::move(funcRef));
}
ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
std::vector<Element> elements;
elements.reserve(rows * columns);
bool overflow{false};
[[maybe_unused]] const auto &rounding{
context.targetCharacteristics().roundingMode()};
// result(j,k) = SUM(A(j,:) * B(:,k))
for (ConstantSubscript ci{0}; ci < columns; ++ci) {
for (ConstantSubscript ri{0}; ri < rows; ++ri) {
ConstantSubscripts aAt{ma->lbounds()};
if (ma->Rank() == 2) {
aAt[0] += ri;
}
ConstantSubscripts bAt{mb->lbounds()};
if (mb->Rank() == 2) {
bAt[1] += ci;
}
Element sum{};
[[maybe_unused]] Element correction{};
for (ConstantSubscript j{0}; j < commonExtent; ++j) {
Element aElt{ma->At(aAt)};
Element bElt{mb->At(bAt)};
if constexpr (T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex) {
// Kahan summation
auto product{aElt.Multiply(bElt, rounding)};
overflow |= product.flags.test(RealFlag::Overflow);
auto next{correction.Add(product.value, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else if constexpr (T::category == TypeCategory::Integer) {
auto product{aElt.MultiplySigned(bElt)};
overflow |= product.SignedMultiplicationOverflowed();
auto added{sum.AddSigned(product.lower)};
overflow |= added.overflow;
sum = std::move(added.value);
} else {
static_assert(T::category == TypeCategory::Logical);
sum = sum.OR(aElt.AND(bElt));
}
++aAt.back();
++bAt.front();
}
elements.push_back(sum);
}
}
if (overflow) {
context.messages().Say(
"MATMUL of %s data overflowed during computation"_warn_en_US,
T::AsFortran());
}
ConstantSubscripts shape;
if (ma->Rank() == 2) {
shape.push_back(rows);
}
if (mb->Rank() == 2) {
shape.push_back(columns);
}
return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
}
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_
4 changes: 3 additions & 1 deletion flang/lib/Evaluate/fold-real.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "fold-implementation.h"
#include "fold-matmul.h"
#include "fold-reduction.h"

namespace Fortran::evaluate {
Expand Down Expand Up @@ -269,6 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
}
return result.value;
}));
} else if (name == "matmul") {
return FoldMatmul(context, std::move(funcRef));
} else if (name == "max") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
} else if (name == "maxval") {
Expand Down Expand Up @@ -446,7 +449,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return result.value;
}));
}
// TODO: matmul
return Expr<T>{std::move(funcRef)};
}

Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Evaluate/fold-reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static Expr<T> FoldDotProduct(
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
Element correction; // Use Kahan summation for greater precision.
Element correction{}; // Use Kahan summation for greater precision.
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
Expand Down Expand Up @@ -80,7 +80,7 @@ static Expr<T> FoldDotProduct(
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
Element correction; // Use Kahan summation for greater precision.
Element correction{}; // Use Kahan summation for greater precision.
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
Expand Down
41 changes: 41 additions & 0 deletions flang/test/Evaluate/fold-matmul.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
! Tests folding of MATMUL()
module m
integer, parameter :: ia(2,3) = reshape([1, 2, 2, 3, 3, 4], shape(ia))
integer, parameter :: ib(3,2) = reshape([1, 2, 3, 2, 3, 4], shape(ib))
integer, parameter :: ix(*) = [1, 2]
integer, parameter :: iy(*) = [1, 2, 3]
integer, parameter :: iab(*,*) = matmul(ia, ib)
integer, parameter :: ixa(*) = matmul(ix, ia)
integer, parameter :: iay(*) = matmul(ia, iy)
logical, parameter :: test_iab = all([iab] == [14, 20, 20, 29])
logical, parameter :: test_ixa = all(ixa == [5, 8, 11])
logical, parameter :: test_iay = all(iay == [14, 20])

real, parameter :: ra(*,*) = ia
real, parameter :: rb(*,*) = ib
real, parameter :: rx(*) = ix
real, parameter :: ry(*) = iy
real, parameter :: rab(*,*) = matmul(ra, rb)
real, parameter :: rxa(*) = matmul(rx, ra)
real, parameter :: ray(*) = matmul(ra, ry)
logical, parameter :: test_rab = all(rab == iab)
logical, parameter :: test_rxa = all(rxa == ixa)
logical, parameter :: test_ray = all(ray == iay)

complex, parameter :: za(*,*) = cmplx(ra, -1.)
complex, parameter :: zb(*,*) = cmplx(rb, -1.)
complex, parameter :: zx(*) = cmplx(rx, -1.)
complex, parameter :: zy(*) = cmplx(ry, -1.)
complex, parameter :: zab(*,*) = matmul(za, zb)
complex, parameter :: zxa(*) = matmul(zx, za)
complex, parameter :: zay(*) = matmul(za, zy)
logical, parameter :: test_zab = all([zab] == [(11,-12),(17,-15),(17,-15),(26,-18)])
logical, parameter :: test_zxa = all(zxa == [(3,-6),(6,-8),(9,-10)])
logical, parameter :: test_zay = all(zay == [(11,-12),(17,-15)])

logical, parameter :: la(16, 4) = reshape([((iand(shiftr(j,k),1)/=0, j=0,15), k=0,3)], shape(la))
logical, parameter :: lb(4, 16) = transpose(la)
logical, parameter :: lab(16, 16) = matmul(la, lb)
logical, parameter :: test_lab = all([lab] .eqv. [((iand(k,j)/=0, k=0,15), j=0,15)])
end

0 comments on commit 0fdf912

Please sign in to comment.