Skip to content

Commit 6c9541d

Browse files
Implement simple type polymorphism for linalg named ops.
* It was decided that this was the end of the line for the existing custom tc parser/generator, and this is the first step to replacing it with a declarative format that maps well to mathy source languages. * One such source language is implemented here: https://github.com/stellaraccident/mlir-linalgpy/blob/main/samples/mm.py * In fact, this is the exact source of the declarative `polymorphic_matmul` in this change. * I am working separately to clean this python implementation up and add it to MLIR (probably as `mlir.tools.linalg_opgen` or equiv). The scope of the python side is greater than just generating named ops: the ops are callable and directly emit `linalg.generic` ops fully dynamically, and this is intended to be a feature for frontends like npcomp to define custom linear algebra ops at runtime. * There is more work required to handle full type polymorphism, especially with respect to integer formulations, since they require more specificity wrt types. * Followups to this change will bring the new generator to feature parity with the current one and delete the current. Roughly, this involves adding support for interface declarations and attribute symbol bindings. Differential Revision: https://reviews.llvm.org/D97135
1 parent b568d3d commit 6c9541d

File tree

9 files changed

+1131
-10
lines changed

9 files changed

+1131
-10
lines changed

mlir/docs/Dialects/Linalg.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,18 @@ void batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
662662
}
663663
```
664664

665+
### YAML Based Named Structured Ops
666+
667+
Linalg provides a declarative generation tool (`mlir-linalg-ods-yaml-gen`) to
668+
automatically produce named ops from a YAML-based op description format
669+
intended to capture the structure of the named ops and be generated from a
670+
higher level "mathy" DSL syntax. This facility is currently in flight and is
671+
intended to subsume the above when ready. See the C++ class to YAML mapping
672+
traits in `mlir-mlinalg-ods-yaml-gen.cpp` as the source of truth for the schema.
673+
674+
Most of the above documentation roughly applies to this path and will be ported
675+
as migration continues.
676+
665677
## Open Issues and Design Alternatives<a name="open_issues"></a>
666678

667679
Multiple open issues and design alternatives are in flight and it is time to lay

mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Declare a function to generate ODS with mlir-linalg-ods-gen
2-
function(add_linalg_ods_gen tc_filename output_file)
2+
function(add_linalg_ods_tc_gen tc_filename output_file)
33
set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename})
4-
set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.td)
5-
set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.cpp.inc)
4+
set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.td)
5+
set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.cpp.inc)
66
set_source_files_properties(
77
${GEN_ODS_FILE}
88
PROPERTIES GENERATED TRUE)
@@ -20,17 +20,52 @@ function(add_linalg_ods_gen tc_filename output_file)
2020
${MLIR_LINALG_ODS_GEN_TARGET}
2121
VERBATIM)
2222
add_custom_target(
23-
MLIR${output_file}IncGen
23+
MLIR${output_file}TcIncGen
2424
DEPENDS
2525
${MLIR_LINALG_ODS_GEN_EXE}
2626
${MLIR_LINALG_ODS_GEN_TARGET}
2727
${GEN_ODS_FILE} ${GEN_CPP_FILE})
2828
endfunction()
2929

30-
add_linalg_ods_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
30+
# Declare a function to generate ODS with mlir-linalg-ods-yaml-gen
31+
function(add_linalg_ods_yaml_gen yaml_ast_file output_file)
32+
set(YAML_AST_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${yaml_ast_file})
33+
set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.td)
34+
set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.cpp.inc)
35+
set_source_files_properties(
36+
${GEN_ODS_FILE}
37+
PROPERTIES GENERATED TRUE)
38+
set_source_files_properties(
39+
${GEN_CPP_FILE}
40+
PROPERTIES GENERATED TRUE)
41+
add_custom_command(
42+
OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE}
43+
COMMAND ${MLIR_LINALG_ODS_YAML_GEN_EXE} ${YAML_AST_SOURCE} -o-ods-decl=${GEN_ODS_FILE} -o-impl=${GEN_CPP_FILE}
44+
MAIN_DEPENDENCY
45+
${YAML_AST_SOURCE}
46+
DEPENDS
47+
${MLIR_LINALG_ODS_YAML_GEN_EXE}
48+
${MLIR_LINALG_ODS_YAML_GEN_TARGET})
49+
add_custom_target(
50+
MLIR${output_file}YamlIncGen
51+
DEPENDS
52+
${MLIR_LINALG_ODS_YAML_GEN_EXE}
53+
${MLIR_LINALG_ODS_YAML_GEN_TARGET}
54+
${GEN_ODS_FILE} ${GEN_CPP_FILE})
55+
endfunction()
56+
57+
# TODO: Delete tc generation and replace with the YAML variant once all ops are
58+
# ported.
59+
add_linalg_ods_tc_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
60+
add_linalg_ods_yaml_gen(LinalgNamedStructuredOps.yaml LinalgNamedStructuredOps)
61+
3162
# Provide a short name for all external dependency that needs to
3263
# include Linalg in ODS
33-
add_custom_target(LinalgOdsGen DEPENDS MLIRLinalgNamedStructuredOpsIncGen)
64+
add_custom_target(LinalgOdsGen
65+
DEPENDS
66+
MLIRLinalgNamedStructuredOpsTcIncGen
67+
MLIRLinalgNamedStructuredOpsYamlIncGen
68+
)
3469
add_dependencies(mlir-headers LinalgOdsGen)
3570

3671
add_mlir_dialect(LinalgOps linalg)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
--- !LinalgOpConfig
2+
metadata: !LinalgOpMetadata
3+
name: polymorphic_matmul
4+
cpp_op_name: PolymorphicMatmulOp
5+
doc: |-
6+
Type polymorphic matrix multiplication.
7+
8+
This op is presently here to test a new path for generation and will replace
9+
the existing 'matmul' op when ready. Do not use.
10+
structured_op: !LinalgStructuredOpConfig
11+
args:
12+
- !<LinalgTensorDef>
13+
name: A
14+
usage: input
15+
shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
16+
- !<LinalgTensorDef>
17+
name: B
18+
usage: input
19+
shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
20+
- !<LinalgTensorDef>
21+
name: C
22+
usage: output
23+
shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
24+
indexing_maps: !LinalgIndexingMapsConfig
25+
static_indexing_maps:
26+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
27+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
28+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
29+
iterator_types:
30+
- parallel
31+
- parallel
32+
- reduction
33+
assignments:
34+
- !ScalarAssign
35+
arg: C
36+
value: !ScalarExpression
37+
scalar_apply:
38+
fn_name: add
39+
operands:
40+
- !ScalarExpression
41+
scalar_arg: C
42+
- !ScalarExpression
43+
scalar_apply:
44+
fn_name: mul
45+
operands:
46+
- !ScalarExpression
47+
scalar_arg: A
48+
- !ScalarExpression
49+
scalar_arg: B
50+

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def ConvOp : PoolingBase_Op<"conv", []> {
343343
// parallelized across; i.e. [zs] in the TF notation above whose number
344344
// match `xs` (i.e. 1 window loop per "image" dimension).
345345
// This may evolve in the future.
346-
// Conditionally check nPar is large enough for cases of ill-formed op:
346+
// Conditionally check nPar is large enough for cases of ill-formed op:
347347
// this avoids overflows before hitting the verifier.
348348
assert(nPar > getNumBatchDimensions() + getNumInputFeatureDimensions() &&
349349
"expected at least one window dimension (i.e. memref ranks greater "
@@ -806,6 +806,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
806806
//===----------------------------------------------------------------------===//
807807

808808
// This file is auto-generated from a TC def specification.
809-
include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
809+
include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td"
810+
include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td"
810811

811812
#endif // LINALG_STRUCTURED_OPS

mlir/lib/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalg
1414
LINK_LIBS PUBLIC
1515
MLIRAffine
1616
MLIRIR
17+
MLIRParser
1718
MLIRSideEffectInterfaces
1819
MLIRViewLikeInterface
1920
MLIRStandard

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/Matchers.h"
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/Parser.h"
2324

2425
#include "llvm/ADT/DenseMap.h"
2526
#include "llvm/ADT/SetVector.h"
@@ -121,6 +122,81 @@ static LogicalResult foldMemRefCast(Operation *op) {
121122
return success(folded);
122123
}
123124

125+
//===----------------------------------------------------------------------===//
126+
// Region builder helper.
127+
// TODO: Move this to a utility library.
128+
// The public methods on this class are referenced directly from generated code
129+
// and bind by name to math functions in the DSL as:
130+
// `applyfn__{fnName}`
131+
// Examples:
132+
// `applyfn__add`
133+
// `applyfn__mul`
134+
// The naming convention is intentional in order to match snake-cased DSL names.
135+
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
136+
//
137+
// Implementations of the math functions must be polymorphic over numeric types,
138+
// internally performing necessary casts. If the function application makes no
139+
// sense, then the only recourse is to assert and return nullptr. This can be
140+
// extended later if it becomes possible to fail construction of the region. The
141+
// invariant should be enforced at a higher level.
142+
//
143+
// TODO: These helpers are currently type polymorphic over the class of integer
144+
// and floating point types, but they will not internally cast within bit
145+
// widths of a class (mixed precision such as i8->i32) or across classes
146+
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
147+
// to be handled with care and work is being considered to extend the op
148+
// language to make such cases explicit. In the mean-time, violating this will
149+
// fail verification, which is deemed acceptable.
150+
//===----------------------------------------------------------------------===//
151+
152+
namespace {
153+
154+
class RegionBuilderHelper {
155+
public:
156+
RegionBuilderHelper(Block &block) : block(block) {}
157+
158+
Value applyfn__add(Value lhs, Value rhs) {
159+
OpBuilder builder = getBuilder(lhs);
160+
if (isFloatingPoint(lhs))
161+
return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs);
162+
else if (isInteger(lhs))
163+
return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs);
164+
llvm_unreachable("unsupported non numeric type");
165+
}
166+
167+
Value applyfn__mul(Value lhs, Value rhs) {
168+
OpBuilder builder = getBuilder(lhs);
169+
if (isFloatingPoint(lhs))
170+
return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs);
171+
else if (isInteger(lhs))
172+
return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs);
173+
llvm_unreachable("unsupported non numeric type");
174+
}
175+
176+
void yieldOutputs(ValueRange values) {
177+
assert(!values.empty() && "linalg ops must yield outputs");
178+
if (values.empty())
179+
return;
180+
Value first = values.front();
181+
OpBuilder builder = getBuilder(first);
182+
builder.create<YieldOp>(first.getLoc(), values);
183+
}
184+
185+
private:
186+
Block &block;
187+
188+
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
189+
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
190+
191+
OpBuilder getBuilder(Value value) {
192+
OpBuilder builder(value.getContext());
193+
builder.setInsertionPointToEnd(&block);
194+
return builder;
195+
}
196+
};
197+
198+
} // namespace
199+
124200
//===----------------------------------------------------------------------===//
125201
// CopyOp
126202
//===----------------------------------------------------------------------===//
@@ -1868,7 +1944,8 @@ struct EraseDeadLinalgOp;
18681944
struct FoldTensorCastOp;
18691945
} // namespace
18701946

1871-
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
1947+
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
1948+
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
18721949

18731950
#define GET_OP_CLASSES
18741951
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@@ -2032,7 +2109,8 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
20322109
unsigned actual = body->getNumArguments();
20332110
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
20342111
if (expected != actual) {
2035-
if (errorHandler) errorHandler(expected, actual);
2112+
if (errorHandler)
2113+
errorHandler(expected, actual);
20362114
return;
20372115
}
20382116

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
2+
3+
func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
4+
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
5+
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
6+
return %0: tensor<16x32xf32>
7+
}
8+
9+
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
10+
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
11+
// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
12+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
13+
// CHECK-NEXT: -> tensor<16x32xf32>
14+
15+
// -----
16+
17+
func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
18+
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
19+
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
20+
return %0: tensor<16x32xi32>
21+
}
22+
23+
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32)
24+
// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32
25+
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
26+
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
27+
// CHECK-NEXT: -> tensor<16x32xi32>

mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ set(LLVM_LINK_COMPONENTS
22
Core
33
Support
44
)
5+
6+
set(LLVM_OPTIONAL_SOURCES
7+
mlir-linalg-ods-gen.cpp
8+
mlir-linalg-ods-yaml-gen.cpp
9+
)
10+
11+
# Original mlir-linalg-ods-gen (to be replaced).
512
add_llvm_tool(mlir-linalg-ods-gen
613
mlir-linalg-ods-gen.cpp
714
)
@@ -30,3 +37,35 @@ if(LLVM_USE_HOST_TOOLS)
3037
endif()
3138
endif()
3239
endif()
40+
41+
42+
# New mlir-linalg-ods-yaml-gen.
43+
add_llvm_tool(mlir-linalg-ods-yaml-gen
44+
mlir-linalg-ods-yaml-gen.cpp
45+
)
46+
llvm_update_compile_flags(mlir-linalg-ods-yaml-gen)
47+
target_link_libraries(mlir-linalg-ods-yaml-gen PRIVATE
48+
MLIRIR
49+
MLIRSupport
50+
MLIRParser
51+
)
52+
53+
set(MLIR_LINALG_ODS_YAML_GEN mlir-linalg-ods-yaml-gen CACHE
54+
STRING "Native mlir-linalg-ods-yaml-gen executable. Saves building one when cross-compiling.")
55+
56+
set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN} PARENT_SCOPE)
57+
set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen PARENT_SCOPE)
58+
59+
if(LLVM_USE_HOST_TOOLS)
60+
if ("${MLIR_LINALG_ODS_YAML_GEN_EXE}" STREQUAL mlir-linalg-ods-yaml-gen)
61+
build_native_tool(mlir-linalg-ods-yaml-gen MLIR_LINALG_ODS_YAML_GEN_EXE DEPENDS mlir-linalg-ods-yaml-gen)
62+
set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN_EXE} PARENT_SCOPE)
63+
64+
add_custom_target(mlir-linalg-ods-yaml-gen-host DEPENDS ${MLIR_LINALG_ODS_YAML_GEN_EXE})
65+
set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen-host DEPENDS PARENT_SCOPE)
66+
67+
if(NOT LLVM_BUILD_UTILS)
68+
set_target_properties(mlir-linalg-ods-yaml-gen PROPERTIES EXCLUDE_FROM_ALL ON)
69+
endif()
70+
endif()
71+
endif()

0 commit comments

Comments
 (0)