Skip to content

Commit

Permalink
[flang] add fir.box_offset operation (#73641)
Browse files Browse the repository at this point in the history
This operation allows computing the address of descriptor fields. It is
needed to help attaching descriptors in OpenMP/OpenACC target region.
The pointers inside the descriptor structure must be mapped too, but the
fir.box is abstract, so these fields cannot be computed with
fir.coordinate_of.

To preserve the abstraction of the descriptor layout in FIR, introduce
an operation specifically to !fir.ref<fir.box<>> address fields based on
field names (base_addr or derived_type).
  • Loading branch information
jeanPerier committed Nov 29, 2023
1 parent c145e4c commit 91e1b4a
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 19 deletions.
10 changes: 10 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
}

def fir_BoxFieldAttr : I32EnumAttr<
"BoxFieldAttr", "",
[
I32EnumAttrCase<"base_addr", 0>,
I32EnumAttrCase<"derived_type", 1>
]> {
let cppNamespace = "fir";
}


// mlir::SideEffects::Resource for modelling operations which add debugging information
def DebuggingResource : Resource<"::fir::DebuggingResource">;

Expand Down
39 changes: 39 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3036,4 +3036,43 @@ def fir_DeclareOp : fir_Op<"declare", [AttrSizedOperandSegments,
let hasVerifier = 1;
}

def fir_BoxOffsetOp : fir_Op<"box_offset", [NoMemoryEffect]> {

let summary = "Get the address of a field in a fir.ref<fir.box>";

let description = [{
Given the address of a fir.box, compute the address of a field inside
the fir.box.
This allows keeping the actual runtime descriptor layout abstract in
FIR while providing access to the pointer addresses in the runtime
descriptor for OpenMP/OpenACC target mapping.

To avoid requiring too much information about the fields that the runtime
descriptor implementation must have, only the base_addr and derived_type
descriptor fields can be addressed.

```
%addr = fir.box_offset %box base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
%tdesc = fir.box_offset %box derived_type : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>

```
}];

let arguments = (ins
AnyReferenceLike:$box_ref,
fir_BoxFieldAttr:$field
);

let results = (outs RefOrLLVMPtr);
let hasVerifier = 1;

let assemblyFormat = [{
$box_ref $field attr-dict `:` functional-type(operands, results)
}];

let builders = [
OpBuilder<(ins "mlir::Value":$boxRef, "fir::BoxFieldAttr":$field)>
];
}

#endif
59 changes: 40 additions & 19 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3656,6 +3656,27 @@ struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
}
};

struct BoxOffsetOpConversion : public FIROpConversion<fir::BoxOffsetOp> {
using FIROpConversion::FIROpConversion;

mlir::LogicalResult
matchAndRewrite(fir::BoxOffsetOp boxOffset, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Type pty = ::getLlvmPtrType(boxOffset.getContext());
mlir::Type boxType = fir::unwrapRefType(boxOffset.getBoxRef().getType());
mlir::Type llvmBoxTy =
lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(boxType));
unsigned fieldId = boxOffset.getField() == fir::BoxFieldAttr::derived_type
? getTypeDescFieldId(boxType)
: kAddrPosInBox;
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
boxOffset, pty, llvmBoxTy, adaptor.getBoxRef(),
llvm::ArrayRef<mlir::LLVM::GEPArg>{0, fieldId});
return mlir::success();
}
};

/// Conversion pattern for operation that must be dead. The information in these
/// operations is used by other operation. At this point they should not have
/// anymore uses.
Expand Down Expand Up @@ -3807,25 +3828,25 @@ class FIRToLLVMLowering
AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion,
BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion,
BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeCodeOpConversion,
BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion,
ConstcOpConversion, ConvertOpConversion, CoordinateOpConversion,
DTEntryOpConversion, DivcOpConversion, EmboxOpConversion,
EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion,
FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion,
GlobalLenOpConversion, GlobalOpConversion, HasValueOpConversion,
InsertOnRangeOpConversion, InsertValueOpConversion,
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
MulcOpConversion, NegcOpConversion, NoReassocOpConversion,
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
ShiftOpConversion, SliceOpConversion, StoreOpConversion,
StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
UndefOpConversion, UnreachableOpConversion,
UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(typeConverter,
options);
BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion,
BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
CmpcOpConversion, ConstcOpConversion, ConvertOpConversion,
CoordinateOpConversion, DTEntryOpConversion, DivcOpConversion,
EmboxOpConversion, EmboxCharOpConversion, EmboxProcOpConversion,
ExtractValueOpConversion, FieldIndexOpConversion, FirEndOpConversion,
FreeMemOpConversion, GlobalLenOpConversion, GlobalOpConversion,
HasValueOpConversion, InsertOnRangeOpConversion,
InsertValueOpConversion, IsPresentOpConversion,
LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
UnreachableOpConversion, UnrealizedConversionCastOpConversion,
XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
ZeroOpConversion>(typeConverter, options);
mlir::populateFuncToLLVMConversionPatterns(typeConverter, pattern);
mlir::populateOpenMPToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
Expand Down
33 changes: 33 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3584,6 +3584,39 @@ void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
results.push_back(term->getOperand(resultNum));
}

