Skip to content

Commit

Permalink
Convert fir.allocmem and fir.freemem operations to calls to malloc an…
Browse files Browse the repository at this point in the history
…d free, respectively

This patch is part of the upstreaming effort from the fir-dev branch.

Address review comments
- move CHECK blocks to after the mlir code in the test file
- fix style with respect to anonymous namespaces: only include class definitions in the namespace and make functions static and outside the namespace
- fix a few nits
- remove TODO in favor of notifyMatchFailure
- removed unnecessary CHECK line from convert-to-llvm.fir
- rebase on main - add TODO back in
- get successfull test of TODO in AllocMemOp converion of derived type with LEN params
- clearer comments and reduced use of auto
- move defintion of computeDerivedTypeSize to fix build error

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>

Reviewed By: awarzynski, clementval, kiranchandramohan, schweitz

Differential Revision: https://reviews.llvm.org/D114104
  • Loading branch information
AlexisPerry committed Dec 7, 2021
1 parent 8421fa5 commit c2acd45
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 52 deletions.
212 changes: 164 additions & 48 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Expand Up @@ -69,6 +69,7 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
mlir::Type convertType(mlir::Type ty) const {
return lowerTy().convertType(ty);
}
mlir::Type voidPtrTy() const { return getVoidPtrType(); }

mlir::Type getVoidPtrType() const {
return mlir::LLVM::LLVMPointerType::get(
Expand Down Expand Up @@ -608,13 +609,15 @@ struct CallOpConversion : public FIROpConversion<fir::CallOp> {
return success();
}
};
} // namespace

static mlir::Type getComplexEleTy(mlir::Type complex) {
if (auto cc = complex.dyn_cast<mlir::ComplexType>())
return cc.getElementType();
return complex.cast<fir::ComplexType>().getElementType();
}

namespace {
/// Compare complex values
///
/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une).
Expand Down Expand Up @@ -878,6 +881,119 @@ struct GenTypeDescOpConversion : public FIROpConversion<fir::GenTypeDescOp> {
return failure();
}
};
} // namespace

/// Return the LLVMFuncOp corresponding to the standard malloc call.
static mlir::LLVM::LLVMFuncOp
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
auto module = op->getParentOfType<mlir::ModuleOp>();
if (mlir::LLVM::LLVMFuncOp mallocFunc =
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("malloc"))
return mallocFunc;
mlir::OpBuilder moduleBuilder(
op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
mlir::LLVM::LLVMFunctionType::get(getVoidPtrType(op.getContext()),
indexType,
/*isVarArg=*/false));
}

/// Helper function for generating the LLVM IR that computes the size
/// in bytes for a derived type.
static mlir::Value
computeDerivedTypeSize(mlir::Location loc, mlir::Type ptrTy, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter) {
auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, ptrTy);
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
llvm::SmallVector<mlir::Value> args{nullPtr, one};
auto gep = rewriter.create<mlir::LLVM::GEPOp>(loc, ptrTy, args);
return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, gep);
}

namespace {
/// Lower a `fir.allocmem` instruction into `llvm.call @malloc`
struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
using FIROpConversion::FIROpConversion;

mlir::LogicalResult
matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type ty = convertType(heap.getType());
mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc(heap, rewriter);
mlir::Location loc = heap.getLoc();
auto ity = lowerTy().indexType();
if (auto recTy = fir::unwrapSequenceType(heap.getAllocatedType())
.dyn_cast<fir::RecordType>())
if (recTy.getNumLenParams() != 0) {
TODO(loc,
"fir.allocmem codegen of derived type with length parameters");
return failure();
}
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, ty);
for (mlir::Value opnd : adaptor.getOperands())
size = rewriter.create<mlir::LLVM::MulOp>(
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
auto malloc = rewriter.create<mlir::LLVM::CallOp>(
loc, ::getVoidPtrType(heap.getContext()), size, heap->getAttrs());
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(heap, ty,
malloc.getResult(0));
return success();
}

// Compute the (allocation) size of the allocmem type in bytes.
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) const {
// Use the primitive size, if available.
auto ptrTy = llTy.dyn_cast<mlir::LLVM::LLVMPointerType>();
if (auto size =
mlir::LLVM::getPrimitiveTypeSizeInBits(ptrTy.getElementType()))
return genConstantIndex(loc, idxTy, rewriter, size / 8);

// Otherwise, generate the GEP trick in LLVM IR to compute the size.
return computeDerivedTypeSize(loc, ptrTy, idxTy, rewriter);
}
};
} // namespace

