Skip to content

Commit 129d6e5

Browse files
committed
[mlir] Move std.tensor_cast -> tensor.cast.
This is almost entirely mechanical. Differential Revision: https://reviews.llvm.org/D93357
1 parent a555ca8 commit 129d6e5

File tree

39 files changed

+500
-471
lines changed

39 files changed

+500
-471
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace linalg {
3434
class LinalgDependenceGraph;
3535

3636
/// A struct containing the Linalg producer before and after fusion.
37-
/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
37+
/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
3838
/// before the consumer Linalg op, until enough canonicalizations have applied.
3939
struct FusionInfo {
4040
LinalgOp originalProducer;

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -354,31 +354,6 @@ computeRankReductionMask(ArrayRef<int64_t> originalShape,
354354
/// ```
355355
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
356356

357-
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
358-
/// Determines whether TensorCastOp casts to a more dynamic version of the
359-
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
360-
/// implement canonicalization patterns for ops in different dialects that may
361-
/// consume the results of tensor_cast operations. Such foldable tensor_cast
362-
/// operations are typically inserted as `subtensor` ops and are canonicalized,
363-
/// to preserve the type compatibility of their uses.
364-
///
365-
/// Returns true when all conditions are met:
366-
/// 1. source and result are ranked tensors with same element type and rank.
367-
/// 2. the tensor type has more static information than the result
368-
///
369-
/// Example:
370-
/// ```mlir
371-
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
372-
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
373-
/// ```
374-
///
375-
/// folds into:
376-
///
377-
/// ```mlir
378-
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
379-
/// ```
380-
bool canFoldIntoConsumerOp(TensorCastOp castOp);
381-
382357
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
383358
/// comparison predicates.
384359
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
6262
let printer = [{
6363
return printStandardCastOp(this->getOperation(), p);
6464
}];
65-
let verifier = [{ return ::verifyCastOp(*this); }];
65+
let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
6666

6767
let hasFolder = 1;
6868
}
@@ -3428,56 +3428,6 @@ def TanhOp : FloatUnaryOp<"tanh"> {
34283428
}];
34293429
}
34303430

3431-
//===----------------------------------------------------------------------===//
3432-
// TensorCastOp
3433-
//===----------------------------------------------------------------------===//
3434-
3435-
def TensorCastOp : CastOp<"tensor_cast"> {
3436-
let summary = "tensor cast operation";
3437-
let description = [{
3438-
Syntax:
3439-
3440-
```
3441-
operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type
3442-
```
3443-
3444-
Convert a tensor from one type to an equivalent type without changing any
3445-
data elements. The source and destination types must both be tensor types
3446-
with the same element type. If both are ranked, then the rank should be the
3447-
same and static dimensions should match. The operation is invalid if
3448-
converting to a mismatching constant dimension.
3449-
3450-
Example:
3451-
3452-
```mlir
3453-
// Convert from unknown rank to rank 2 with unknown dimension sizes.
3454-
%2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
3455-
%2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
3456-
3457-
// Convert to a type with more known dimensions.
3458-
%3 = "std.tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
3459-
3460-
// Discard static dimension and rank information.
3461-
%4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
3462-
%5 = "std.tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
3463-
```
3464-
}];
3465-
3466-
let arguments = (ins AnyTensor:$source);
3467-
let results = (outs AnyTensor);
3468-
3469-
let extraClassDeclaration = [{
3470-
/// Return true if `a` and `b` are valid operand and result pairs for
3471-
/// the operation.
3472-
static bool areCastCompatible(Type a, Type b);
3473-
3474-
/// The result of a tensor_cast is always a tensor.
3475-
TensorType getType() { return getResult().getType().cast<TensorType>(); }
3476-
}];
3477-
3478-
let hasCanonicalizer = 1;
3479-
}
3480-
34813431
//===----------------------------------------------------------------------===//
34823432
// TensorLoadOp
34833433
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,38 @@
2828
#define GET_OP_CLASSES
2929
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
3030

31+
//===----------------------------------------------------------------------===//
32+
// Tensor Dialect Helpers
33+
//===----------------------------------------------------------------------===//
34+
35+
namespace mlir {
36+
namespace tensor {
37+
38+
/// Determines whether tensor::CastOp casts to a more dynamic version of the
39+
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
40+
/// implement canonicalization patterns for ops in different dialects that may
41+
/// consume the results of tensor.cast operations. Such foldable tensor.cast
42+
/// operations are typically inserted as `subtensor` ops and are canonicalized,
43+
/// to preserve the type compatibility of their uses.
44+
///
45+
/// Returns true when all conditions are met:
46+
/// 1. source and result are ranked tensors with same element type and rank.
47+
/// 2. the tensor type has more static information than the result
48+
///
49+
/// Example:
50+
/// ```mlir
51+
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
52+
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
53+
/// ```
54+
///
55+
/// folds into:
56+
///
57+
/// ```mlir
58+
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
59+
/// ```
60+
bool canFoldIntoConsumerOp(CastOp castOp);
61+
62+
} // namespace tensor
63+
} // namespace mlir
64+
3165
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,52 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
1919
let parser = [{ return ::parse$cppClass(parser, result); }];
2020
}
2121

