diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir index e2ab876f8b46a..b52612d0d1f10 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -24,10 +24,46 @@ // CHECK-NOT: copy // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) %0, %1 = call @inner_func(%t0) : (tensor) -> (tensor, f32) - // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref return %1, %0 : f32, tensor } "test.finish" () : () -> () }) : () -> () +// ----- +#enc1 = #test.tensor_encoding<"hello"> +#enc2 = #test.tensor_encoding<"not hello"> + +"test.symbol_scope_isolated"() ({ + // CHECK: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref>) + // CHECK-SAME: -> memref> + func.func @inner_func(%t: tensor) + -> tensor { + // CHECK: return %[[arg0]] + return %t : tensor + } + + // CHECK: func @outer_func( + // CHECK-SAME: %[[arg0:.*]]: memref>) + // CHECK-SAME: -> (memref>, + // CHECK-SAME: memref>) + func.func @outer_func(%t0: tensor) + -> (tensor, tensor) { + // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]]) + %0 = call @inner_func(%t0) + : (tensor) -> (tensor) + + // CHECK: %[[local:.*]] = "test.create_memref_op"() : () + // CHECK-SAME: -> memref> + %local = "test.create_tensor_op"() : () -> tensor + // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]]) + %1 = "test.dummy_tensor_op"(%local) : (tensor) + -> tensor + + // CHECK: return %[[call]], %[[dummy]] + return %0, %1 : tensor, tensor + } + "test.finish" () : () -> () +}) : () -> () diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp index 1e2d4a7c8f08d..4069a74dbf63c 100644 --- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp @@ -11,11 +11,25 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr +#include "TestDialect.h" + using namespace mlir; namespace { +MemRefLayoutAttrInterface +getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) { + if (auto encoding = dyn_cast_if_present( + tensorType.getEncoding())) { + return cast(test::TestMemRefLayoutAttr::get( + tensorType.getContext(), encoding.getDummy())); + } + return {}; +} + struct TestOneShotModuleBufferizePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass) @@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); registry.insert(); } StringRef getArgument() const final { @@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass bufferization::OneShotBufferizationOptions opt; opt.bufferizeFunctionBoundaries = true; + opt.functionArgTypeConverterFn = + [&](bufferization::TensorLikeType tensor, Attribute memSpace, + func::FuncOp, const bufferization::BufferizationOptions &) { + assert(isa(tensor) && "tests only builtin tensors"); + auto tensorType = cast(tensor); + auto layout = getMemRefLayoutForTensorEncoding(tensorType); + return cast( + MemRefType::get(tensorType.getShape(), + tensorType.getElementType(), layout, memSpace)); + }; + bufferization::BufferizationState bufferizationState; if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt, diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 5685004bbbd25..9e7e4f883b576 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/TensorEncoding.td" // All of the attributes will extend this class. class Test_Attr traits = []> @@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> { let hasStorageCustomConstructor = 1; } +def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding", + [DeclareAttrInterfaceMethods]> { + let mnemonic = "tensor_encoding"; + + let parameters = (ins "mlir::StringAttr":$dummy); + let assemblyFormat = "`<` $dummy `>`"; +} + +def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout", + [DeclareAttrInterfaceMethods]> { + let mnemonic = "memref_layout"; + + let parameters = (ins "mlir::StringAttr":$dummy); + let assemblyFormat = "`<` $dummy `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index fe1e9166a3099..9db7b01dd193b 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -541,6 +541,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct( return nullptr; } +//===----------------------------------------------------------------------===// +// TestTensorEncodingAttr +//===----------------------------------------------------------------------===// + +::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding( + mlir::ArrayRef shape, mlir::Type elementType, + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const { + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TestMemRefLayoutAttr +//===----------------------------------------------------------------------===// + +mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const { + return mlir::AffineMap::getMultiDimIdentityMap(1, getContext()); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h index 778d84fae7365..0ad5ab641c6d0 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -24,6 +24,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/TensorEncoding.h" // generated files require above includes to come first #include "TestAttrInterfaces.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index f2adca6310d78..bcf3b55d33cb9 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -18,6 +18,7 @@ #include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td index 2b5491fc0c6a0..37a263f1d10b8 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -24,7 +24,10 @@ def Test_Dialect : Dialect { let useDefaultTypePrinterParser = 0; let useDefaultAttributePrinterParser = 1; let isExtensible = 1; - let dependentDialects = ["::mlir::DLTIDialect"]; + let dependentDialects = [ + "::mlir::DLTIDialect", + "::mlir::bufferization::BufferizationDialect" + ]; let discardableAttrs = (ins "mlir::IntegerAttr":$discardable_attr_key, "SimpleAAttr":$other_discardable_attr_key diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 53055fea215b7..b211e243f234c 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete( return createNewMultiAllocaWithoutSlot(slot, builder, *this); } +namespace { +/// Returns test dialect's memref layout for test dialect's tensor encoding when +/// applicable. +MemRefLayoutAttrInterface +getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) { + if (auto encoding = + dyn_cast(tensorType.getEncoding())) { + return cast(test::TestMemRefLayoutAttr::get( + tensorType.getContext(), encoding.getDummy())); + } + return {}; +} + +/// Auxiliary bufferization function for test and builtin tensors. +bufferization::BufferLikeType +convertTensorToBuffer(mlir::Operation *op, + const bufferization::BufferizationOptions &options, + bufferization::TensorLikeType tensorLike) { + auto buffer = + *tensorLike.getBufferType(options, [&]() { return op->emitError(); }); + if (auto memref = dyn_cast(buffer)) { + // Note: For the sake of testing, we want to ensure that encoding -> layout + // bufferization happens. This is currently achieved manually. + auto layout = + getMemRefLayoutForTensorEncoding(cast(tensorLike)); + return cast( + MemRefType::get(memref.getShape(), memref.getElementType(), layout, + memref.getMemorySpace())); + } + return buffer; +} +} // namespace + ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( ::mlir::RewriterBase &rewriter, const ::mlir::bufferization::BufferizationOptions &options, @@ -1435,8 +1468,8 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( return failure(); const auto outType = getOutput().getType(); - const auto bufferizedOutType = test::TestMemrefType::get( - getContext(), outType.getShape(), outType.getElementType(), nullptr); + const auto bufferizedOutType = + convertTensorToBuffer(getOperation(), options, outType); // replace op with memref analogy auto dummyMemrefOp = test::TestDummyMemrefOp::create( rewriter, getLoc(), bufferizedOutType, *buffer); @@ -1470,13 +1503,12 @@ ::mlir::LogicalResult test::TestCreateTensorOp::bufferize( mlir::FailureOr test::TestCreateTensorOp::getBufferType( - mlir::Value value, const mlir::bufferization::BufferizationOptions &, + mlir::Value value, const mlir::bufferization::BufferizationOptions &options, const mlir::bufferization::BufferizationState &, llvm::SmallVector<::mlir::Value> &) { - const auto type = dyn_cast(value.getType()); + const auto type = dyn_cast(value.getType()); if (type == nullptr) return failure(); - return cast(test::TestMemrefType::get( - getContext(), type.getShape(), type.getElementType(), nullptr)); + return convertTensorToBuffer(getOperation(), options, type); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6ea27187655ee..b08933c08d132 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ValueBoundsOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" // Include the attribute definitions. include "TestAttrDefs.td" @@ -2322,7 +2323,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op", } //===----------------------------------------------------------------------===// -// Copy Operation Test +// Copy Operation Test //===----------------------------------------------------------------------===// def CopyOp : TEST_Op<"copy", []> { @@ -3663,10 +3664,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", ["bufferize", "bufferizesToMemoryRead", "bufferizesToMemoryWrite", "getAliasingValues"]>]> { let arguments = (ins - Arg:$input + Arg:$input ); let results = (outs - Arg:$output + Arg:$output ); let extraClassDefinition = [{ @@ -3688,10 +3689,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { let arguments = (ins - Arg:$input + Arg:$input ); let results = (outs - Arg:$output + Arg:$output ); } @@ -3701,7 +3702,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op", "bufferizesToMemoryWrite", "getAliasingValues", "bufferizesToAllocation"]>]> { let arguments = (ins); - let results = (outs Arg:$output); + let results = (outs Arg:$output); let extraClassDefinition = [{ bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&, const ::mlir::bufferization::AnalysisState&) { @@ -3725,7 +3726,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op", def TestCreateMemrefOp : TEST_Op<"create_memref_op"> { let arguments = (ins); - let results = (outs Arg:$output); + let results = (outs Arg:$output); } //===----------------------------------------------------------------------===//