diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index cfce9fca504ec..7bfb304b973f5 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td" include "flang/Optimizer/Dialect/FIRDialect.td" @@ -80,7 +81,10 @@ def AnyRefOfConstantSizeAggregateType : TypeConstraint< // Memory SSA operations //===----------------------------------------------------------------------===// -def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments]> { +def fir_AllocaOp : fir_Op<"alloca", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods +]> { let summary = "allocate storage for a temporary on the stack given a type"; let description = [{ This primitive operation is used to allocate an object on the stack. A @@ -288,8 +292,11 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> { let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))"; } -def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface, - DeclareOpInterfaceMethods]> { +def fir_LoadOp : fir_OneResultOp<"load", [ + FirAliasTagOpInterface, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { let summary = "load a value from a memory reference"; let description = [{ Load a value from a memory reference into an ssa-value (virtual register). @@ -319,8 +326,11 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface, }]; } -def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface, - DeclareOpInterfaceMethods]> { +def fir_StoreOp : fir_Op<"store", [ + FirAliasTagOpInterface, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { let summary = "store an SSA-value to a memory location"; let description = [{ diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h index 67e9287ddad4f..41a979f97aece 100644 --- a/flang/include/flang/Optimizer/Support/InitFIR.h +++ b/flang/include/flang/Optimizer/Support/InitFIR.h @@ -129,6 +129,7 @@ inline void registerMLIRPassesForFortranTools() { mlir::affine::registerAffineLoopTilingPass(); mlir::affine::registerAffineDataCopyGenerationPass(); + mlir::registerMem2RegPass(); mlir::registerLowerAffinePass(); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 4e797d651cb7a..c2a3d52fe88d2 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -186,6 +186,36 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) { return fir::ReferenceType::get(intype); } +llvm::SmallVector fir::AllocaOp::getPromotableSlots() { + // TODO: support promotion of dynamic allocas + if (isDynamic()) + return {}; + + return {mlir::MemorySlot{getResult(), getAllocatedType()}}; +} + +mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot, + mlir::OpBuilder &builder) { + return fir::UndefOp::create(builder, getLoc(), slot.elemType); +} + +void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot, + mlir::BlockArgument argument, + mlir::OpBuilder &builder) {} + +std::optional +fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot, + mlir::Value defaultValue, + mlir::OpBuilder &builder) { + if (defaultValue && defaultValue.use_empty()) { + assert(mlir::isa(defaultValue.getDefiningOp()) && + "Expected undef op to be the default value"); + defaultValue.getDefiningOp()->erase(); + } + this->erase(); + return std::nullopt; +} + mlir::Type fir::AllocaOp::getAllocatedType() { return mlir::cast(getType()).getEleTy(); } @@ -2861,6 +2891,39 @@ llvm::SmallVector fir::LenParamIndexOp::getAttributes() { // LoadOp //===----------------------------------------------------------------------===// +bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) { + return getMemref() == slot.ptr; +} + +bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; } + +mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot, + mlir::OpBuilder &builder, + mlir::Value reachingDef, + const mlir::DataLayout &dataLayout) { + return mlir::Value(); +} + +bool fir::LoadOp::canUsesBeRemoved( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl &blockingUses, + mlir::SmallVectorImpl &newBlockingUses, + const mlir::DataLayout &dataLayout) { + if (blockingUses.size() != 1) + return false; + mlir::Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemref() == slot.ptr; +} + +mlir::DeletionKind fir::LoadOp::removeBlockingUses( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl &blockingUses, + mlir::OpBuilder &builder, mlir::Value reachingDefinition, + const mlir::DataLayout &dataLayout) { + getResult().replaceAllUsesWith(reachingDefinition); + return mlir::DeletionKind::Delete; +} + void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value refVal) { if (!refVal) { @@ -4256,6 +4319,39 @@ llvm::LogicalResult fir::SliceOp::verify() { // StoreOp //===----------------------------------------------------------------------===// +bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; } + +bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) { + return getMemref() == slot.ptr; +} + +mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot, + mlir::OpBuilder &builder, + mlir::Value reachingDef, + const mlir::DataLayout &dataLayout) { + return getValue(); +} + +bool fir::StoreOp::canUsesBeRemoved( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl &blockingUses, + mlir::SmallVectorImpl &newBlockingUses, + const mlir::DataLayout &dataLayout) { + if (blockingUses.size() != 1) + return false; + mlir::Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemref() == slot.ptr && + getValue() != slot.ptr; +} + +mlir::DeletionKind fir::StoreOp::removeBlockingUses( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl &blockingUses, + mlir::OpBuilder &builder, mlir::Value reachingDefinition, + const mlir::DataLayout &dataLayout) { + return mlir::DeletionKind::Delete; +} + mlir::Type fir::StoreOp::elementType(mlir::Type refType) { return fir::dyn_cast_ptrEleTy(refType); } diff --git a/flang/test/Fir/mem2reg.mlir b/flang/test/Fir/mem2reg.mlir new file mode 100644 index 0000000000000..25d114a55e1a4 --- /dev/null +++ b/flang/test/Fir/mem2reg.mlir @@ -0,0 +1,68 @@ +// RUN: fir-opt %s --allow-unregistered-dialect --mem2reg --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @basic() -> i32 { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32 +// CHECK: return %[[CONSTANT_0]] : i32 +// CHECK: } +func.func @basic() -> i32 { + %0 = arith.constant 5 : i32 + %1 = fir.alloca i32 + fir.store %0 to %1 : !fir.ref + %2 = fir.load %1 : !fir.ref + return %2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @default_value() -> i32 { +// CHECK: %[[UNDEFINED_0:.*]] = fir.undefined i32 +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32 +// CHECK: return %[[UNDEFINED_0]] : i32 +// CHECK: } +func.func @default_value() -> i32 { + %0 = arith.constant 5 : i32 + %1 = fir.alloca i32 + %2 = fir.load %1 : !fir.ref + fir.store %0 to %1 : !fir.ref + return %2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @basic_float() -> f32 { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5.200000e+00 : f32 +// CHECK: return %[[CONSTANT_0]] : f32 +// CHECK: } +func.func @basic_float() -> f32 { + %0 = arith.constant 5.2 : f32 + %1 = fir.alloca f32 + fir.store %0 to %1 : !fir.ref + %2 = fir.load %1 : !fir.ref + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: func.func @cycle( +// CHECK-SAME: %[[ARG0:.*]]: i64, +// CHECK-SAME: %[[ARG1:.*]]: i1, +// CHECK-SAME: %[[ARG2:.*]]: i64) { +// CHECK: cf.cond_br %[[ARG1]], ^bb1(%[[ARG2]] : i64), ^bb2(%[[ARG2]] : i64) +// CHECK: ^bb1(%[[VAL_0:.*]]: i64): +// CHECK: "test.use"(%[[VAL_0]]) : (i64) -> () +// CHECK: cf.br ^bb2(%[[ARG0]] : i64) +// CHECK: ^bb2(%[[VAL_1:.*]]: i64): +// CHECK: cf.br ^bb1(%[[VAL_1]] : i64) +// CHECK: } +func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) { + %alloca = fir.alloca i64 + fir.store %arg2 to %alloca : !fir.ref + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1: + %use = fir.load %alloca : !fir.ref + "test.use"(%use) : (i64) -> () + fir.store %arg0 to %alloca : !fir.ref + cf.br ^bb2 +^bb2: + cf.br ^bb1 +}