/// Return the LLVMFuncOp corresponding to the standard free call.
static mlir::LLVM::LLVMFuncOp
getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
auto module = op->getParentOfType<mlir::ModuleOp>();
if (mlir::LLVM::LLVMFuncOp freeFunc =
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("free"))
return freeFunc;
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
mlir::LLVM::LLVMFunctionType::get(voidType,
getVoidPtrType(op.getContext()),
/*isVarArg=*/false));
}

namespace {
/// Lower a `fir.freemem` instruction into `llvm.call @free`
struct FreeMemOpConversion : public FIROpConversion<fir::FreeMemOp> {
using FIROpConversion::FIROpConversion;

mlir::LogicalResult
matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::LLVM::LLVMFuncOp freeFunc = getFree(freemem, rewriter);
mlir::Location loc = freemem.getLoc();
auto bitcast = rewriter.create<mlir::LLVM::BitcastOp>(
freemem.getLoc(), voidPtrTy(), adaptor.getOperands()[0]);
freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
rewriter.create<mlir::LLVM::CallOp>(
loc, mlir::TypeRange{}, mlir::ValueRange{bitcast}, freemem->getAttrs());
rewriter.eraseOp(freemem);
return success();
}
};

/// Convert `fir.end`
struct FirEndOpConversion : public FIROpConversion<fir::FirEndOp> {
Expand Down Expand Up @@ -987,11 +1103,12 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
return mlir::LLVM::Linkage::External;
}
};
} // namespace

void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
Optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter,
mlir::Block *newBlock) {
static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
Optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter,
mlir::Block *newBlock) {
if (destOps.hasValue())
rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, destOps.getValue(),
newBlock, mlir::ValueRange());
Expand All @@ -1000,25 +1117,27 @@ void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
}

template <typename A, typename B>
void genBrOp(A caseOp, mlir::Block *dest, Optional<B> destOps,
mlir::ConversionPatternRewriter &rewriter) {
static void genBrOp(A caseOp, mlir::Block *dest, Optional<B> destOps,
mlir::ConversionPatternRewriter &rewriter) {
if (destOps.hasValue())
rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, destOps.getValue(),
dest);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, llvm::None, dest);
}

void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
Optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter) {
static void genCaseLadderStep(mlir::Location loc, mlir::Value cmp,
mlir::Block *dest,
Optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter) {
auto *thisBlock = rewriter.getInsertionBlock();
auto *newBlock = createBlock(rewriter, dest);
rewriter.setInsertionPointToEnd(thisBlock);
genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock);
rewriter.setInsertionPointToEnd(newBlock);
}

namespace {
/// Conversion of `fir.select_case`
///
/// The `fir.select_case` operation is converted to a if-then-else ladder.
Expand Down Expand Up @@ -1103,11 +1222,12 @@ struct SelectCaseOpConversion : public FIROpConversion<fir::SelectCaseOp> {
return success();
}
};
} // namespace

template <typename OP>
void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.selector();
Expand Down Expand Up @@ -1152,6 +1272,7 @@ void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
/*branchWeights=*/ArrayRef<int32_t>());
}

namespace {
/// conversion of fir::SelectOp to an if-then-else ladder
struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
using FIROpConversion::FIROpConversion;
Expand Down Expand Up @@ -1299,6 +1420,7 @@ struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
return success();
}
};
} // namespace

/// Common base class for embox to descriptor conversion.
template <typename OP>
Expand Down Expand Up @@ -1642,18 +1764,6 @@ computeTripletExtent(mlir::ConversionPatternRewriter &rewriter,
return rewriter.create<mlir::LLVM::SelectOp>(loc, cmp, extent, zero);
}

/// Helper function for generating the LLVM IR that computes the size
/// in bytes for a derived type.
static mlir::Value
computeDerivedTypeSize(mlir::Location loc, mlir::Type ptrTy, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter) {
auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, ptrTy);
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
llvm::SmallVector<mlir::Value> args{nullPtr, one};
auto gep = rewriter.create<mlir::LLVM::GEPOp>(loc, ptrTy, args);
return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, gep);
}

/// Create a generic box on a memory reference. This conversions lowers the
/// abstract box to the appropriate, initialized descriptor.
struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
Expand Down Expand Up @@ -2135,6 +2245,7 @@ struct ValueOpCommon {
}
};

