Skip to content

Commit cab8dda

Browse files
committed
[mlir] Start splitting the tensor dialect out of std.
This starts by moving `std.extract_element` to `tensor.extract` (this mirrors the naming of `vector.extract`). Curiously, `std.extract_element` supposedly works on vectors as well, and this patch removes that functionality. I would tend to do that in separate patch, but I couldn't find any downstream users relying on this, and the fact that we have `vector.extract` made it seem safe enough to lump in here. This also sets up the `tensor` dialect as a dependency of the `std` dialect, as some ops that currently live in `std` depend on `tensor.extract` via their canonicalization patterns. Part of RFC: https://llvm.discourse.group/t/rfc-split-the-tensor-dialect-from-std/2347/2 Differential Revision: https://reviews.llvm.org/D92991
1 parent 7b3470b commit cab8dda

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+611
-311
lines changed

mlir/include/mlir/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ add_subdirectory(SCF)
1414
add_subdirectory(Shape)
1515
add_subdirectory(SPIRV)
1616
add_subdirectory(StandardOps)
17+
add_subdirectory(Tensor)
1718
add_subdirectory(Tosa)
1819
add_subdirectory(Vector)

mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
1212
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314

1415
#include "mlir/Transforms/FoldUtils.h"
1516

@@ -46,7 +47,6 @@ using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
4647
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
4748
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
4849
using folded_std_dim = FoldedValueBuilder<DimOp>;
49-
using folded_std_extract_element = FoldedValueBuilder<ExtractElementOp>;
5050
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
5151
using folded_std_muli = FoldedValueBuilder<MulIOp>;
5252
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
@@ -60,6 +60,7 @@ using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
6060
using folded_std_view = FoldedValueBuilder<ViewOp>;
6161
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
6262
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
63+
using folded_tensor_extract = FoldedValueBuilder<tensor::ExtractOp>;
6364
} // namespace intrinsics
6465
} // namespace edsc
6566
} // namespace mlir

mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_
1010

1111
#include "mlir/Dialect/StandardOps/EDSC/Builders.h"
12+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1213