//===----------------------------------------------------------------------===//
// BoxOffsetOp
//===----------------------------------------------------------------------===//

mlir::LogicalResult fir::BoxOffsetOp::verify() {
auto boxType = mlir::dyn_cast_or_null<fir::BaseBoxType>(
fir::dyn_cast_ptrEleTy(getBoxRef().getType()));
if (!boxType)
return emitOpError("box_ref operand must have !fir.ref<!fir.box<T>> type");
if (getField() != fir::BoxFieldAttr::base_addr &&
getField() != fir::BoxFieldAttr::derived_type)
return emitOpError("cannot address provided field");
if (getField() == fir::BoxFieldAttr::derived_type)
if (!fir::boxHasAddendum(boxType))
return emitOpError("can only address derived_type field of derived type "
"or unlimited polymorphic fir.box");
return mlir::success();
}

void fir::BoxOffsetOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value boxRef,
fir::BoxFieldAttr field) {
mlir::Type valueType =
fir::unwrapPassByRefType(fir::unwrapRefType(boxRef.getType()));
mlir::Type resultType = valueType;
if (field == fir::BoxFieldAttr::base_addr)
resultType = fir::LLVMPointerType::get(fir::ReferenceType::get(valueType));
else if (field == fir::BoxFieldAttr::derived_type)
resultType = fir::LLVMPointerType::get(
fir::TypeDescType::get(fir::unwrapSequenceType(valueType)));
build(builder, result, {resultType}, boxRef, field);
}

//===----------------------------------------------------------------------===//

mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) {
Expand Down
39 changes: 39 additions & 0 deletions flang/test/Fir/box-offset-codegen.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Test fir.box_offset code generation.
// RUN: tco %s | FileCheck %s
// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s

func.func @scalar_addr(%scalar : !fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.ref<!fir.type<t>>> {
%addr = fir.box_offset %scalar base_addr : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.ref<!fir.type<t>>>
return %addr : !fir.llvm_ptr<!fir.ref<!fir.type<t>>>
}
// CHECK-LABEL: define ptr @scalar_addr(
// CHECK-SAME: ptr %[[BOX:.*]]) {
// CHECK: %[[VAL_0:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr %[[BOX]], i32 0, i32 0
// CHECK: ret ptr %[[VAL_0]]

func.func @scalar_tdesc(%scalar : !fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>> {
%tdesc = fir.box_offset %scalar derived_type : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>
return %tdesc : !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>
}
// CHECK-LABEL: define ptr @scalar_tdesc(
// CHECK-SAME: ptr %[[BOX:.*]]) {
// CHECK: %[[VAL_0:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr %[[BOX]], i32 0, i32 7
// CHECK: ret ptr %[[VAL_0]]

func.func @array_addr(%array : !fir.ref<!fir.class<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.ptr<!fir.array<?x!fir.type<t>>>> {
%addr = fir.box_offset %array base_addr : (!fir.ref<!fir.class<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.ptr<!fir.array<?x!fir.type<t>>>>
return %addr : !fir.llvm_ptr<!fir.ptr<!fir.array<?x!fir.type<t>>>>
}
// CHECK-LABEL: define ptr @array_addr(
// CHECK-SAME: ptr %[[BOX:.*]]) {
// CHECK: %[[VAL_0:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]], ptr, [1 x i64] }, ptr %[[BOX]], i32 0, i32 0
// CHECK: ret ptr %[[VAL_0]]

func.func @array_tdesc(%array : !fir.ref<!fir.class<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>> {
%tdesc = fir.box_offset %array derived_type : (!fir.ref<!fir.class<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>
return %tdesc : !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>
}
// CHECK-LABEL: define ptr @array_tdesc(
// CHECK-SAME: ptr %[[BOX:.*]]) {
// CHECK: %[[VAL_0:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]], ptr, [1 x i64] }, ptr %[[BOX]], i32 0, i32 8
// CHECK: ret ptr %[[VAL_0]]
42 changes: 42 additions & 0 deletions flang/test/Fir/box-offset.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Test fir.box_offset parse/print/parse/print identity.
// RUN: fir-opt %s | fir-opt | FileCheck %s

func.func @test_box_offset(%unlimited : !fir.ref<!fir.class<none>>, %type_star : !fir.ref<!fir.box<!fir.array<?xnone>>>) {
%box1 = fir.alloca !fir.box<i32>
%addr1 = fir.box_offset %box1 base_addr : (!fir.ref<!fir.box<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>>

%box2 = fir.alloca !fir.box<!fir.type<t>>
%addr2 = fir.box_offset %box2 base_addr : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.ref<!fir.type<t>>>
%tdesc2 = fir.box_offset %box2 derived_type : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>

%box3 = fir.alloca !fir.box<!fir.array<?xi32>>
%addr3 = fir.box_offset %box3 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>

%box4 = fir.alloca !fir.box<!fir.ptr<!fir.array<?x!fir.type<t>>>>
%addr4 = fir.box_offset %box4 base_addr : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.ptr<!fir.array<?x!fir.type<t>>>>
%tdesc4 = fir.box_offset %box4 derived_type : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>

%addr5 = fir.box_offset %unlimited base_addr : (!fir.ref<!fir.class<none>>) -> !fir.llvm_ptr<!fir.ref<none>>
%tdesc5 = fir.box_offset %unlimited derived_type : (!fir.ref<!fir.class<none>>) -> !fir.llvm_ptr<!fir.tdesc<none>>

%addr6 = fir.box_offset %type_star base_addr : (!fir.ref<!fir.box<!fir.array<?xnone>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xnone>>>
%tdesc6 = fir.box_offset %type_star derived_type : (!fir.ref<!fir.box<!fir.array<?xnone>>>) -> !fir.llvm_ptr<!fir.tdesc<none>>
return
}
// CHECK-LABEL: func.func @test_box_offset(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.class<none>>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<?xnone>>>) {
// CHECK: %[[VAL_2:.*]] = fir.alloca !fir.box<i32>
// CHECK: %[[VAL_3:.*]] = fir.box_offset %[[VAL_2]] base_addr : (!fir.ref<!fir.box<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>>
// CHECK: %[[VAL_4:.*]] = fir.alloca !fir.box<!fir.type<t>>
// CHECK: %[[VAL_5:.*]] = fir.box_offset %[[VAL_4]] base_addr : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.ref<!fir.type<t>>>
// CHECK: %[[VAL_6:.*]] = fir.box_offset %[[VAL_4]] derived_type : (!fir.ref<!fir.box<!fir.type<t>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>
// CHECK: %[[VAL_7:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
// CHECK: %[[VAL_8:.*]] = fir.box_offset %[[VAL_7]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
// CHECK: %[[VAL_9:.*]] = fir.alloca !fir.box<!fir.ptr<!fir.array<?x!fir.type<t>>>>
// CHECK: %[[VAL_10:.*]] = fir.box_offset %[[VAL_9]] base_addr : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.ptr<!fir.array<?x!fir.type<t>>>>
// CHECK: %[[VAL_11:.*]] = fir.box_offset %[[VAL_9]] derived_type : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.type<t>>>>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<t>>>
// CHECK: %[[VAL_12:.*]] = fir.box_offset %[[VAL_0]] base_addr : (!fir.ref<!fir.class<none>>) -> !fir.llvm_ptr<!fir.ref<none>>
// CHECK: %[[VAL_13:.*]] = fir.box_offset %[[VAL_0]] derived_type : (!fir.ref<!fir.class<none>>) -> !fir.llvm_ptr<!fir.tdesc<none>>
// CHECK: %[[VAL_14:.*]] = fir.box_offset %[[VAL_1]] base_addr : (!fir.ref<!fir.box<!fir.array<?xnone>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xnone>>>
// CHECK: %[[VAL_15:.*]] = fir.box_offset %[[VAL_1]] derived_type : (!fir.ref<!fir.box<!fir.array<?xnone>>>) -> !fir.llvm_ptr<!fir.tdesc<none>>
16 changes: 16 additions & 0 deletions flang/test/Fir/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,19 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> {
%0 = fir.convert %arg0 : (f32) -> !fir.logical<4>
return %0 : !fir.logical<4>
}

// -----

func.func @bad_box_offset(%not_a_box : !fir.ref<i32>) {
// expected-error@+1{{'fir.box_offset' op box_ref operand must have !fir.ref<!fir.box<T>> type}}
%addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref<i32>) -> !fir.llvm_ptr<!fir.ref<i32>>
return
}

// -----

func.func @bad_box_offset(%no_addendum : !fir.ref<!fir.box<i32>>) {
// expected-error@+1{{'fir.box_offset' op can only address derived_type field of derived type or unlimited polymorphic fir.box}}
%addr1 = fir.box_offset %no_addendum derived_type : (!fir.ref<!fir.box<i32>>) -> !fir.llvm_ptr<!fir.tdesc<!fir.type<none>>>
return
}

0 comments on commit 91e1b4a

Please sign in to comment.