diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h index 9d884453d2ef60..3a0254ab4de8ed 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -283,6 +283,12 @@ bool isBoxNone(mlir::Type ty); /// e.g. !fir.box> bool isBoxedRecordType(mlir::Type ty); +/// Return true iff `ty` is a scalar boxed record type. +/// e.g. !fir.box> +/// !fir.box>> +/// !fir.class> +bool isScalarBoxedRecordType(mlir::Type ty); + /// Return the nested RecordType if one if found. Return ty otherwise. mlir::Type getDerivedType(mlir::Type ty); diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 0da45e87436b72..5227fdfb8f6f61 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1708,7 +1708,7 @@ IntrinsicLibrary::genElementalCall( for (const fir::ExtendedValue &arg : args) { auto *box = arg.getBoxOf(); if (!arg.getUnboxed() && !arg.getCharBox() && - !(box && fir::isPolymorphicType(fir::getBase(*box).getType()))) + !(box && fir::isScalarBoxedRecordType(fir::getBase(*box).getType()))) fir::emitFatalError(loc, "nonscalar intrinsic argument"); } if (outline) diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index dac0e6557538fa..decb93f3d55e38 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -290,6 +290,20 @@ bool isBoxedRecordType(mlir::Type ty) { return false; } +bool isScalarBoxedRecordType(mlir::Type ty) { + if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) + ty = refTy; + if (auto boxTy = ty.dyn_cast()) { + if (boxTy.getEleTy().isa()) + return true; + if (auto heapTy = boxTy.getEleTy().dyn_cast()) + return heapTy.getEleTy().isa(); + if (auto ptrTy = boxTy.getEleTy().dyn_cast()) + return ptrTy.getEleTy().isa(); + } + return false; +} + static bool isAssumedType(mlir::Type ty) { if (auto boxTy = ty.dyn_cast()) { if (boxTy.getEleTy().isa()) diff --git a/flang/test/Lower/polymorphic-temp.f90 b/flang/test/Lower/polymorphic-temp.f90 index f8627ef5b80a36..5dfd36af3f3657 100644 --- a/flang/test/Lower/polymorphic-temp.f90 +++ b/flang/test/Lower/polymorphic-temp.f90 @@ -207,4 +207,23 @@ subroutine test_merge_intrinsic(a, b) ! CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[ARG0]], %[[ARG1]] : !fir.class> ! CHECK: fir.call @_QMpoly_tmpPcheck_scalar(%[[SELECT]]) {{.*}} : (!fir.class>) -> () + subroutine test_merge_intrinsic2(a, b, i) + class(p1), allocatable, intent(in) :: a + type(p1), allocatable :: b + integer, intent(in) :: i + + call check_scalar(merge(a, b, i==1)) + end subroutine + + +! CHECK-LABEL: func.func @_QMpoly_tmpPtest_merge_intrinsic2( +! CHECK-SAME: %[[A:.*]]: !fir.ref>>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref>>> {fir.bindc_name = "b"}, %[[I:.*]]: !fir.ref {fir.bindc_name = "i"}) { +! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]] : !fir.ref>>> +! CHECK: %[[LOAD_B:.*]] = fir.load %[[B]] : !fir.ref>>> +! CHECK: %[[LOAD_I:.*]] = fir.load %[[I]] : !fir.ref +! CHECK: %[[C1:.*]] = arith.constant 1 : i32 +! CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[LOAD_I]], %[[C1]] : i32 +! CHECK: %[[B_CONV:.*]] = fir.convert %[[LOAD_B]] : (!fir.box>>) -> !fir.class>> +! CHECK: %{{.*}} = arith.select %[[CMPI]], %[[LOAD_A]], %[[B_CONV]] : !fir.class>> + end module diff --git a/flang/unittests/Optimizer/FIRTypesTest.cpp b/flang/unittests/Optimizer/FIRTypesTest.cpp index e30800a3caf560..41588e2c98b2f4 100644 --- a/flang/unittests/Optimizer/FIRTypesTest.cpp +++ b/flang/unittests/Optimizer/FIRTypesTest.cpp @@ -147,6 +147,43 @@ TEST_F(FIRTypesTest, isBoxedRecordType) { fir::ReferenceType::get(mlir::IntegerType::get(&context, 32))))); } +// Test fir::isScalarBoxedRecordType from flang/Optimizer/Dialect/FIRType.h. +TEST_F(FIRTypesTest, isScalarBoxedRecordType) { + mlir::Type recTy = fir::RecordType::get(&context, "dt"); + mlir::Type seqRecTy = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, recTy); + mlir::Type ty = fir::BoxType::get(recTy); + EXPECT_TRUE(fir::isScalarBoxedRecordType(ty)); + EXPECT_TRUE(fir::isScalarBoxedRecordType(fir::ReferenceType::get(ty))); + + // CLASS(T), ALLOCATABLE + ty = fir::ClassType::get(fir::HeapType::get(recTy)); + EXPECT_TRUE(fir::isScalarBoxedRecordType(ty)); + + // TYPE(T), ALLOCATABLE + ty = fir::BoxType::get(fir::HeapType::get(recTy)); + EXPECT_TRUE(fir::isScalarBoxedRecordType(ty)); + + // TYPE(T), POINTER + ty = fir::BoxType::get(fir::PointerType::get(recTy)); + EXPECT_TRUE(fir::isScalarBoxedRecordType(ty)); + + // CLASS(T), POINTER + ty = fir::ClassType::get(fir::PointerType::get(recTy)); + EXPECT_TRUE(fir::isScalarBoxedRecordType(ty)); + + // TYPE(T), DIMENSION(10) + ty = fir::BoxType::get(fir::SequenceType::get({10}, recTy)); + EXPECT_FALSE(fir::isScalarBoxedRecordType(ty)); + + // TYPE(T), DIMENSION(:) + ty = fir::BoxType::get(seqRecTy); + EXPECT_FALSE(fir::isScalarBoxedRecordType(ty)); + + EXPECT_FALSE(fir::isScalarBoxedRecordType(fir::BoxType::get( + fir::ReferenceType::get(mlir::IntegerType::get(&context, 32))))); +} + TEST_F(FIRTypesTest, updateTypeForUnlimitedPolymorphic) { // RecordType are not changed.