Skip to content

Commit

Permalink
[flang] Fix array substring emboxing code generation
Browse files Browse the repository at this point in the history
The code generation of the fir.embox op creating descriptors for
array substring with a non constant length base was using the
substring length to compute the first dimension result stride.
Fix it to use the input length instead.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D154086
  • Loading branch information
jeanPerier committed Jun 29, 2023
1 parent e6fed06 commit b881fc2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
42 changes: 26 additions & 16 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,22 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
return CFI_attribute_other;
}

mlir::Value getCharacterByteSize(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
fir::CharacterType charTy,
mlir::ValueRange lenParams) const {
auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
mlir::Value size =
genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy));
if (charTy.hasConstantLen())
return size; // Length accounted for in the genTypeStrideInBytes GEP.
// Otherwise, multiply the single character size by the length.
assert(!lenParams.empty());
auto len64 = FIROpConversion<OP>::integerCast(loc, rewriter, i64Ty,
lenParams.back());
return rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, size, len64);
}

// Get the element size and CFI type code of the boxed value.
std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode(
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
Expand All @@ -1286,18 +1302,9 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
return {genTypeStrideInBytes(loc, i64Ty, rewriter,
this->convertType(boxEleTy)),
typeCodeVal};
if (auto charTy = boxEleTy.dyn_cast<fir::CharacterType>()) {
mlir::Value size =
genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy));
if (charTy.getLen() == fir::CharacterType::unknownLen()) {
// Multiply the single character size by the length.
assert(!lenParams.empty());
auto len64 = FIROpConversion<OP>::integerCast(loc, rewriter, i64Ty,
lenParams.back());
size = rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, size, len64);
}
return {size, typeCodeVal};
};
if (auto charTy = boxEleTy.dyn_cast<fir::CharacterType>())
return {getCharacterByteSize(loc, rewriter, charTy, lenParams),
typeCodeVal};
if (fir::isa_ref_type(boxEleTy)) {
auto ptrTy = mlir::LLVM::LLVMPointerType::get(
mlir::LLVM::LLVMVoidType::get(rewriter.getContext()));
Expand Down Expand Up @@ -1691,7 +1698,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
sourceBox = operands[xbox.getSourceBoxOffset()];
sourceBoxType = xbox.getSourceBox().getType();
}
auto [boxTy, dest, eleSize] = consDescriptorPrefix(
auto [boxTy, dest, resultEleSize] = consDescriptorPrefix(
xbox, fir::unwrapRefType(xbox.getMemref().getType()), rewriter,
xbox.getOutRank(), adaptor.getSubstr(), adaptor.getLenParams(),
sourceBox, sourceBoxType);
Expand Down Expand Up @@ -1720,7 +1727,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
// Adjust the element scaling factor if the element is a dependent type.
if (fir::hasDynamicSize(seqEleTy)) {
if (auto charTy = seqEleTy.dyn_cast<fir::CharacterType>()) {
prevPtrOff = eleSize;
prevPtrOff =
getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams());
} else if (seqEleTy.isa<fir::RecordType>()) {
// prevPtrOff = ;
TODO(loc, "generate call to calculate size of PDT");
Expand All @@ -1734,8 +1742,10 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
const auto hasSubcomp = !xbox.getSubcomponent().empty();
const bool hasSubstr = !xbox.getSubstr().empty();
// Initial element stride that will be use to compute the step in
// each dimension.
mlir::Value prevDimByteStride = eleSize;
// each dimension. Initially, this is the size of the input element.
// Note that when there are no components/substring, the resultEleSize
// that was previously computed matches the input element size.
mlir::Value prevDimByteStride = resultEleSize;
if (hasSubcomp) {
// We have a subcomponent. The step value needs to be the number of
// bytes per element (which is a derived type).
Expand Down
25 changes: 25 additions & 0 deletions flang/test/Fir/embox-substring.fir
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,28 @@ func.func @embox_index_substr(%addr : !fir.ref<!fir.array<?x!fir.char<1,2>>>) {
%3 = fir.embox %addr (%1) [%2] : (!fir.ref<!fir.array<?x!fir.char<1,2>>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?x!fir.char<1,?>>>
return
}

// CHARACTER(*) :: C(2)
// CALL DUMP(C(:)(1:1))
// Test that the resulting stride is based on the input length, not the substring one.
func.func @substring_dyn_base(%base_addr: !fir.ref<!fir.array<2x!fir.char<1,?>>>, %base_len: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%2 = fircg.ext_embox %base_addr(%c2)[%c1, %c2, %c1] substr %c0, %c1 typeparams %base_len : (!fir.ref<!fir.array<2x!fir.char<1,?>>>, index, index, index, index, index, index, index) -> !fir.box<!fir.array<2x!fir.char<1>>>
fir.call @dump(%2) : (!fir.box<!fir.array<2x!fir.char<1>>>) -> ()
return
}
func.func private @dump(!fir.box<!fir.array<2x!fir.char<1>>>)

// CHECK-LABEL: llvm.func @substring_dyn_base(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: i64) {
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: llvm.getelementptr
// CHECK: %[[VAL_28:.*]] = llvm.mlir.null : !llvm.ptr<i8>
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]][1] : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[VAL_30:.*]] = llvm.ptrtoint %[[VAL_29]] : !llvm.ptr<i8> to i64
// CHECK: %[[VAL_31:.*]] = llvm.mul %[[VAL_30]], %[[VAL_1]] : i64
// CHECK: %[[VAL_42:.*]] = llvm.mul %[[VAL_31]], %[[VAL_5]] : i64
// CHECK: %[[VAL_43:.*]] = llvm.insertvalue %[[VAL_42]], %{{.*}}[7, 0, 2] : !llvm.struct<(ptr<array<2 x array<1 x i8>>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>

0 comments on commit b881fc2

Please sign in to comment.