Skip to content

Commit

Permalink
[mlir][OpDSL] Add type function attributes.
Browse files Browse the repository at this point in the history
Previously, OpDSL operation used hardcoded type conversion operations (cast or cast_unsigned). Supporting signed and unsigned casts thus meant implementing two different operations. Type function attributes allow us to define a single operation that has a cast type function attribute which at operation instantiation time may be set to cast or cast_unsigned. We may for example, defina a matmul operation with a cast argument:

```
@linalg_structured_op
def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True),
    cast=TypeFnAttrDef(default=TypeFn.cast)):
  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
```

When instantiating the operation the attribute may be set to the desired cast function:

```
linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
```

The revsion introduces a enum in the Linalg dialect that maps one-by-one to the type functions defined by OpDSL.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D119718
  • Loading branch information
gysit committed Feb 25, 2022
1 parent 3fe6f93 commit 51fdd80
Show file tree
Hide file tree
Showing 24 changed files with 759 additions and 475 deletions.
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
Expand Up @@ -44,6 +44,18 @@ add_dependencies(mlir-headers LinalgOdsGen)

add_mlir_dialect(LinalgOps linalg)

set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen)
add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen)

add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc)
add_dependencies(LinalgOpsDocGen LinalgOdsGen)

Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
Expand Up @@ -104,6 +104,19 @@ LogicalResult verifyStructuredOpInterface(Operation *op);

#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"

//===----------------------------------------------------------------------===//
// Linalg Enums
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"

//===----------------------------------------------------------------------===//
// Linalg Attributes
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc"

//===----------------------------------------------------------------------===//
// Linalg Interfaces
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Expand Up @@ -13,6 +13,7 @@
#ifndef LINALG_BASE
#define LINALG_BASE

include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"

def Linalg_Dialect : Dialect {
Expand Down Expand Up @@ -57,4 +58,17 @@ def Linalg_Dialect : Dialect {
}];
}

// Define a TypeFn enum matching the OpDSL TypeFn class.
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // LINALG_BASE

0 comments on commit 51fdd80

Please sign in to comment.