From 8cf8f527fec3a8b1018e9507b864330f07fe98f4 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 6 Feb 2025 17:14:22 -0800 Subject: [PATCH] [flang] Handle !dir$ unroll [01] When an explicit `N` is passed to the loop unroll directive, the unrolling factor should N if N > 1, and unrolling should be disabled when N is 0 or 1. Update docs and add test cases. --- flang/docs/Directives.md | 9 +++++- flang/lib/Lower/Bridge.cpp | 47 +++++++++++++++++++++++-------- flang/test/Integration/unroll.f90 | 45 +++++++++++++++++++++++++---- 3 files changed, 82 insertions(+), 19 deletions(-) diff --git a/flang/docs/Directives.md b/flang/docs/Directives.md index f356f762b13a2..c6c2e29a420ea 100644 --- a/flang/docs/Directives.md +++ b/flang/docs/Directives.md @@ -39,15 +39,22 @@ A list of non-standard directives supported by Flang * `!dir$ vector always` forces vectorization on the following loop regardless of cost model decisions. The loop must still be vectorizable. [This directive currently only works on plain do loops without labels]. +* `!dir$ unroll [n]` specifies that the compiler ought to unroll the immediately + following loop `n` times. When `n` is `0` or `1`, the loop should not be unrolled + at all. When `n` is `2` or greater, the loop should be unrolled exactly `n` + times if possible. When `n` is omitted, the compiler should attempt to fully + unroll the loop. Some compilers accept an optional `=` before the `n` when `n` + is present in the directive. Flang does not. # Directive Details ## Introduction Directives are commonly used in Fortran programs to specify additional actions to be performed by the compiler. The directives are always specified with the -`!dir$` or `cdir$` prefix. +`!dir$` or `cdir$` prefix. ## Loop Directives + Some directives are associated with the following construct, for example loop directives. Directives on loops are used to specify additional transformation to be performed by the compiler like enabling vectorisation, unrolling, interchange diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index a31629b17cf29..36e58e456dea3 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -63,6 +63,7 @@ #include "flang/Semantics/tools.h" #include "flang/Support/Version.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" @@ -2170,11 +2171,38 @@ class FirConverter : public Fortran::lower::AbstractConverter { return builder->createIntegerConstant(loc, controlType, 1); // step } + // For unroll directives without a value, force full unrolling. + // For unroll directives with a value, if the value is greater than 1, + // force unrolling with the given factor. Otherwise, disable unrolling. + mlir::LLVM::LoopUnrollAttr + genLoopUnrollAttr(std::optional directiveArg) { + mlir::BoolAttr falseAttr = + mlir::BoolAttr::get(builder->getContext(), false); + mlir::BoolAttr trueAttr = mlir::BoolAttr::get(builder->getContext(), true); + mlir::IntegerAttr countAttr; + mlir::BoolAttr fullUnrollAttr; + bool shouldUnroll = true; + if (directiveArg.has_value()) { + auto unrollingFactor = directiveArg.value(); + if (unrollingFactor == 0 || unrollingFactor == 1) { + shouldUnroll = false; + } else { + countAttr = + builder->getIntegerAttr(builder->getI64Type(), unrollingFactor); + } + } else { + fullUnrollAttr = trueAttr; + } + + mlir::BoolAttr disableAttr = shouldUnroll ? falseAttr : trueAttr; + return mlir::LLVM::LoopUnrollAttr::get( + builder->getContext(), /*disable=*/disableAttr, /*count=*/countAttr, {}, + /*full=*/fullUnrollAttr, {}, {}, {}); + } + void addLoopAnnotationAttr( IncrementLoopInfo &info, llvm::SmallVectorImpl &dirs) { - mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false); - mlir::BoolAttr t = mlir::BoolAttr::get(builder->getContext(), true); mlir::LLVM::LoopVectorizeAttr va; mlir::LLVM::LoopUnrollAttr ua; bool has_attrs = false; @@ -2182,20 +2210,15 @@ class FirConverter : public Fortran::lower::AbstractConverter { Fortran::common::visit( Fortran::common::visitors{ [&](const Fortran::parser::CompilerDirective::VectorAlways &) { + mlir::BoolAttr falseAttr = + mlir::BoolAttr::get(builder->getContext(), false); va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(), - /*disable=*/f, {}, {}, - {}, {}, {}, {}); + /*disable=*/falseAttr, + {}, {}, {}, {}, {}, {}); has_attrs = true; }, [&](const Fortran::parser::CompilerDirective::Unroll &u) { - mlir::IntegerAttr countAttr; - if (u.v.has_value()) { - countAttr = builder->getIntegerAttr(builder->getI64Type(), - u.v.value()); - } - ua = mlir::LLVM::LoopUnrollAttr::get( - builder->getContext(), /*disable=*/f, /*count*/ countAttr, - {}, /*full*/ u.v.has_value() ? f : t, {}, {}, {}); + ua = genLoopUnrollAttr(u.v); has_attrs = true; }, [&](const auto &) {}}, diff --git a/flang/test/Integration/unroll.f90 b/flang/test/Integration/unroll.f90 index 9d69605e10d1b..aa47e465b63fc 100644 --- a/flang/test/Integration/unroll.f90 +++ b/flang/test/Integration/unroll.f90 @@ -3,14 +3,47 @@ ! CHECK-LABEL: unroll_dir subroutine unroll_dir integer :: a(10) - !dir$ unroll - ! CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]] + !dir$ unroll + ! CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[UNROLL_ENABLE_FULL_ANNO:.*]] do i=1,10 - a(i)=i + a(i)=i end do end subroutine unroll_dir -! CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[UNROLL:.*]], ![[UNROLL_FULL:.*]]} -! CHECK: ![[UNROLL]] = !{!"llvm.loop.unroll.enable"} -! CHECK: ![[UNROLL_FULL]] = !{!"llvm.loop.unroll.full"} +! CHECK-LABEL: unroll_dir_0 +subroutine unroll_dir_0 + integer :: a(10) + !dir$ unroll 0 + ! CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[UNROLL_DISABLE_ANNO:.*]] + do i=1,10 + a(i)=i + end do +end subroutine unroll_dir_0 + +! CHECK-LABEL: unroll_dir_1 +subroutine unroll_dir_1 + integer :: a(10) + !dir$ unroll 1 + ! CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[UNROLL_DISABLE_ANNO]] + do i=1,10 + a(i)=i + end do +end subroutine unroll_dir_1 + +! CHECK-LABEL: unroll_dir_2 +subroutine unroll_dir_2 + integer :: a(10) + !dir$ unroll 2 + ! CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[UNROLL_ENABLE_COUNT_2:.*]] + do i=1,10 + a(i)=i + end do +end subroutine unroll_dir_2 +! CHECK: ![[UNROLL_ENABLE_FULL_ANNO]] = distinct !{![[UNROLL_ENABLE_FULL_ANNO]], ![[UNROLL_ENABLE:.*]], ![[UNROLL_FULL:.*]]} +! CHECK: ![[UNROLL_ENABLE:.*]] = !{!"llvm.loop.unroll.enable"} +! CHECK: ![[UNROLL_FULL:.*]] = !{!"llvm.loop.unroll.full"} +! CHECK: ![[UNROLL_DISABLE_ANNO]] = distinct !{![[UNROLL_DISABLE_ANNO]], ![[UNROLL_DISABLE:.*]]} +! CHECK: ![[UNROLL_DISABLE]] = !{!"llvm.loop.unroll.disable"} +! CHECK: ![[UNROLL_ENABLE_COUNT_2]] = distinct !{![[UNROLL_ENABLE_COUNT_2]], ![[UNROLL_ENABLE]], ![[UNROLL_COUNT_2:.*]]} +! CHECK: ![[UNROLL_COUNT_2]] = !{!"llvm.loop.unroll.count", i32 2}