Skip to content

Commit

Permalink
[mlir][sparse] Move a few routines to CodegenUtils.
Browse files Browse the repository at this point in the history
Move a few supporting routines for generating function calls to CodegenUtils so
that they can be used by the codegen path for sparse tensor file input and
output.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135691
  • Loading branch information
bixia1 committed Oct 11, 2022
1 parent 51db96a commit 2d252a0
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 46 deletions.
33 changes: 33 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,36 @@ void mlir::sparse_tensor::translateIndicesArray(
}
assert(dstIndices.size() == dstRank);
}

FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
TypeRange resultType,
ValueRange operands,
EmitCInterface emitCInterface) {
MLIRContext *context = module.getContext();
auto result = SymbolRefAttr::get(context, name);
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
if (!func) {
OpBuilder moduleBuilder(module.getBodyRegion());
func = moduleBuilder.create<func::FuncOp>(
module.getLoc(), name,
FunctionType::get(context, operands.getTypes(), resultType));
func.setPrivate();
if (static_cast<bool>(emitCInterface))
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
}
return result;
}

func::CallOp mlir::sparse_tensor::createFuncCall(
OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
ValueRange operands, EmitCInterface emitCInterface) {
auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
FlatSymbolRefAttr fn =
getFunc(module, name, resultType, operands, emitCInterface);
return builder.create<func::CallOp>(loc, resultType, fn, operands);
}

Type mlir::sparse_tensor::getOpaquePointerType(OpBuilder &builder) {
return LLVM::LLVMPointerType::get(builder.getI8Type());
}
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/ExecutionEngine/SparseTensor/Enums.h"
Expand All @@ -28,6 +30,10 @@ class Value;

namespace sparse_tensor {

/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
enum class EmitCInterface : bool { Off = false, On = true };

//===----------------------------------------------------------------------===//
// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate
// loop structure to (co-iterate) sparse tensors.
Expand Down Expand Up @@ -225,6 +231,23 @@ void translateIndicesArray(OpBuilder &builder, Location loc,
ArrayRef<Value> dstShape,
SmallVectorImpl<Value> &dstIndices);

/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// of ABI complications with passing in and returning MemRefs to C functions.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType,
ValueRange operands, EmitCInterface emitCInterface);

/// Creates a `CallOp` to the function reference returned by `getFunc()` in
/// the builder's module.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name,
TypeRange resultType, ValueRange operands,
EmitCInterface emitCInterface);

/// Returns the equivalent of `void*` for opaque arguments to the
/// execution engine.
Type getOpaquePointerType(OpBuilder &builder);

//===----------------------------------------------------------------------===//
// Inlined constant generators.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand All @@ -36,61 +34,17 @@ using namespace mlir::sparse_tensor;

namespace {

/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
enum class EmitCInterface : bool { Off = false, On = true };

//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//

/// Returns the equivalent of `void*` for opaque arguments to the
/// execution engine.
static Type getOpaquePointerType(OpBuilder &builder) {
return LLVM::LLVMPointerType::get(builder.getI8Type());
}

/// Maps each sparse tensor type to an opaque pointer.
static Optional<Type> convertSparseTensorTypes(Type type) {
if (getSparseTensorEncoding(type) != nullptr)
return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
return llvm::None;
}

/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// of ABI complications with passing in and returning MemRefs to C functions.
static FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name,
TypeRange resultType, ValueRange operands,
EmitCInterface emitCInterface) {
MLIRContext *context = module.getContext();
auto result = SymbolRefAttr::get(context, name);
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
if (!func) {
OpBuilder moduleBuilder(module.getBodyRegion());
func = moduleBuilder.create<func::FuncOp>(
module.getLoc(), name,
FunctionType::get(context, operands.getTypes(), resultType));
func.setPrivate();
if (static_cast<bool>(emitCInterface))
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
}
return result;
}

/// Creates a `CallOp` to the function reference returned by `getFunc()` in
/// the builder's module.
static func::CallOp createFuncCall(OpBuilder &builder, Location loc,
StringRef name, TypeRange resultType,
ValueRange operands,
EmitCInterface emitCInterface) {
auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto fn = getFunc(module, name, resultType, operands, emitCInterface);
return builder.create<func::CallOp>(loc, resultType, fn, operands);
}

/// Replaces the `op` with a `CallOp` to the function reference returned
/// by `getFunc()`.
static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
Expand Down

0 comments on commit 2d252a0

Please sign in to comment.