1314
namespace mlir {
1415
namespace edsc {
@@ -28,7 +29,6 @@ using std_dealloc = OperationBuilder<DeallocOp>;
2829
using std_divis = ValueBuilder<SignedDivIOp>;
2930
using std_diviu = ValueBuilder<UnsignedDivIOp>;
3031
using std_dim = ValueBuilder<DimOp>;
31-
using std_extract_element = ValueBuilder<ExtractElementOp>;
3232
using std_fpext = ValueBuilder<FPExtOp>;
3333
using std_fptrunc = ValueBuilder<FPTruncOp>;
3434
using std_im = ValueBuilder<ImOp>;
@@ -52,6 +52,7 @@ using std_tensor_store = OperationBuilder<TensorStoreOp>;
5252
using std_view = ValueBuilder<ViewOp>;
5353
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
5454
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
55+
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
5556

5657
/// Branches into `block` with `operands`.
5758
BranchOp std_br(Block *block, ValueRange operands);

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,59 +1669,6 @@ def Exp2Op : FloatUnaryOp<"exp2"> {
16691669
let summary = "base-2 exponential of the specified value";
16701670
}
16711671

1672-
//===----------------------------------------------------------------------===//
1673-
// ExtractElementOp
1674-
//===----------------------------------------------------------------------===//
1675-
1676-
def ExtractElementOp : Std_Op<"extract_element",
1677-
[NoSideEffect,
1678-
TypesMatchWith<"result type matches element type of aggregate",
1679-
"aggregate", "result",
1680-
"$_self.cast<ShapedType>().getElementType()">]> {
1681-
let summary = "element extract operation";
1682-
let description = [{
1683-
The `extract_element` op reads a tensor or vector and returns one element
1684-
from it specified by an index list. The output of the 'extract_element' is a
1685-
new value with the same type as the elements of the tensor or vector. The
1686-
arity of indices matches the rank of the accessed value (i.e., if a tensor
1687-
is of rank 3, then 3 indices are required for the extract. The indices
1688-
should all be of `index` type.
1689-
1690-
Example:
1691-
1692-
```mlir
1693-
%3 = extract_element %v[%1, %2] : vector<4x4xi32>
1694-
%4 = extract_element %t[%1, %2] : tensor<4x4xi32>
1695-
%5 = extract_element %ut[%1, %2] : tensor<*xi32>
1696-
```
1697-
}];
1698-
1699-
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
1700-
Variadic<Index>:$indices);
1701-
let results = (outs AnyType:$result);
1702-
1703-
let builders = [
1704-
OpBuilderDAG<(ins "Value":$aggregate, CArg<"ValueRange", "{}">:$indices), [{
1705-
auto resType = aggregate.getType().cast<ShapedType>()
1706-
.getElementType();
1707-
build($_builder, $_state, resType, aggregate, indices);
1708-
}]>];
1709-
1710-
let extraClassDeclaration = [{
1711-
Value getAggregate() { return getOperand(0); }
1712-
1713-
operand_range getIndices() {
1714-
return {operand_begin() + 1, operand_end()};
1715-
}
1716-
}];
1717-
1718-
let hasFolder = 1;
1719-
1720-
let assemblyFormat = [{
1721-
$aggregate `[` $indices `]` attr-dict `:` type($aggregate)
1722-
}];
1723-
}
1724-
17251672
//===----------------------------------------------------------------------===//
17261673
// TensorFromElementsOp
17271674
//===----------------------------------------------------------------------===//
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_mlir_dialect(TensorOps tensor)
2+
add_mlir_doc(TensorOps -gen-dialect-doc TensorOps Dialects/)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10+
#define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11+
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/Dialect.h"
14+
#include "mlir/IR/OpDefinition.h"
15+
#include "mlir/IR/OpImplementation.h"
16+
#include "mlir/Interfaces/SideEffectInterfaces.h"
17+
18+
//===----------------------------------------------------------------------===//
19+
// Tensor Dialect
20+
//===----------------------------------------------------------------------===//
21+
22+
#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
23+
24+
//===----------------------------------------------------------------------===//
25+
// Tensor Dialect Operations
26+
//===----------------------------------------------------------------------===//
27+
28+
#define GET_OP_CLASSES
29+
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
30+
31+
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- TensorBase.td - Base definitions for tensor dialect -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef TENSOR_BASE
10+
#define TENSOR_BASE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def Tensor_Dialect : Dialect {
15+
let name = "tensor";
16+
let cppNamespace = "::mlir::tensor";
17+
let description = [{
18+
The `tensor` dialect is intended to hold core tensor creation and
19+
manipulation ops, which are not strongly associated with any particular
20+
other dialect or domain abstraction. The primary smoke test of this is ops
21+
that make sense for any tensor element type.
22+
23+
We leave it to other dialects to hold the vast swath of possible
24+
computations one might want to do on a tensor.
25+
26+
The `tensor` type is (for better or for worse) used to represent all kinds
27+
of things, and supports an open-ended set of element types. Examples:
28+
29+
- representing large, dense aggregations of primitive types, suitable for
30+
high-performance numerical computing.
31+
- representing shapes in the `shape` dialect, which consist of small
32+
1D tensors of `index` data type.
33+
- representing aggregations of strings or “variant” types.
34+
- representing large, sparse aggregations of primitive types, suitable
35+
for high-performance numerical computing.
36+
37+
Thus, for the `tensor` dialect, we prefer for now to constrain the
38+
scope as much as possible. The expectation is that at some point
39+
in the future, the `tensor` dialect’s scope may be broadened through a
40+
careful discussion of the tradeoffs.
41+
42+
The `tensor` type is actually a builtin type (it lives in the builtin
43+
dialect), and does not live in this dialect.
44+
45+
}];
46+
}
47+
48+
#endif // TENSOR_BASE
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- TensorOps.td - Tensor op definitions ----------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef TENSOR_OPS
10+
#define TENSOR_OPS
11+
12+
include "mlir/Dialect/Tensor/IR/TensorBase.td"
13+
include "mlir/Interfaces/SideEffectInterfaces.td"
14+
15+
class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
16+
: Op<Tensor_Dialect, mnemonic, traits> {
17+
let printer = [{ return ::print(p, *this); }];
18+
let verifier = [{ return ::verify(*this); }];
19+
let parser = [{ return ::parse$cppClass(parser, result); }];
20+
}
21+
22+
//===----------------------------------------------------------------------===//
23+
// ExtractOp
24+
//===----------------------------------------------------------------------===//
25+
26+
def Tensor_ExtractOp : Tensor_Op<"extract",
27+
[NoSideEffect,
28+
TypesMatchWith<"result type matches element type of tensor",
29+
"tensor", "result",
30+
"$_self.cast<ShapedType>().getElementType()">]> {
31+
let summary = "element extraction operation";
32+
let description = [{
33+
The `tensor.extract` op reads a tensor and returns one
34+
element from it specified by an index list. The output of the op is a
35+
new value with the same type as the elements of the tensor. The
36+
arity of indices must match the rank of the accessed value (i.e., if a
37+
tensor is of rank 3, then 3 indices are required for the extract. The
38+
indices should all be of `index` type.
39+
40+
Example:
41+
42+
```mlir
43+
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
44+
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
45+
%6 = tensor.extract %ut[%1, %2] : tensor<*xi32>
46+
```
47+
}];
48+
49+
let arguments = (ins AnyTensor:$tensor, Variadic<Index>:$indices);
50+
let results = (outs AnyType:$result);
51+
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
52+
53+
let builders = [
54+
OpBuilderDAG<(ins "Value":$tensor, CArg<"ValueRange", "{}">:$indices), [{
55+
auto resType = tensor.getType().cast<ShapedType>().getElementType();
56+
build($_builder, $_state, resType, tensor, indices);
57+
}]>];
58+
59+
let hasFolder = 1;
60+
}
61+
62+
#endif // TENSOR_OPS
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Tensor)
3+
add_public_tablegen_target(MLIRTensorTransformsIncGen)
4+
5+
add_mlir_doc(Passes -gen-pass-doc TensorPasses ./)

0 commit comments

Comments
 (0)