Skip to content

Commit

Permalink
[mlir][memref][transform] Add vector_to_llvm conversion patterns
Browse files Browse the repository at this point in the history
These patterns are exposed via a new "apply_conversion_patterns" op.

Also provide a new type converter that converts from memref to LLVM types. Conversion patterns that lower to LLVM are special: they require an `LLVMTypeConverter`; a normal `TypeConverter` is not enough. This revision also adds a new interface method to pattern descriptor ops to verify that the default type converter of the enclosing "apply_conversion_patterns" op is compatible with the set of patterns. At the moment, a simple `StringRef` is used. This can evolve to a richer type in the future if needed.

Differential Revision: https://reviews.llvm.org/D157369
  • Loading branch information
matthias-springer committed Aug 9, 2023
1 parent 76e6248 commit 7ec88f0
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,33 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
"apply_conversion_patterns.memref.memref_to_llvm_type_converter",
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
["getTypeConverterType"]>]> {
let description = [{
This operation provides an "LLVMTypeConverter" that lowers memref types to
LLVM types.

The type converter can be customized as follows:
- `use_aligned_alloc`: Use aligned_alloc in place of malloc for heap
allocations.
- `index_bitwidth`: Bitwidth of the index type, "0" indicates the size of a
machine word.
- `use_generic_functions`: Use generic allocation and deallocation functions
instead of the classic "malloc", "aligned_alloc" and "free" functions.
- `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of
typed pointers.
}];

let arguments = (ins
DefaultValuedAttr<BoolAttr, "false">:$use_aligned_alloc,
DefaultValuedAttr<I64Attr, "0">:$index_bitwidth,
DefaultValuedAttr<BoolAttr, "false">:$use_generic_functions,
DefaultValuedAttr<BoolAttr, "false">:$use_opaque_pointers);
let assemblyFormat = "attr-dict";
}

def ApplyExpandOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.memref.expand_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
58 changes: 40 additions & 18 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,39 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
];
}

def TypeConverterBuilderOpInterface
: OpInterface<"TypeConverterBuilderOpInterface"> {
let description = [{
This interface should be implemented by ops that specify a type converter
for a dialect conversion. Such ops can be used with
"apply_conversion_patterns".
}];

let cppNamespace = "::mlir::transform";

let methods = [
InterfaceMethod<
/*desc=*/[{
Return the type converter to be used with a dialect conversion.
}],
/*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
/*name=*/"getTypeConverter",
/*arguments=*/(ins)
>,
StaticInterfaceMethod<
/*desc=*/[{
Return the type of type converter that this `getTypeConverter` returns.
This function is used for op verification.
}],
/*returnType=*/"StringRef",
/*name=*/"getTypeConverterType",
/*arguments=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{ return "TypeConverter"; }]
>,
];
}

def ConversionPatternDescriptorOpInterface
: OpInterface<"ConversionPatternDescriptorOpInterface"> {
let description = [{
Expand Down Expand Up @@ -300,27 +333,16 @@ def ConversionPatternDescriptorOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/"return nullptr;"
>,
];
}

def TypeConverterBuilderOpInterface
: OpInterface<"TypeConverterBuilderOpInterface"> {
let description = [{
This interface should be implemented by ops that specify a type converter
for a dialect conversion. Such ops can be used with
"apply_conversion_patterns".
}];

let cppNamespace = "::mlir::transform";

let methods = [
InterfaceMethod<
/*desc=*/[{
Return the type converter to be used with a dialect conversion.
Verify the default type converter that is provided by the enclosing
"apply_conversion_patterns" op.
}],
/*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
/*name=*/"getTypeConverter",
/*arguments=*/(ins)
/*returnType=*/"::mlir::LogicalResult",
/*name=*/"verifyTypeConverter",
/*arguments=*/(ins "TypeConverterBuilderOpInterface":$builder),
/*methodBody=*/"",
/*defaultImplementation=*/"return success();"
>,
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@ include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
"apply_conversion_patterns.vector.vector_to_llvm",
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
["verifyTypeConverter"]>]> {
let description = [{
Collects patterns that convert vector dialect ops to LLVM dialect ops. These
patterns require an "LLVMTypeConverter".

The patterns can be customized as follows:
- `reassociate_fp_reductions`: Allows LLVM to reassociate floating-point
reductions for speed.
- `force_32bit_vector_indices`: Allows the compiler to assume that vector
indices fit in 32-bit if that yields faster code.
}];

let arguments = (ins
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
let assemblyFormat = "attr-dict";
}


def ApplyCastAwayVectorLeadingOneDimPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.cast_away_vector_leading_one_dim",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Value;
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
virtual ~TypeConverter() = default;

/// This class provides all of the information necessary to convert a type
/// signature.
class SignatureConversion {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ add_mlir_dialect_library(MLIRMemRefTransformOps
MLIRAffineDialect
MLIRArithDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRMemRefTransforms
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand All @@ -26,6 +28,29 @@ using namespace mlir;
#define DEBUG_TYPE "memref-transforms"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")

//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//

std::unique_ptr<TypeConverter>
transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
LowerToLLVMOptions options(getContext());
options.allocLowering =
(getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
: LowerToLLVMOptions::AllocLowering::Malloc);
options.useGenericFunctions = getUseGenericFunctions();
options.useOpaquePointers = getUseOpaquePointers();

if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(getIndexBitwidth());

return std::make_unique<LLVMTypeConverter>(getContext(), options);
}

StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
return "LLVMTypeConverter";
}

//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 12 additions & 2 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,14 +589,24 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
if (!llvm::hasSingleElement(typeConverterRegion.front()))
return emitOpError()
<< "expected exactly one op in default type converter region";
Operation *typeConverterOp = &typeConverterRegion.front().front();
if (!isa<transform::TypeConverterBuilderOpInterface>(typeConverterOp)) {
auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
&typeConverterRegion.front().front());
if (!typeConverterOp) {
InFlightDiagnostic diag = emitOpError()
<< "expected default converter child op to "
"implement TypeConverterBuilderOpInterface";
diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
return diag;
}
// Check default type converter type.
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
auto descriptor =
cast<transform::ConversionPatternDescriptorOpInterface>(&op);
if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
return failure();
}
}
}
if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() &&
!getIllegalDialects())
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ add_mlir_dialect_library(MLIRVectorTransformOps

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
MLIRVectorToLLVM
MLIRVectorTransforms
MLIRSideEffectInterfaces
MLIRTransformDialect
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
Expand All @@ -23,6 +26,25 @@ using namespace mlir;
using namespace mlir::vector;
using namespace mlir::transform;

//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//

void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateVectorToLLVMConversionPatterns(
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
getReassociateFpReductions(), getForce_32bitVectorIndices());
}

LogicalResult
transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
}

//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s

// CHECK-LABEL: func @lower_to_llvm
// CHECK-NOT: vector.bitcast
// CHECK: llvm.bitcast
func.func @lower_to_llvm(%input: vector<f32>) -> vector<i32> {
%0 = vector.bitcast %input : vector<f32> to vector<i32>
return %0 : vector<i32>
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_conversion_patterns to %0 {
transform.apply_conversion_patterns.vector.vector_to_llvm
} with type_converter {
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
} {legal_dialects = ["func", "llvm"]} : !transform.any_op
}
3 changes: 3 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4146,12 +4146,14 @@ cc_library(
":ArithDialect",
":AsmParser",
":IR",
":LLVMCommonConversion",
":LLVMDialect",
":SideEffectInterfaces",
":TransformDialect",
":TransformUtils",
":VectorDialect",
":VectorEnumsIncGen",
":VectorToLLVM",
":VectorToSCF",
":VectorTransformOpsIncGen",
":VectorTransforms",
Expand Down Expand Up @@ -11510,6 +11512,7 @@ cc_library(
":AffineDialect",
":ArithDialect",
":IR",
":LLVMCommonConversion",
":LoopLikeInterface",
":MemRefDialect",
":MemRefTransformOpsIncGen",
Expand Down

0 comments on commit 7ec88f0

Please sign in to comment.