Skip to content

Commit

Permalink
[flang] HLFIR to FIR lowering for complex parts
Browse files Browse the repository at this point in the history
This revision implements HLFIR to FIR lowering for complex parts.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D146487
  • Loading branch information
EthanLuisMcDonough committed Apr 4, 2023
1 parent 0e0db0a commit e89e244
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 10 deletions.
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,8 @@ mlir::LogicalResult fir::ReboxOp::verify() {
inputEleTy.isa<fir::RecordType>() || outEleTy.isa<mlir::NoneType>() ||
(inputEleTy.isa<mlir::NoneType>() && outEleTy.isa<fir::RecordType>()) ||
(getSlice() && inputEleTy.isa<fir::CharacterType>()) ||
(getSlice() && fir::isa_complex(inputEleTy) &&
outEleTy.isa<mlir::FloatType>()) ||
areCompatibleCharacterTypes(inputEleTy, outEleTy);
if (!typeCanMismatch)
return emitOpError(
Expand Down
46 changes: 36 additions & 10 deletions flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,6 @@ class DesignateOpConversion
auto module = designate->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));

if (designate.getComplexPart())
TODO(loc, "hlfir::designate with complex part");

hlfir::Entity baseEntity(designate.getMemref());

if (baseEntity.isMutableBox())
Expand All @@ -377,6 +374,7 @@ class DesignateOpConversion
auto [base, shape] = hlfir::genVariableFirBaseShapeAndParams(
loc, builder, baseEntity, firBaseTypeParameters);
mlir::Type baseEleTy = hlfir::getFortranElementType(base.getType());
mlir::Type resultEleTy = hlfir::getFortranElementType(designateResultType);

mlir::Value fieldIndex;
if (designate.getComponent()) {
Expand Down Expand Up @@ -428,12 +426,7 @@ class DesignateOpConversion
if (fieldIndex && baseEntity.isArray()) {
// array%scalar_comp or array%array_comp(indices)
// Generate triples for array(:, :, ...).
auto one = builder.createIntegerConstant(loc, idxTy, 1);
for (auto [lb, ub] : hlfir::genBounds(loc, builder, baseEntity)) {
triples.push_back(builder.createConvert(loc, idxTy, lb));
triples.push_back(builder.createConvert(loc, idxTy, ub));
triples.push_back(one);
}
triples = genFullSliceTriples(builder, loc, baseEntity);
sliceFields.push_back(fieldIndex);
// Add indices in the field path for "array%array_comp(indices)"
// case.
Expand Down Expand Up @@ -464,7 +457,12 @@ class DesignateOpConversion
builder.create<mlir::arith::SubIOp>(loc, substring[0], one);
substring.push_back(designate.getTypeparams()[0]);
}

if (designate.getComplexPart()) {
if (triples.empty())
triples = genFullSliceTriples(builder, loc, baseEntity);
sliceFields.push_back(builder.createIntegerConstant(
loc, idxTy, *designate.getComplexPart()));
}
mlir::Value slice;
if (!triples.empty())
slice =
Expand Down Expand Up @@ -517,6 +515,16 @@ class DesignateOpConversion
base = fir::factory::CharacterExprHelper{builder, loc}.genSubstringBase(
base, designate.getSubstring()[0], resultAddressType);

// Scalar complex part ref
if (designate.getComplexPart()) {
// Sequence types should have already been handled by this point
assert(!designateResultType.isa<fir::SequenceType>());
auto index = builder.createIntegerConstant(loc, builder.getIndexType(),
*designate.getComplexPart());
auto coorTy = fir::ReferenceType::get(resultEleTy);
base = builder.create<fir::CoordinateOp>(loc, coorTy, base, index);
}

// Cast/embox the computed scalar address if needed.
if (designateResultType.isa<fir::BoxCharType>()) {
assert(designate.getTypeparams().size() == 1 &&
Expand All @@ -530,6 +538,24 @@ class DesignateOpConversion
}
return mlir::success();
}

private:
// Generates triple for full slice
// Used for component and complex part slices when a triple is
// not specified
static llvm::SmallVector<mlir::Value>
genFullSliceTriples(fir::FirOpBuilder &builder, mlir::Location loc,
hlfir::Entity baseEntity) {
llvm::SmallVector<mlir::Value> triples;
mlir::Type idxTy = builder.getIndexType();
auto one = builder.createIntegerConstant(loc, idxTy, 1);
for (auto [lb, ub] : hlfir::genBounds(loc, builder, baseEntity)) {
triples.push_back(builder.createConvert(loc, idxTy, lb));
triples.push_back(builder.createConvert(loc, idxTy, ub));
triples.push_back(one);
}
return triples;
}
};