22+
//===----------------------------------------------------------------------===//
23+
// CastOp
24+
//===----------------------------------------------------------------------===//
25+
26+
def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
27+
let summary = "tensor cast operation";
28+
let description = [{
29+
Convert a tensor from one type to an equivalent type without changing any
30+
data elements. The source and destination types must both be tensor types
31+
with the same element type. If both are ranked, then the rank should be the
32+
same and static dimensions should match. The operation is invalid if
33+
converting to a mismatching constant dimension.
34+
35+
Example:
36+
37+
```mlir
38+
// Convert from unknown rank to rank 2 with unknown dimension sizes.
39+
%2 = tensor.cast %1 : tensor<*xf32> to tensor<?x?xf32>
40+
41+
// Convert to a type with more known dimensions.
42+
%3 = tensor.cast %2 : tensor<?x?xf32> to tensor<4x?xf32>
43+
44+
// Discard static dimension and rank information.
45+
%4 = tensor.cast %3 : tensor<4x?xf32> to tensor<?x?xf32>
46+
%5 = tensor.cast %4 : tensor<?x?xf32> to tensor<*xf32>
47+
```
48+
}];
49+
50+
let arguments = (ins AnyTensor:$source);
51+
let results = (outs AnyTensor:$dest);
52+
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
53+
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
54+
55+
let extraClassDeclaration = [{
56+
/// Return true if `a` and `b` are valid operand and result pairs for
57+
/// the operation.
58+
static bool areCastCompatible(Type a, Type b);
59+
60+
/// The result of a tensor.cast is always a tensor.
61+
TensorType getType() { return getResult().getType().cast<TensorType>(); }
62+
}];
63+
64+
let hasFolder = 1;
65+
let hasCanonicalizer = 1;
66+
}
67+
2268
//===----------------------------------------------------------------------===//
2369
// ExtractOp
2470
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,11 +1775,18 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p);
17751775
// These functions are out-of-line implementations of the methods in CastOp,
17761776
// which avoids them being template instantiated/duplicated.
17771777
namespace impl {
1778+
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
1779+
// need for them, but some older ODS code in `std` still depends on them).
17781780
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
17791781
Type destType);
17801782
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
17811783
void printCastOp(Operation *op, OpAsmPrinter &p);
1784+
// TODO: Create a CastOpInterface with a method areCastCompatible.
1785+
// Also, consider adding functionality to CastOpInterface to be able to perform
1786+
// the ChainedTensorCast canonicalization generically.
17821787
Value foldCastOp(Operation *op);
1788+
LogicalResult verifyCastOp(Operation *op,
1789+
function_ref<bool(Type, Type)> areCastCompatible);
17831790
} // namespace impl
17841791
} // end namespace mlir
17851792

mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
1+
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -tensor-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
22
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
33
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
44
// RUN: | FileCheck %s
@@ -8,7 +8,7 @@ func @main() {
88
%b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
99

1010
%addf = addf %a, %b : tensor<3xf32>
11-
%addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
11+
%addf_unranked = tensor.cast %addf : tensor<3xf32> to tensor<*xf32>
1212
call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
1313
// CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
1414
// CHECK-NEXT: [11, 22, 33]

mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
1+
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
2+
// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
3+
// RUN: -finalizing-bufferize \
24
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
35
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
46
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -15,14 +17,14 @@ func @main() {
1517
%inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
1618
%inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32>
1719

18-
%unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
20+
%unranked_at_position_0 = tensor.cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
1921
call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> ()
2022

2123
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
2224
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
2325
// CHECK-NEXT: [20, 10]
2426

25-
%unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
27+
%unranked_at_position_1 = tensor.cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
2628
call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> ()
2729

2830
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
1+
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
2+
// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
3+
// RUN: -finalizing-bufferize \
24
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
35
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
46
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -9,7 +11,7 @@ func @main() {
911
%insert_val = constant dense<20.0> : tensor<1xf32>
1012
%inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
1113

12-
%unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
14+
%unranked = tensor.cast %inserted : tensor<2xf32> to tensor<*xf32>
1315
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
1416

1517
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize \
2-
// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
2+
// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \
33
// RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \
44
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
55
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
@@ -19,7 +19,7 @@ func @main() {
1919
// Note that this is skipping a step and we would need at least some function
2020
// attribute to declare that this conversion is valid (e.g. when we statically
2121
// know that things will play nicely at the C ABI boundary).
22-
%unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
22+
%unranked = tensor.cast %0 : tensor<4xf32> to tensor<*xf32>
2323
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
2424

2525
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}

0 commit comments

Comments
 (0)