Skip to content

Commit

Permalink
[flang][hlfir] fix missing conversion in transpose simplification
Browse files Browse the repository at this point in the history
It seems just replacing the operation was not replacing all of the uses
when the types of the expression before and after this pass differ (due
to differing shape information). Now the shape information is always
kept the same.

This fixes #63399

Differential Revision: https://reviews.llvm.org/D153333
  • Loading branch information
tblah committed Jun 21, 2023
1 parent 9d796d0 commit 74adc3e
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 9 deletions.
4 changes: 3 additions & 1 deletion flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,13 @@ using ElementalKernelGenerator = std::function<hlfir::Entity(
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
/// Generate an hlfir.elementalOp given call back to generate the element
/// value at for each iteration.
/// If exprType is specified, this will be the return type of the elemental op
hlfir::ElementalOp genElementalOp(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel);
const ElementalKernelGenerator &genKernel,
mlir::Type exprType = mlir::Type{});

/// Structure to describe a loop nest.
struct LoopNest {
Expand Down
12 changes: 6 additions & 6 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,12 +722,12 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
isPolymorphic);
}

hlfir::ElementalOp
hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel) {
mlir::Type exprType = getArrayExprType(elementType, shape, false);
hlfir::ElementalOp hlfir::genElementalOp(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
mlir::Value shape, mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel, mlir::Type exprType) {
if (!exprType)
exprType = getArrayExprType(elementType, shape, false);
auto elementalOp =
builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
auto insertPt = builder.saveInsertionPoint();
Expand Down
11 changes: 9 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,16 @@ class TransposeAsElementalConversion
return val;
};
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, typeParams, genKernel);
loc, builder, elementType, resultShape, typeParams, genKernel,
transpose.getResult().getType());

rewriter.replaceOp(transpose, elementalOp.getResult());
// it wouldn't be safe to replace block arguments with a different
// hlfir.expr type. Types can differ due to differing amounts of shape
// information
assert(elementalOp.getResult().getType() ==
transpose.getResult().getType());

rewriter.replaceOp(transpose, elementalOp);
return mlir::success();
}

Expand Down
91 changes: 91 additions & 0 deletions flang/test/HLFIR/simplify-hlfir-intrinsics.fir
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,94 @@ func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
// CHECK: }
// CHECK: return
// CHECK: }

// expr with multiple uses
func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) {
%0 = hlfir.transpose %arg0 : (!hlfir.expr<2x2xf32>) -> !hlfir.expr<2x2xf32>
%1 = hlfir.shape_of %0 : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
%2 = hlfir.elemental %1 : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
^bb0(%arg2: index, %arg3: index):
%3 = hlfir.apply %0, %arg2, %arg3 : (!hlfir.expr<2x2xf32>, index, index) -> f32
%4 = math.cos %3 fastmath<contract> : f32
hlfir.yield_element %4 : f32
}
hlfir.assign %2 to %arg1 realloc : !hlfir.expr<2x2xf32>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
hlfir.destroy %2 : !hlfir.expr<2x2xf32>
hlfir.destroy %0 : !hlfir.expr<2x2xf32>
return
}
// CHECK-LABEL: func.func @transpose4(
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<2x2xf32>
// CHECK-SAME: %[[ARG1:.*]]:
// CHECK: %[[SHAPE0:.*]] = fir.shape
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
// CHECK: hlfir.yield_element %[[ELE]] : f32
// CHECK: }
// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.apply %[[TRANSPOSE]], %[[I]], %[[J]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
// CHECK: %[[COS_ELE:.*]] = math.cos %[[ELE]] fastmath<contract> : f32
// CHECK: hlfir.yield_element %[[COS_ELE]] : f32
// CHECK: }
// CHECK: hlfir.assign %[[COS]] to %[[ARG1]] realloc
// CHECK: hlfir.destroy %[[COS]] : !hlfir.expr<2x2xf32>
// CHECK: hlfir.destroy %[[TRANSPOSE]] : !hlfir.expr<2x2xf32>
// CHECK: return
// CHECK: }

// regression test
func.func @transpose5(%arg0: !fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>> {fir.host_assoc}) attributes {fir.internal_proc} {
%0 = fir.address_of(@_QFEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
%1:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>)
%c0_i32 = arith.constant 0 : i32
%2 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%3 = fir.load %2 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%4 = fir.box_addr %3 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
%c0 = arith.constant 0 : index
%5:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%c1 = arith.constant 1 : index
%6:3 = fir.box_dims %3, %c1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%7 = fir.shape %5#1, %6#1 : (index, index) -> !fir.shape<2>
%8:2 = hlfir.declare %4(%7) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
%c1_i32 = arith.constant 1 : i32
%9 = fir.coordinate_of %arg0, %c1_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%10 = fir.load %9 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
%11 = fir.box_addr %10 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
%c0_0 = arith.constant 0 : index
%12:3 = fir.box_dims %10, %c0_0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%c1_1 = arith.constant 1 : index
%13:3 = fir.box_dims %10, %c1_1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
%14 = fir.shape %12#1, %13#1 : (index, index) -> !fir.shape<2>
%15:2 = hlfir.declare %11(%14) {uniq_name = "_QFEc"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
%16 = hlfir.transpose %8#0 : (!fir.ref<!fir.array<2x2xf64>>) -> !hlfir.expr<2x2xf64>
%17 = hlfir.shape_of %16 : (!hlfir.expr<2x2xf64>) -> !fir.shape<2>
%18 = hlfir.elemental %17 : (!fir.shape<2>) -> !hlfir.expr<?x?xf64> {
^bb0(%arg1: index, %arg2: index):
%19 = hlfir.apply %16, %arg1, %arg2 : (!hlfir.expr<2x2xf64>, index, index) -> f64
%20 = math.cos %19 fastmath<contract> : f64
hlfir.yield_element %20 : f64
}
hlfir.assign %18 to %1#0 realloc : !hlfir.expr<?x?xf64>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
hlfir.destroy %18 : !hlfir.expr<?x?xf64>
hlfir.destroy %16 : !hlfir.expr<2x2xf64>
return
}
// CHECK-LABEL: func.func @transpose5(
// ...
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0:.*]]
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.designate %[[ARRAY:.*]] (%[[J]], %[[I]])
// CHECK: %[[LOAD:.*]] = fir.load %[[ELE]]
// CHECK: hlfir.yield_element %[[LOAD]]
// CHECK: }
// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]]
// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]]
// ...
// CHECK: hlfir.assign %[[COS]] to %{{.*}} realloc
// CHECK: hlfir.destroy %[[COS]]
// CHECK: hlfir.destroy %[[TRANSPOSE]]
// CHECK: return
// CHECK: }

0 comments on commit 74adc3e

Please sign in to comment.