class ParentComponentOpConversion
Expand Down
72 changes: 72 additions & 0 deletions flang/test/Fir/rebox.fir
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,78 @@ func.func @test_rebox_4(%arg0: !fir.box<!fir.array<?x!fir.char<1,?>>>) {
}
func.func private @bar_test_rebox_4(!fir.box<!fir.ptr<!fir.array<?x!fir.char<1,10>>>>)

// Testing complex part slice reboxing
// subroutine test_cmplx_2(a)
// complex :: a(:)
// call bar1(a%re)
// end subroutine

// CHECK-LABEL: define void @test_cmplx_1(
// CHECK-SAME: ptr %[[INBOX:.*]])
func.func @test_cmplx_1(%arg0: !fir.box<!fir.array<?x!fir.complex<4>>>) {
// CHECK: %[[OUTBOX_ALLOC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }
%c1 = arith.constant 1 : index
%c1_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%0:3 = fir.box_dims %arg0, %c0 : (!fir.box<!fir.array<?x!fir.complex<4>>>, index) -> (index, index, index)
%1 = fir.slice %c1, %0#1, %c1 path %c1_i32 : (index, index, index, i32) -> !fir.slice<1>
%2 = fir.rebox %arg0 [%1] : (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.slice<1>) -> !fir.box<!fir.array<?xf32>>
// CHECK: %[[INSTRIDE_0_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[INBOX]], i32 0, i32 7, i64 0, i32 1
// CHECK: %[[INSTRIDE_0:.*]] = load i64, ptr %[[INSTRIDE_0_GEP]]
// CHECK: %[[INSTRIDE_1_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[INBOX]], i32 0, i32 7, i32 0, i32 2
// CHECK: %[[INSTRIDE_1:.*]] = load i64, ptr %[[INSTRIDE_1_GEP]]
// CHECK: %[[FRONT_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[INBOX]], i32 0, i32 0
// CHECK: %[[FRONT_PTR:.*]] = load ptr, ptr %[[FRONT_GEP]]
// CHECK: %[[FIELD_OFFSET_GEP:.*]] = getelementptr { float, float }, ptr %[[FRONT_PTR]], i64 0, i32 0
// CHECK: %[[FRONT_OFFSET:.*]] = mul i64 0, %[[INSTRIDE_1]]
// CHECK: %[[OFFSET_GEP:.*]] = getelementptr i8, ptr %[[FIELD_OFFSET_GEP]], i64 %[[FRONT_OFFSET]]
// CHECK: %[[SUB_1:.*]] = sub i64 %[[INSTRIDE_0]], 1
// CHECK: %[[ADD_1:.*]] = add i64 %[[SUB_1]], 1
// CHECK: %[[DIV_1:.*]] = sdiv i64 %[[ADD_1]], 1
// CHECK: %[[CHECK_NONZERO:.*]] = icmp sgt i64 %[[DIV_1]], 0
// CHECK: %[[CHECKED_BOUND:.*]] = select i1 %[[CHECK_NONZERO]], i64 %[[DIV_1]], i64 0
// CHECK: %[[STRIDE:.*]] = mul i64 1, %[[INSTRIDE_1]]
// CHECK: %[[VAL_BUILD_1:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } { ptr undef, i64 ptrtoint (ptr getelementptr (float, ptr null, i32 1) to i64), i32 {{.*}}, i8 1, i8 27, i8 0, i8 0, [1 x [3 x i64]] [{{\[}}3 x i64] [i64 1, i64 undef, i64 undef]] }, i64 %[[CHECKED_BOUND]], 7, 0, 1
// CHECK: %[[VAL_BUILD_2:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } %[[VAL_BUILD_1]], i64 %[[STRIDE]], 7, 0, 2
// CHECK: %[[VAL_BUILD_3:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } %[[VAL_BUILD_2]], ptr %[[OFFSET_GEP]], 0
// CHECK: store { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } %[[VAL_BUILD_3]], ptr %[[OUTBOX_ALLOC]]
fir.call @bar1(%2) : (!fir.box<!fir.array<?xf32>>) -> ()
// CHECK: call void @bar1(ptr %[[OUTBOX_ALLOC]])
return
}