namespace {
/// Extract a subobject value from an ssa-value of aggregate type
struct ExtractValueOpConversion
: public FIROpAndTypeConversion<fir::ExtractValueOp>,
Expand Down Expand Up @@ -2245,6 +2356,7 @@ struct InsertOnRangeOpConversion
return success();
}
};
} // namespace

/// XArrayCoor is the address arithmetic on a dynamically shaped, sliced,
/// shifted etc. array.
Expand Down Expand Up @@ -2414,9 +2526,10 @@ struct XArrayCoorOpConversion

/// Generate inline code for complex addition/subtraction
template <typename LLVMOP, typename OPTY>
mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
fir::LLVMTypeConverter &lowering) {
static mlir::LLVM::InsertValueOp
complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
fir::LLVMTypeConverter &lowering) {
mlir::Value a = opnds[0];
mlir::Value b = opnds[1];
auto loc = sumop.getLoc();
Expand All @@ -2436,6 +2549,7 @@ mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds,
return rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, r1, ry, c1);
}

namespace {
struct AddcOpConversion : public FIROpConversion<fir::AddcOp> {
using FIROpConversion::FIROpConversion;

Expand Down Expand Up @@ -2675,10 +2789,11 @@ struct EmboxCharOpConversion : public FIROpConversion<fir::EmboxCharOp> {
return success();
}
};
} // namespace

/// Construct an `llvm.extractvalue` instruction. It will return value at
/// element \p x from \p tuple.
mlir::LLVM::ExtractValueOp
static mlir::LLVM::ExtractValueOp
genExtractValueWithIndex(mlir::Location loc, mlir::Value tuple, mlir::Type ty,
mlir::ConversionPatternRewriter &rewriter,
mlir::MLIRContext *ctx, int x) {
Expand All @@ -2687,6 +2802,7 @@ genExtractValueWithIndex(mlir::Location loc, mlir::Value tuple, mlir::Type ty,
return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, xty, tuple, cx);
}

namespace {
/// Convert `!fir.boxchar_len` to `!llvm.extractvalue` for the 2nd part of the
/// boxchar.
struct BoxCharLenOpConversion : public FIROpConversion<fir::BoxCharLenOp> {
Expand Down Expand Up @@ -2818,26 +2934,26 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
mlir::OwningRewritePatternList pattern(context);
pattern.insert<
AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
AllocaOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion,
BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion,
BoxIsArrayOpConversion, BoxIsPtrOpConversion, BoxProcHostOpConversion,
BoxRankOpConversion, BoxTypeDescOpConversion, CallOpConversion,
CmpcOpConversion, ConstcOpConversion, ConvertOpConversion,
DispatchOpConversion, DispatchTableOpConversion, DTEntryOpConversion,
DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
FirEndOpConversion, HasValueOpConversion, GenTypeDescOpConversion,
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
InsertValueOpConversion, IsPresentOpConversion,
LenParamIndexOpConversion, LoadOpConversion, NegcOpConversion,
NoReassocOpConversion, MulcOpConversion, SelectCaseOpConversion,
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
SubcOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion,
XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(
typeConverter);
AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion,
BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion,
BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeDescOpConversion,
CallOpConversion, CmpcOpConversion, ConstcOpConversion,
ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion,
DTEntryOpConversion, DivcOpConversion, EmboxOpConversion,
EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion,
FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion,
HasValueOpConversion, GenTypeDescOpConversion, GlobalLenOpConversion,
GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
NegcOpConversion, NoReassocOpConversion, MulcOpConversion,
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
ShiftOpConversion, SliceOpConversion, StoreOpConversion,
StringLitOpConversion, SubcOpConversion, UnboxCharOpConversion,
UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
ZeroOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
Expand Down
10 changes: 10 additions & 0 deletions flang/test/Fir/Todo/allocmem.fir
@@ -0,0 +1,10 @@
// RUN: %not_todo_cmd fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s 2>&1 | FileCheck %s

// Test `fir.allocmem` of derived type with LEN parameters conversion to llvm.
// Not implemented yet.

func @allocmem_test(%arg0 : i32, %arg1 : i16) {
// CHECK: not yet implemented fir.allocmem codegen of derived type with length parameters
%0 = fir.allocmem !fir.type<_QTt(p1:i32,p2:i16){f1:i32,f2:f32}>(%arg0, %arg1 : i32, i16) {name = "_QEvar"}
return
}

0 comments on commit c2acd45

Please sign in to comment.