From db5855234d2dc3eeff56bcaf998a1dfa2a4baa19 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 17 Jan 2024 19:41:21 -0800 Subject: [PATCH] [MLIR] Add a new interface for "IR parameterization" This implements the ability to define "meta program": that is a mechanism similar to C++ template. So as an example, this input IR: ``` testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} { %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A"> testparametric.print_attr #testparametric.param<"B"> return } func.func @caller() { %cst0 = arith.constant 0 : i32 %cst1 = arith.constant 1. : f32 %cst2 = arith.constant 2. : f64 testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> () testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> () testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> () return } ``` Will see the @callee parametric function be instantiated for each call-site: ``` func.func @caller() { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 1.000000e+00 : f32 %cst_0 = arith.constant 2.000000e+00 : f64 testparametric.call @callee$__mlir_instance__$A$i32$B$32(%c0_i32) meta = {} : (i32) -> () testparametric.call @callee$__mlir_instance__$A$f32$B$64(%cst) meta = {} : (f32) -> () testparametric.call @callee$__mlir_instance__$A$f64$B$128(%cst_0) meta = {} : (f64) -> () return } testparametric.func @callee$__mlir_instance__$A$f32$B$64(%arg0: f32) { %0 = add %arg0, %arg0 : (f32, f32) -> f32 print_attr 64 : i64 return } testparametric.func @callee$__mlir_instance__$A$f64$B$128(%arg0: f64) { %0 = add %arg0, %arg0 : (f64, f64) -> f64 print_attr 128 : i64 return } testparametric.func @callee$__mlir_instance__$A$i32$B$32(%arg0: i32) { %0 = add %arg0, %arg0 : (i32, i32) -> i32 print_attr 32 : i64 return } ``` --- mlir/include/mlir/IR/SymbolInterfaces.td | 2 +- mlir/include/mlir/Interfaces/CMakeLists.txt | 1 + .../ParametricSpecializationOpInterface.h | 25 ++ .../ParametricSpecializationOpInterface.td | 46 +++ .../Transforms/ParametricSpecialization.h | 11 + mlir/lib/Interfaces/CMakeLists.txt | 2 + .../ParametricSpecializationOpInterface.cpp | 13 + mlir/lib/Transforms/CMakeLists.txt | 2 + .../Transforms/ParametricSpecialization.cpp | 13 + mlir/test/Parametric/ops.mlir | 18 ++ mlir/test/lib/Dialect/CMakeLists.txt | 1 + .../lib/Dialect/TestParametric/CMakeLists.txt | 68 ++++ .../TestParametric/TestParametricAttrDefs.td | 38 +++ .../TestParametricAttributes.cpp | 42 +++ .../TestParametric/TestParametricAttributes.h | 33 ++ .../TestParametric/TestParametricDialect.cpp | 297 ++++++++++++++++++ .../TestParametric/TestParametricDialect.h | 45 +++ .../TestParametric/TestParametricDialect.td | 27 ++ .../TestParametricInterfaces.cpp | 11 + .../TestParametric/TestParametricInterfaces.h | 33 ++ .../TestParametricInterfaces.td | 34 ++ .../TestParametric/TestParametricOps.td | 202 ++++++++++++ .../TestParametric/TestParametricTypeDefs.td | 37 +++ .../TestParametric/TestParametricTypes.cpp | 42 +++ .../TestParametric/TestParametricTypes.h | 154 +++++++++ .../lib/Dialect/TestParametric/lit.local.cfg | 1 + mlir/test/lib/Transforms/CMakeLists.txt | 1 + .../TestParametricSpecialization.cpp | 191 +++++++++++ mlir/tools/mlir-lsp-server/CMakeLists.txt | 1 + .../tools/mlir-lsp-server/mlir-lsp-server.cpp | 4 + mlir/tools/mlir-opt/CMakeLists.txt | 1 + mlir/tools/mlir-opt/mlir-opt.cpp | 6 + 32 files changed, 1401 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h create mode 100644 mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td create mode 100644 mlir/include/mlir/Transforms/ParametricSpecialization.h create mode 100644 mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp create mode 100644 mlir/lib/Transforms/ParametricSpecialization.cpp create mode 100644 mlir/test/Parametric/ops.mlir create mode 100644 mlir/test/lib/Dialect/TestParametric/CMakeLists.txt create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricOps.td create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h create mode 100644 mlir/test/lib/Dialect/TestParametric/lit.local.cfg create mode 100644 mlir/test/lib/Transforms/TestParametricSpecialization.cpp diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index 844601f8f6837..68bae5d3f991d 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -154,7 +154,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { "bool", "isDeclaration", (ins), [{}], /*defaultImplementation=*/[{ // By default, assume that the operation defines a symbol. - return false; + return false; }] >, ]; diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index d81298bb4daf0..2f3e34e266e3f 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(ParallelCombiningOpInterface) +add_mlir_interface(ParametricSpecializationOpInterface) add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) add_mlir_interface(SideEffectInterfaces) diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h new file mode 100644 index 0000000000000..88770e7239ac0 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h @@ -0,0 +1,25 @@ +//===- ParametricSpecializationOpInterface.h - Parallel combining op interface +//---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the operation interface for ops that parallel combining +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_ +#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/ParametricSpecializationOpInterface.h.inc" + +#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_ diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td new file mode 100644 index 0000000000000..e3c12d6b4b60f --- /dev/null +++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td @@ -0,0 +1,46 @@ +//===-- ParametricSpecializationOpInterface.td -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES +#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES + +include "mlir/IR/OpBase.td" + +def ParametricOpInterface : OpInterface<"ParametricOpInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "::mlir::LogicalResult", "specialize", (ins + "::mlir::DictionaryAttr":$params)>, + InterfaceMethod<"", + "::mlir::LogicalResult", "checkOperand", (ins + "::mlir::OpOperand &":$operand, + "::mlir::Type":$concreteType)>, + InterfaceMethod<"Only for symbol operation which will be cloned, mangle in-place.", + "::mlir::FailureOr<::mlir::StringAttr>", "getMangledName", (ins + "::mlir::DictionaryAttr":$metaArgs), "", [{ + return failure(); + }] +>, + ]; +} + +def SpecializingOpInterface : OpInterface<"SpecializingOpInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "::mlir::SymbolRefAttr", "getTarget", (ins)>, + InterfaceMethod<"", + "::mlir::DictionaryAttr", "getMetaArgs", (ins)>, + InterfaceMethod<"", + "::mlir::LogicalResult", "setSpecializedTarget", (ins + "::mlir::SymbolOpInterface":$target)>, + ]; +} + +#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES diff --git a/mlir/include/mlir/Transforms/ParametricSpecialization.h b/mlir/include/mlir/Transforms/ParametricSpecialization.h new file mode 100644 index 0000000000000..1bbe3e2a557ef --- /dev/null +++ b/mlir/include/mlir/Transforms/ParametricSpecialization.h @@ -0,0 +1,11 @@ +//===- RemoveDeadValues.h - Specialize Meta Program -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Operation.h" + +namespace mlir {} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index e7c76e70ed6b5..1998b66f168f3 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES LoopLikeInterface.cpp MemorySlotInterfaces.cpp ParallelCombiningOpInterface.cpp + ParametricSpecializationOpInterface.cpp RuntimeVerifiableOpInterface.cpp ShapedOpInterfaces.cpp SideEffectInterfaces.cpp @@ -80,6 +81,7 @@ add_mlir_library(MLIRLoopLikeInterface add_mlir_interface_library(MemorySlotInterfaces) add_mlir_interface_library(ParallelCombiningOpInterface) +add_mlir_interface_library(ParametricSpecializationOpInterface) add_mlir_interface_library(RuntimeVerifiableOpInterface) add_mlir_interface_library(ShapedOpInterfaces) add_mlir_interface_library(SideEffectInterfaces) diff --git a/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp new file mode 100644 index 0000000000000..80fc2caf0d12a --- /dev/null +++ b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp @@ -0,0 +1,13 @@ +//===- ParametricSpecializationOpInterface.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ParametricSpecializationOpInterface.h" +#include "mlir/Support/LogicalResult.h" + +/// Include the definitions of the interface. +#include "mlir/Interfaces/ParametricSpecializationOpInterface.cpp.inc" diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index af51a4ab1157f..8254f9d212c60 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(MLIRTransforms LoopInvariantCodeMotion.cpp Mem2Reg.cpp OpStats.cpp + ParametricSpecialization.cpp PrintIR.cpp RemoveDeadValues.cpp SCCP.cpp @@ -32,6 +33,7 @@ add_mlir_library(MLIRTransforms MLIRFunctionInterfaces MLIRLoopLikeInterface MLIRMemorySlotInterfaces + MLIRParametricSpecializationOpInterface MLIRPass MLIRRuntimeVerifiableOpInterface MLIRSideEffectInterfaces diff --git a/mlir/lib/Transforms/ParametricSpecialization.cpp b/mlir/lib/Transforms/ParametricSpecialization.cpp new file mode 100644 index 0000000000000..fcc3daacad447 --- /dev/null +++ b/mlir/lib/Transforms/ParametricSpecialization.cpp @@ -0,0 +1,13 @@ +//===- RemoveDeadValues.cpp - Specialize Meta Program ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/ParametricSpecialization.h" + +using namespace mlir; + +void specialize(Operation *op) {} \ No newline at end of file diff --git a/mlir/test/Parametric/ops.mlir b/mlir/test/Parametric/ops.mlir new file mode 100644 index 0000000000000..ed8c87cd48cce --- /dev/null +++ b/mlir/test/Parametric/ops.mlir @@ -0,0 +1,18 @@ + + +testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} { + %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A"> + testparametric.print_attr #testparametric.param<"B"> + return +} + +func.func @caller() { + %cst0 = arith.constant 0 : i32 + %cst1 = arith.constant 1. : f32 + %cst2 = arith.constant 2. : f64 + testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> () + testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> () + testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> () + testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> () + return +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index 30a17c201ff76..8c1be74f15899 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Test) add_subdirectory(TestDyn) +add_subdirectory(TestParametric) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt new file mode 100644 index 0000000000000..dcc79f15993b7 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt @@ -0,0 +1,68 @@ +set(LLVM_OPTIONAL_SOURCES + TestParametricDialect.cpp +) + +set(LLVM_TARGET_DEFINITIONS TestParametricInterfaces.td) +mlir_tablegen(TestParametricAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TestParametricAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TestParametricTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TestParametricTypeInterfaces.cpp.inc -gen-type-interface-defs) +mlir_tablegen(TestParametricOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TestParametricOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTestParametricInterfaceIncGen) + +set(LLVM_TARGET_DEFINITIONS TestParametricOps.td) +mlir_tablegen(TestParametricAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TestParametricAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRTestParametricAttrDefIncGen) + +set(LLVM_TARGET_DEFINITIONS TestParametricTypeDefs.td) +mlir_tablegen(TestParametricTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=testparametric) +mlir_tablegen(TestParametricTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=testparametric) +add_public_tablegen_target(MLIRTestParametricTypeDefIncGen) + +set(LLVM_TARGET_DEFINITIONS TestParametricOps.td) +mlir_tablegen(TestParametricOps.h.inc -gen-op-decls) +mlir_tablegen(TestParametricOps.cpp.inc -gen-op-defs) +mlir_tablegen(TestParametricOpsDialect.h.inc -gen-dialect-decls -dialect=testparametric) +mlir_tablegen(TestParametricOpsDialect.cpp.inc -gen-dialect-defs -dialect=testparametric) +add_public_tablegen_target(MLIRTestParametricOpsIncGen) + +# Exclude testparametrics from libMLIR.so +add_mlir_library(MLIRTestParametricDialect + TestParametricAttributes.cpp + TestParametricDialect.cpp + TestParametricInterfaces.cpp + TestParametricTypes.cpp + + EXCLUDE_FROM_LIBMLIR + + DEPENDS + MLIRTestParametricAttrDefIncGen + MLIRTestParametricInterfaceIncGen + MLIRTestParametricTypeDefIncGen + MLIRTestParametricOpsIncGen + + LINK_LIBS PUBLIC + MLIRControlFlowInterfaces + MLIRDataLayoutInterfaces + MLIRDerivedAttributeOpInterface + MLIRDestinationStyleOpInterface + MLIRDialect + MLIRDLTIDialect + MLIRFuncDialect + MLIRFunctionInterfaces + MLIRFuncTransforms + MLIRIR + MLIRInferIntRangeInterface + MLIRInferTypeOpInterface + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLLVMDialect + MLIRPass + MLIRReduce + MLIRTensorDialect + MLIRTransformUtils + MLIRTransforms +) + diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td new file mode 100644 index 0000000000000..c9133a99a3441 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td @@ -0,0 +1,38 @@ +//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TableGen data attribute definitions for Test dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TESTPARAMETRIC_ATTRDEFS +#define TESTPARAMETRIC_ATTRDEFS + +// To get the test dialect definition. +include "TestParametricDialect.td" +include "mlir/Dialect/Utils/StructuredOpsUtils.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpAsmInterface.td" + +// All of the attributes will extend this class. +class TestParametric_Attr traits = []> + : AttrDef; + +def TestParametric_ParamAttr : TestParametric_Attr<"Param"> { + let mnemonic = "param"; + // List of type parameters. + let parameters = ( + ins + "::mlir::StringAttr":$ref + ); + let assemblyFormat = "`<` $ref `>`"; +} + +#endif // TESTPARAMETRIC_ATTRDEFS diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp new file mode 100644 index 0000000000000..a5dde555b3e89 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp @@ -0,0 +1,42 @@ +//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains attributes defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#include "TestParametricAttributes.h" +#include "TestParametricDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/bit.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace testparametric; + +//===----------------------------------------------------------------------===// +// TestParametricDialect +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "TestParametricAttrDefs.cpp.inc" + +void TestParametricDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "TestParametricAttrDefs.cpp.inc" + >(); +} diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h new file mode 100644 index 0000000000000..ec728ee34b5bd --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h @@ -0,0 +1,33 @@ +//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTPARAMETRICATTRIBUTES_H +#define MLIR_TESTPARAMETRICATTRIBUTES_H + +#include + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" + +#include "TestParametricAttrInterfaces.h.inc" +#include "mlir/IR/DialectResourceBlobManager.h" + +namespace testparametric {} // namespace testparametric + +#define GET_ATTRDEF_CLASSES +#include "TestParametricAttrDefs.h.inc" + +#endif // MLIR_TESTPARAMETRICATTRIBUTES_H diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp new file mode 100644 index 0000000000000..693d93910e818 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp @@ -0,0 +1,297 @@ +//===- TestParametricDialect.cpp - MLIR Dialect for Testing +//----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestParametricDialect.h" +#include "TestParametricAttributes.h" +#include "TestParametricInterfaces.h" +#include "TestParametricTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/ODSSupport.h" +#include "mlir/IR/OperationSupport.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Base64.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +// Include this before the using namespace lines below to +// test that we don't have namespace dependencies. +#include "TestParametricOpsDialect.cpp.inc" + +using namespace mlir; +using namespace testparametric; + +void TestParametricDialect::initialize() { + registerAttributes(); + registerTypes(); + addOperations< +#define GET_OP_LIST +#include "TestParametricOps.cpp.inc" + >(); +} +void testparametric::registerTestParametricDialect(DialectRegistry ®istry) { + registry.insert(); +} + +#include "TestParametricOpInterfaces.cpp.inc" +#include "TestParametricTypeInterfaces.cpp.inc" + +#define GET_OP_CLASSES +#include "TestParametricOps.cpp.inc" + +::mlir::ParseResult ParametricFuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void ParametricFuncOp::print(mlir::OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + ParametricFuncOp fn = + symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + DictionaryAttr metaParams = fn.getMetaParamsAttr(); + DictionaryAttr metaArgs = getMetaArgs(); + if (metaParams && metaArgs.size() != metaParams.size()) + return emitOpError("incorrect number of meta operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + auto operandType = getOperand(i).getType(); + auto paramType = fnType.getInput(i); + if (auto metaParamType = dyn_cast(paramType)) { + auto metaArg = metaArgs.get(metaParamType.getRef()); + if (!metaArg) + return emitOpError("Missing meta args for type operand ") + << metaParamType.getRef(); + auto metaArgType = dyn_cast(metaArg); + if (!metaArgType) + return emitOpError("Expected TypeAttr for meta args ") + << metaParamType.getRef() << ", got " << metaArg; + if (metaArgType.getValue() != operandType) + return emitOpError("Mismatch between operand type and meta args type: ") + << operandType << " vs " << metaArgType; + continue; + } + if (operandType != paramType) { + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + } + } + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + } + return success(); +} + +/// Specialization Interface Implementation + +static LogicalResult replaceValueType(Value value, Type newType) { + for (OpOperand &use : value.getUses()) { + if (auto paramOp = dyn_cast(use.getOwner())) { + if (failed(paramOp.checkOperand(use, newType))) { + paramOp.emitOpError() << "fails to replace operand type for operand #" + << use.getOperandNumber() << " with " << newType; + return failure(); + } + } + } + value.setType(newType); + return success(); +} +static LogicalResult replaceValueType(Value value, DictionaryAttr metaArgs) { + auto paramType = dyn_cast(value.getType()); + if (!paramType) + return success(); + auto metaArg = + llvm::dyn_cast_or_null(metaArgs.get(paramType.getRef())); + if (!metaArg) { + if (value.getDefiningOp()) + value.getDefiningOp()->emitError() + << "expected TypeAttr for specializing meta arg " << paramType + << ", got " << metaArgs; + return failure(); + } + return replaceValueType(value, metaArg.getValue()); +} + +LogicalResult ParametricFuncOp::specialize(DictionaryAttr metaArgs) { + auto mangledName = getMangledName(metaArgs); + if (failed(mangledName)) + return failure(); + setSymNameAttr(*mangledName); + removeMetaParamsAttr(); + + auto specializeTypes = [&](auto typeRange, SmallVector &specialized) { + for (Type ty : typeRange) { + auto paramType = dyn_cast(ty); + if (!paramType) { + specialized.push_back(ty); + continue; + } + auto metaArg = + llvm::dyn_cast_or_null(metaArgs.get(paramType.getRef())); + if (!metaArg) { + emitOpError() << "expected TypeAttr for specializing meta arg " + << paramType << ", got " << metaArgs; + return failure(); + } + specialized.push_back(metaArg.getValue()); + } + return success(); + }; + auto fnType = getFunctionType(); + SmallVector argTypes, resTypes; + if (failed(specializeTypes(fnType.getInputs(), argTypes))) + return failure(); + if (failed(specializeTypes(fnType.getResults(), resTypes))) + return failure(); + for (auto argTypes : llvm::zip(argTypes, this->getArguments())) { + auto newType = std::get<0>(argTypes); + auto blockArg = std::get<1>(argTypes); + if (failed(replaceValueType(blockArg, newType))) + return failure(); + } + + setFunctionType(FunctionType::get(getContext(), argTypes, resTypes)); + if (getFunctionBody() + .walk([&](Operation *op) { + if (auto parametricOp = dyn_cast(op)) { + if (failed(parametricOp.specialize(metaArgs))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return failure(); + return success(); +} + +LogicalResult ParametricFuncOp::checkOperand(mlir::OpOperand &, mlir::Type) { + return success(); +} + +FailureOr +ParametricFuncOp::getMangledName(DictionaryAttr metaArgs) { + auto name = getNameAttr(); + if (!name) + return failure(); + std::string mangledName; + llvm::raw_string_ostream os(mangledName); + os << name.getValue() << "$__mlir_instance__"; + for (NamedAttribute name : metaArgs) { + os << "$" << name.getName().getValue(); + Attribute value = name.getValue(); + if (auto intAttr = dyn_cast(value)) + os << "$" << intAttr.getValue(); + else + os << "$" << value; + } + + return StringAttr::get(getContext(), os.str()); +} + +LogicalResult AddOp::specialize(DictionaryAttr metaArgs) { + if (failed(replaceValueType(getResult(), metaArgs))) + return failure(); + return success(); +} + +LogicalResult AddOp::checkOperand(mlir::OpOperand &, mlir::Type) { + return success(); +} + +SymbolRefAttr CallOp::getTarget() { return getCalleeAttr(); } + +LogicalResult CallOp::setSpecializedTarget(SymbolOpInterface target) { + // TODO: check validity first. + setCalleeAttr(SymbolRefAttr::get(target.getNameAttr())); + setMetaArgsAttr(DictionaryAttr::get(getContext())); + return success(); +} + +LogicalResult PrintAttrOp::specialize(DictionaryAttr metaArgs) { + auto valueAttr = dyn_cast_or_null(getValueAttr()); + if (!valueAttr) + return success(); + auto metaArg = metaArgs.get(valueAttr.getRef()); + if (!metaArg) { + emitOpError() << "failed to specialize, missing " << valueAttr.getRef() + << " entry in " << metaArgs; + return failure(); + } + setValueAttr(metaArg); + return success(); +} + +LogicalResult PrintAttrOp::checkOperand(mlir::OpOperand &, mlir::Type) { + return success(); +} diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h new file mode 100644 index 0000000000000..510c94c1100dc --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h @@ -0,0 +1,45 @@ +//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a fake 'test' dialect that can be used for testing things +// that do not have a respective counterpart in the main source directories. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTDIALECT_H +#define MLIR_TESTDIALECT_H + +#include "TestParametricAttributes.h" +#include "TestParametricInterfaces.h" +#include "TestParametricTypes.h" + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/ParametricSpecializationOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + +#include "TestParametricOpInterfaces.h.inc" +#include "TestParametricOpsDialect.h.inc" + +#define GET_OP_CLASSES +#include "TestParametricOps.h.inc" + +namespace testparametric { +void registerTestParametricDialect(::mlir::DialectRegistry ®istry); +} // namespace testparametric + +#endif // MLIR_TESTDIALECT_H diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td new file mode 100644 index 0000000000000..54575cdc723ee --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td @@ -0,0 +1,27 @@ +//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TESTPARAMETRIC_DIALECT +#define TESTPARAMETRIC_DIALECT + +include "mlir/IR/OpBase.td" + +def TestParametric_Dialect : Dialect { + let name = "testparametric"; + let cppNamespace = "::testparametric"; + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; + + let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + private: + }]; +} + +#endif // TESTPARAMETRIC_DIALECT diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp new file mode 100644 index 0000000000000..597654c639e81 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp @@ -0,0 +1,11 @@ +//===- TestInterfaces.cpp - MLIR interfaces for testing ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestParametricInterfaces.h" + +using namespace mlir; diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h new file mode 100644 index 0000000000000..092c24058c017 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h @@ -0,0 +1,33 @@ +//===- TestInterfaces.h - MLIR interfaces for testing -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares interfaces for the 'test' dialect that can be used for +// testing the interface infrastructure. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H +#define MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H + +#include "mlir/IR/BuiltinAttributes.h" + +#include "llvm/ADT/DenseMap.h" + +namespace mlir { + +class SpecializationParams { +public: + SpecializationParams() {} + +private: + DenseMap params; +}; + +} // namespace mlir + +#endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td new file mode 100644 index 0000000000000..954ab0cac1fcf --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td @@ -0,0 +1,34 @@ +//===-- TestInterfaces.td - Test dialect interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES +#define MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES + +include "mlir/IR/OpBase.td" + +def TestParametricOpInterface : OpInterface<"TestParametricOpInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "LogicalResult", "specializeAttr", (ins + "::mlir::StringAttr":$parameterName, + "::mlir::Attribute":$concreteAttr)>, + InterfaceMethod<"", + "LogicalResult", "specializeType", (ins + "::mlir::StringAttr":$parameterName, + "::mlir::Type":$concreteType)>, + InterfaceMethod<"", + "LogicalResult", "checkOperand", (ins + "::mlir::OpOperand &":$operand, + "::mlir::Type":$concreteType)>, + ]; +} + + + +#endif // MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td b/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td new file mode 100644 index 0000000000000..43c04c763957c --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td @@ -0,0 +1,202 @@ +//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TESTPARAMETRIC_OPS +#define TESTPARAMETRIC_OPS + +include "TestParametricDialect.td" +include "mlir/Dialect/DLTI/DLTIBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/PatternBase.td" +include "mlir/IR/RegionKindInterface.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/ParametricSpecializationOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + + +// Include the attribute definitions. +include "TestParametricAttrDefs.td" +// Include the type definitions. +include "TestParametricTypeDefs.td" + + +class TESTParametric_Op traits = []> : + Op; + +def TESTParametric_ParametricFuncOp : TESTParametric_Op<"func", [ + AutomaticAllocationScope, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, + DeclareOpInterfaceMethods + ]> { + let summary = "Parametric function."; + let description = [{ + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$metaParams, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the argument types of this function. + ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ::llvm::ArrayRef<::mlir::Type> getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + /// Allow the dialect prefix to be omitted. + static ::llvm::StringRef getDefaultDialect() { return "testparametric"; } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + + //===------------------------------------------------------------------===// + // ParametricOpInterface Methods + //===------------------------------------------------------------------===// + + ::mlir::FailureOr<::mlir::StringAttr> getMangledName(::mlir::DictionaryAttr); + }]; + let hasCustomAssemblyFormat = 1; +} + + +def ReturnOp : TESTParametric_Op<"return", [Pure, HasParent<"ParametricFuncOp">, + ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + }]; + + let arguments = (ins Variadic:$operands); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +def CallOp : TESTParametric_Op<"call", + [CallOpInterface, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + }]; + + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + DictionaryAttr:$metaArgs + ); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::TypeRange":$results, + CArg<"mlir::ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "mlir::StringAttr":$callee, "mlir::TypeRange":$results, + CArg<"mlir::ValueRange", "{}">:$operands), [{ + build($_builder, $_state, mlir::SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "llvm::StringRef":$callee, "mlir::TypeRange":$results, + CArg<"mlir::ValueRange", "{}">:$operands), [{ + build($_builder, $_state, mlir::StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + mlir::FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + mlir::MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` `meta` `=` $metaArgs attr-dict `:` functional-type($operands, results) + }]; +} + +def AddOp : TESTParametric_Op<"add", [Pure, DeclareOpInterfaceMethods]> { + let summary = "Add operation"; + let description = [{ + }]; + + let arguments = (ins + AnyType:$lhs, + AnyType:$rhs + ); + let results = (outs + AnyType:$result + ); + + let assemblyFormat = "$lhs `` `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +def PrintAttrOp : TESTParametric_Op<"print_attr", [DeclareOpInterfaceMethods]> { + let summary = "Print operation"; + let description = [{ + }]; + + let arguments = (ins + AnyAttr:$value + ); + + let assemblyFormat = "$value attr-dict"; +} + + + +#endif // TESTPARAMETRIC_OPS diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td new file mode 100644 index 0000000000000..d3f3d0327db91 --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td @@ -0,0 +1,37 @@ +//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TableGen data type definitions for Test dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TESTPARAMETRIC_TYPEDEFS +#define TESTPARAMETRIC_TYPEDEFS + +// To get the test dialect def. +include "TestParametricDialect.td" +include "TestParametricAttrDefs.td" +include "TestParametricInterfaces.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/Interfaces/DataLayoutInterfaces.td" + +// All of the types will extend this class. +class TestParametric_Type traits = []> + : TypeDef; + +def TestParametric_ParamType : TestParametric_Type<"Param"> { + let mnemonic = "param"; + // List of type parameters. + let parameters = ( + ins + "::mlir::StringAttr":$ref + ); + let assemblyFormat = "`<` $ref `>`"; +} + +#endif // TESTPARAMETRIC_TYPEDEFS diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp new file mode 100644 index 0000000000000..aaad6b880361b --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp @@ -0,0 +1,42 @@ +//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#include "TestParametricTypes.h" +#include "TestParametricDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/TypeSize.h" +#include + +using namespace mlir; +using namespace testparametric; + +#define GET_TYPEDEF_CLASSES +#include "TestParametricTypeDefs.cpp.inc" + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + +void TestParametricDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "TestParametricTypeDefs.cpp.inc" + >(); +} diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h new file mode 100644 index 0000000000000..0e397d2bfb1ce --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h @@ -0,0 +1,154 @@ +//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTPARAMETRICTYPES_H +#define MLIR_TESTPARAMETRICTYPES_H + +#include +#include + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" + +namespace test { +class TestAttrWithFormatAttr; + +/// FieldInfo represents a field in the StructType data type. It is used as a +/// parameter in TestTypeDefs.td. +struct FieldInfo { + ::llvm::StringRef name; + ::mlir::Type type; + + // Custom allocation called from generated constructor code + FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const { + return FieldInfo{alloc.copyInto(name), type}; + } +}; + +/// A custom type for a test type parameter. +struct CustomParam { + int value; + + bool operator==(const CustomParam &other) const { + return other.value == value; + } +}; + +inline llvm::hash_code hash_value(const test::CustomParam ¶m) { + return llvm::hash_value(param.value); +} + +} // namespace test + +namespace mlir { +template <> +struct FieldParser { + static FailureOr parse(AsmParser &parser) { + auto value = FieldParser::parse(parser); + if (failed(value)) + return failure(); + return test::CustomParam{*value}; + } +}; + +inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer, + test::CustomParam param) { + return printer << param.value; +} + +/// Overload the attribute parameter parser for optional integers. +template <> +struct FieldParser> { + static FailureOr> parse(AsmParser &parser) { + std::optional value; + value.emplace(); + OptionalParseResult result = parser.parseOptionalInteger(*value); + if (result.has_value()) { + if (succeeded(*result)) + return value; + return failure(); + } + value.reset(); + return value; + } +}; +} // namespace mlir + +#include "TestParametricTypeInterfaces.h.inc" + +namespace test { + +/// Storage for simple named recursive types, where the type is identified by +/// its name and can "contain" another type, including itself. +struct TestRecursiveTypeStorage : public ::mlir::TypeStorage { + using KeyTy = ::llvm::StringRef; + + explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key) {} + + bool operator==(const KeyTy &other) const { return name == other; } + + static TestRecursiveTypeStorage * + construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + TestRecursiveTypeStorage(allocator.copyInto(key)); + } + + ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, + ::mlir::Type newBody) { + // Cannot set a different body than before. + if (body && body != newBody) + return ::mlir::failure(); + + body = newBody; + return ::mlir::success(); + } + + ::llvm::StringRef name; + ::mlir::Type body; +}; + +/// Simple recursive type identified by its name and pointing to another named +/// type, potentially itself. This requires the body to be mutated separately +/// from type creation. +class TestRecursiveType + : public ::mlir::Type::TypeBase { +public: + using Base::Base; + + static constexpr ::mlir::StringLiteral name = "test.recursive"; + + static TestRecursiveType get(::mlir::MLIRContext *ctx, + ::llvm::StringRef name) { + return Base::get(ctx, name); + } + + /// Body getter and setter. + ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); } + ::mlir::Type getBody() const { return getImpl()->body; } + + /// Name/key getter. + ::llvm::StringRef getName() { return getImpl()->name; } +}; + +} // namespace test + +#define GET_TYPEDEF_CLASSES +#include "TestParametricTypeDefs.h.inc" + +#endif // MLIR_TESTPARAMETRICTYPES_H diff --git a/mlir/test/lib/Dialect/TestParametric/lit.local.cfg b/mlir/test/lib/Dialect/TestParametric/lit.local.cfg new file mode 100644 index 0000000000000..65a7f202dc82a --- /dev/null +++ b/mlir/test/lib/Dialect/TestParametric/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.remove(".td") diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 2a3a8608db544..038064417bf52 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_library(MLIRTestTransforms TestControlFlowSink.cpp TestInlining.cpp TestIntRangeInference.cpp + TestParametricSpecialization.cpp TestMakeIsolatedFromAbove.cpp TestTopologicalSort.cpp ${MLIRTestTransformsPDLSrc} diff --git a/mlir/test/lib/Transforms/TestParametricSpecialization.cpp b/mlir/test/lib/Transforms/TestParametricSpecialization.cpp new file mode 100644 index 0000000000000..157d4b044c503 --- /dev/null +++ b/mlir/test/lib/Transforms/TestParametricSpecialization.cpp @@ -0,0 +1,191 @@ +//===- TestParametricSpecialization.cpp - Pass for metaprog +// specialization--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Threading.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/ParametricSpecializationOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/ParametricSpecialization.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include +#include + +#define DEBUG_TYPE "parametric-specialization" + +using namespace mlir; + +namespace { + +struct SpecializingRequest { + /// Op to specialize + ParametricOpInterface targetOp; + /// The arguments to specialize it with. + DictionaryAttr metaArgs; + /// The "callers" to update + SmallVector callers; + /// The operation post-specialization + OwningOpRef specialized; +}; + +struct TestParametricSpecializationPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParametricSpecializationPass) + + StringRef getArgument() const final { + return "test-parametric-specialization"; + } + StringRef getDescription() const final { + return "Test the parametric specialization of parametric programs."; + } + + void runOnOperation() override { + Operation *op = getOperation(); + if (!op->hasTrait()) { + op->emitOpError() + << getArgument() + << " pass can only run on an operation that defines a SymbolTable"; + signalPassFailure(); + } + OpBuilder builder(op->getContext()); + SymbolTable symTab(op); + + MLIRContext &ctx = getContext(); + + // Walk the body of the module, and find "roots": operations that are + // already specialized. We'll use these as "roots" to specialize the + // parametric ones. + SmallVector rootOps; + for (Operation &nestedOp : op->getRegion(0).getOps()) { + if (!isa(nestedOp)) + rootOps.push_back(&nestedOp); + } + + std::map specializationRequests; + std::mutex tasksMutex; + LogicalResult result = success(); + + // Run in parallel on every root, and for each, walk the body and find + // "calls" to functions that need specialization. + result = failableParallelForEach(&ctx, rootOps, [&](Operation *root) { + auto result = root->walk([&](Operation *innerOp) { + auto specializingOp = dyn_cast(innerOp); + if (!specializingOp) + return WalkResult::advance(); + auto targetNameAttr = specializingOp.getTarget(); + auto targetOp = symTab.lookup(targetNameAttr.getRootReference()); + if (!targetOp) { + innerOp->emitOpError() + << "can't find target '" << targetNameAttr << "' in SymbolTable"; + return WalkResult::interrupt(); + } + auto parametricTargetOp = dyn_cast(targetOp); + if (!parametricTargetOp) { + auto diag = targetOp->emitOpError(); + diag << "expected target to implement 'ParametricOpInterface'"; + diag.attachNote() << "while specializing " << *innerOp; + return WalkResult::interrupt(); + } + auto metaArgs = specializingOp.getMetaArgs(); + auto failureOrMangledName = parametricTargetOp.getMangledName(metaArgs); + if (failed(failureOrMangledName)) { + parametricTargetOp->emitOpError() + << "failed to mangled with meta args " << metaArgs; + return WalkResult::interrupt(); + } + StringAttr mangledName = *failureOrMangledName; + std::unique_lock lock(tasksMutex); + auto &request = specializationRequests[mangledName.getValue()]; + if (request.targetOp && request.targetOp != targetOp) { + auto diag = targetOp->emitOpError(); + diag << "unexpected mangling collision while specializing with " + "meta args " + << metaArgs << ", mangled name " << mangledName; + diag.attachNote() << "while specializing " << *innerOp; + return WalkResult::interrupt(); + } + request.targetOp = parametricTargetOp; + request.metaArgs = metaArgs; + request.callers.push_back(specializingOp); + LLVM_DEBUG({ llvm::errs() << "Request for " << mangledName << "\n"; }); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + }); + if (failed(result)) { + signalPassFailure(); + return; + } + LLVM_DEBUG({ + llvm::errs() << "Got " << specializationRequests.size() << " requests\n"; + }); + + std::map *, OwningOpRef>> + specializationResults; + result = failableParallelForEach( + &ctx, specializationRequests, + [&](std::pair &request) { + ParametricOpInterface targetOp = request.second.targetOp; + DictionaryAttr metaArgs = request.second.metaArgs; + OwningOpRef specializedOp(targetOp.clone()); + if (failed(specializedOp->specialize(metaArgs))) { + std::unique_lock lock(tasksMutex); + targetOp->emitOpError() << "failed to specialize with " << metaArgs; + return failure(); + } + request.second.specialized = std::move(specializedOp); + return success(); + }); + if (failed(result)) { + signalPassFailure(); + return; + } + + llvm::ThreadPool &threadPool = ctx.getThreadPool(); + llvm::ThreadPoolTaskGroup tasksGroup(threadPool); + for (auto &request : specializationRequests) { + Operation *op = request.second.specialized.get(); + symTab.insert(op); + LLVM_DEBUG({ + llvm::errs() << "Inserted " << cast(op).getName() + << "\n"; + }); + tasksGroup.async([&] { + Operation *op = request.second.specialized.release(); + auto specializedOp = cast(op); + for (SpecializingOpInterface caller : request.second.callers) { + if (failed(caller.setSpecializedTarget(specializedOp))) { + std::unique_lock lock(tasksMutex); + caller->emitOpError() << "failed to specialize\n"; + signalPassFailure(); + } + } + }); + } + tasksGroup.wait(); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestParametricSpecializationPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt index 0134b54eef1b0..6480056c66e2b 100644 --- a/mlir/tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt @@ -18,6 +18,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTestAnalysis MLIRTestDialect MLIRTestDynDialect + MLIRTestParametricDialect MLIRTestIR MLIRTestPass MLIRTestReducer diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index f0ecc5adc68b3..f64e278237d59 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -20,6 +20,9 @@ void registerTestDialect(DialectRegistry &); void registerTestDynDialect(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test +namespace testparametric { +void registerTestParametricDialect(DialectRegistry &); +} // namespace testparametric #endif int main(int argc, char **argv) { @@ -31,6 +34,7 @@ int main(int argc, char **argv) { ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); + ::testparametric::registerTestParametricDialect(registry); #endif return failed(MlirLspServerMain(argc, argv, registry)); } diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 9ad5b32c24f9d..176152f4fd0f4 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -35,6 +35,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTestAnalysis MLIRTestDialect MLIRTestDynDialect + MLIRTestParametricDialect MLIRTestIR MLIRTestOneToNTypeConversionPass MLIRTestPass diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 428bdd9691e09..54aaf1072ba16 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -125,6 +125,7 @@ void registerTestNextAccessPass(); void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); +void registerTestParametricSpecializationPass(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); @@ -154,6 +155,9 @@ void registerTestDynDialect(DialectRegistry &); void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test +namespace testparametric { +void registerTestParametricDialect(DialectRegistry &); +} // namespace testparametric #ifdef MLIR_INCLUDE_TESTS void registerTestPasses() { @@ -248,6 +252,7 @@ void registerTestPasses() { mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); + mlir::test::registerTestParametricSpecializationPass(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSCFWhileOpBuilderPass(); @@ -293,6 +298,7 @@ int main(int argc, char **argv) { ::test::registerTestTransformDialectExtension(registry); ::test::registerTestTilingInterfaceTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); + ::testparametric::registerTestParametricDialect(registry); #endif return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "MLIR modular optimizer driver\n", registry));