// Testing triple on complex part slice
// subroutine test_cmplx_2(a)
// complex :: a(:)
// call bar1(a(7:60:5)%im)
// end subroutine

// CHECK-LABEL: define void @test_cmplx_2(
// CHECK-SAME: ptr %[[INBOX:.*]])
func.func @test_cmplx_2(%arg0: !fir.box<!fir.array<?x!fir.complex<4>>>) {
// CHECK: %[[OUTBOX_ALLOC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }
%c7 = arith.constant 7 : index
%c5 = arith.constant 5 : index
%c60 = arith.constant 60 : index
%c1_i32 = arith.constant 1 : i32
%0 = fir.slice %c7, %c60, %c5 path %c1_i32 : (index, index, index, i32) -> !fir.slice<1>
%1 = fir.rebox %arg0 [%0] : (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.slice<1>) -> !fir.box<!fir.array<11xf32>>
%2 = fir.convert %1 : (!fir.box<!fir.array<11xf32>>) -> !fir.box<!fir.array<?xf32>>
// CHECK: %[[INSTRIDE_0_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[INBOX]], i32 0, i32 7, i32 0, i32 2
// CHECK: %[[INSTRIDE_0:.*]] = load i64, ptr %[[INSTRIDE_0_GEP]]
// CHECK: %[[FRONT_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[INBOX]], i32 0, i32 0
// CHECK: %[[FRONT_PTR:.*]] = load ptr, ptr %[[FRONT_GEP]]
// CHECK: %[[FIELD_OFFSET_GEP:.*]] = getelementptr { float, float }, ptr %[[FRONT_PTR]], i64 0, i32 1
// CHECK: %[[FRONT_OFFSET:.*]] = mul i64 6, %[[INSTRIDE_0]]
// CHECK: %[[OFFSET_GEP:.*]] = getelementptr i8, ptr %[[FIELD_OFFSET_GEP]], i64 %[[FRONT_OFFSET]]
// CHECK: %[[STRIDE:.*]] = mul i64 5, %[[INSTRIDE_0]]
// CHECK: %[[VAL_BUILD_1:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } { ptr undef, i64 ptrtoint (ptr getelementptr (float, ptr null, i32 1) to i64), i32 {{.*}}, i8 1, i8 27, i8 0, i8 0, [1 x [3 x i64]] [{{\[}}3 x i64] [i64 1, i64 11, i64 undef]] }, i64 %[[STRIDE]], 7, 0, 2
// CHECK: %[[VAL_BUILD_2:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } %[[VAL_BUILD_1]], ptr %[[OFFSET_GEP]], 0
// CHECK: store { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } %[[VAL_BUILD_2]], ptr %[[OUTBOX_ALLOC]]
fir.call @bar1(%2) fastmath<contract> : (!fir.box<!fir.array<?xf32>>) -> ()
// CHECK: call void @bar1(ptr %[[OUTBOX_ALLOC]])
return
}

// Test reboxing of unlimited polymorphic.

Expand Down
83 changes: 83 additions & 0 deletions flang/test/HLFIR/designate-codegen-complex-part.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Test code generation to FIR of hlfir.designate operations
// with complex parts.
// RUN: fir-opt %s -convert-hlfir-to-fir | FileCheck %s

func.func @test_set_scalar(%arg0: !fir.ref<!fir.complex<4>>, %arg1: !fir.ref<f32>) {
%0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
%1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%2 = fir.load %1#0 : !fir.ref<f32>
%3 = hlfir.designate %0#0 imag : (!fir.ref<!fir.complex<4>>) -> !fir.ref<f32>
hlfir.assign %2 to %3 : f32, !fir.ref<f32>
return
}
// CHECK-LABEL: func.func @test_set_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.complex<4>>, %[[VAL_1:.*]]: !fir.ref<f32>) {
// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_0]] {uniq_name = "a"} : (!fir.ref<!fir.complex<4>>) -> !fir.ref<!fir.complex<4>>
// CHECK: %[[VAL_3:.*]] = fir.declare %[[VAL_1]] {uniq_name = "b"} : (!fir.ref<f32>) -> !fir.ref<f32>
// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<f32>
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_2]], %[[VAL_5]] : (!fir.ref<!fir.complex<4>>, index) -> !fir.ref<f32>
// CHECK: fir.store %[[VAL_4]] to %[[VAL_6]] : !fir.ref<f32>

func.func @test_scalar_at_index(%arg0: !fir.box<!fir.array<?x!fir.complex<4>>>, %arg1: !fir.ref<i32>) {
%0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.box<!fir.array<?x!fir.complex<4>>>)
%1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%2 = fir.load %1#0 : !fir.ref<i32>
%3 = fir.convert %2 : (i32) -> i64
%4 = hlfir.designate %0#0 (%3) real : (!fir.box<!fir.array<?x!fir.complex<4>>>, i64) -> !fir.ref<f32>
return
}
// CHECK-LABEL: func.func @test_scalar_at_index(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x!fir.complex<4>>>, %[[VAL_1:.*]]: !fir.ref<i32>) {
// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_0]] {uniq_name = "a"} : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<!fir.array<?x!fir.complex<4>>>
// CHECK: %[[VAL_3:.*]] = fir.rebox %[[VAL_2]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<!fir.array<?x!fir.complex<4>>>
// CHECK: %[[VAL_4:.*]] = fir.declare %[[VAL_1]] {uniq_name = "b"} : (!fir.ref<i32>) -> !fir.ref<i32>
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (i32) -> i64
// CHECK: %[[VAL_7:.*]] = fir.array_coor %[[VAL_2]] %[[VAL_6]] : (!fir.box<!fir.array<?x!fir.complex<4>>>, i64) -> !fir.ref<!fir.complex<4>>
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[VAL_7]], %[[VAL_8]] : (!fir.ref<!fir.complex<4>>, index) -> !fir.ref<f32>

func.func @test_complete_slice(%arg0: !fir.box<!fir.array<?x!fir.complex<4>>>) {
%c0 = arith.constant 0 : index
%0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.box<!fir.array<?x!fir.complex<4>>>)
%1:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?x!fir.complex<4>>>, index) -> (index, index, index)
%2 = fir.shape %1#1 : (index) -> !fir.shape<1>
%3 = hlfir.designate %0#0 imag shape %2 : (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
return
}
// CHECK-LABEL: func.func @test_complete_slice(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x!fir.complex<4>>>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_0]] {uniq_name = "a"} : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<!fir.array<?x!fir.complex<4>>>
// CHECK: %[[VAL_3:.*]] = fir.rebox %[[VAL_2]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<!fir.array<?x!fir.complex<4>>>
// CHECK: %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_3]], %[[VAL_1]] : (!fir.box<!fir.array<?x!fir.complex<4>>>, index) -> (index, index, index)
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]]#1 : (index) -> !fir.shape<1>
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_8]] : (!fir.box<!fir.array<?x!fir.complex<4>>>, index) -> (index, index, index)
// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_11:.*]] = fir.slice %[[VAL_7]], %[[VAL_9]]#1, %[[VAL_6]] path %[[VAL_10]] : (index, index, index, index) -> !fir.slice<1>
// CHECK: %[[VAL_12:.*]] = fir.rebox %[[VAL_2]] [%[[VAL_11]]] : (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.slice<1>) -> !fir.box<!fir.array<?xf32>>

func.func @test_slice_steps(%arg0: !fir.box<!fir.array<?x!fir.complex<4>>>) {
%c3 = arith.constant 3 : index
%c12 = arith.constant 12 : index
%c4 = arith.constant 4 : index
%0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.box<!fir.array<?x!fir.complex<4>>>)
%1 = fir.shape %c3 : (index) -> !fir.shape<1>
%2 = hlfir.designate %0#0 (%c4:%c12:%c3) real shape %1 : (!fir.box<!fir.array<?x!fir.complex<4>>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<3xf32>>
return
}
// CHECK-LABEL: func.func @test_slice_steps(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x!fir.complex<4>>>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 12 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_4:.*]] = fir.declare %[[VAL_0]] {uniq_name = "a"} : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<!fir.array<?x!fir.complex<4>>>
// CHECK: %[[VAL_5:.*]] = fir.rebox %[[VAL_4]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<!fir.array<?x!fir.complex<4>>>
// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = fir.slice %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] path %[[VAL_7]] : (index, index, index, index) -> !fir.slice<1>
// CHECK: %[[VAL_9:.*]] = fir.rebox %[[VAL_4]] [%[[VAL_8]]] : (!fir.box<!fir.array<?x!fir.complex<4>>>, !fir.slice<1>) -> !fir.box<!fir.array<3xf32>>

0 comments on commit e89e244

Please sign in to comment.