diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c6b7e28fe0aa8b..e529a50dae935b 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -998,6 +998,31 @@ def MemRef_ReinterpretCastOp: }]; } +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The `memref.rank` operation takes a memref operand and returns its rank. + + Example: + + ```mlir + %0 = memref.rank %arg0 : memref<*xf32> + %1 = memref.rank %arg1 : memref + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); + let results = (outs Index); + + let verifier = ?; + let hasFolder = 1; + let assemblyFormat = "$memref attr-dict `:` type($memref)"; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 23b9df282af03e..2e50971db9e7a2 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -658,32 +658,6 @@ def ConstantOp : Std_Op<"constant", let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -def RankOp : Std_Op<"rank", [NoSideEffect]> { - let summary = "rank operation"; - let description = [{ - The `rank` operation takes a memref/tensor operand and returns its rank. - - Example: - - ```mlir - %1 = rank %arg0 : tensor<*xf32> - %2 = rank %arg1 : memref<*xf32> - ``` - }]; - - let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor], - "any memref or tensor type">:$memrefOrTensor); - let results = (outs Index); - let verifier = ?; - - let hasFolder = 1; - let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 3b1bfeeca6c107..21331fc649cd56 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -68,9 +68,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> { let summary = "dimension index operation"; let description = [{ - The `dim` operation takes a tensor and a dimension operand of type `index`. - It returns the size of the requested dimension of the given tensor. - If the dimension index is out of bounds, the behavior is undefined. + The `tensor.dim` operation takes a tensor and a dimension operand of type + `index`. It returns the size of the requested dimension of the given + tensor. If the dimension index is out of bounds, the behavior is undefined. The specified tensor type is that of the first operand. @@ -558,6 +558,31 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides< let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The `tensor.rank` operation takes a tensor operand and returns its rank. + + Example: + + ```mlir + %0 = tensor.rank %arg0 : tensor<*xf32> + %1 = tensor.rank %arg1 : tensor + ``` + }]; + + let arguments = (ins AnyTensor:$tensor); + let results = (outs Index); + + let verifier = ?; + let hasFolder = 1; + let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 521b3fcab0c6fe..28981dd87ecc9c 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -596,6 +596,28 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { } }; +struct RankOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type operandType = op.memref().getType(); + if (auto unrankedMemRefType = operandType.dyn_cast()) { + UnrankedMemRefDescriptor desc(adaptor.memref()); + rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); + return success(); + } + if (auto rankedMemRefType = operandType.dyn_cast()) { + rewriter.replaceOp( + op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); + return success(); + } + return failure(); + } +}; + struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1549,6 +1571,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, + RankOpLowering, ReassociatingReshapeOpConversion, ReassociatingReshapeOpConversion, StoreOpLowering, diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index e1e24faa4d2a69..5a1af7b33132ea 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -577,7 +577,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); - Value rank = rewriter.create(loc, tensor); + Value rank = rewriter.create(loc, tensor); rewriter.replaceOpWithNewOp( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 200834a2d1bc6c..f588521ac6ef06 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -566,28 +566,6 @@ struct UnrealizedConversionCastOpLowering } }; -struct RankOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(RankOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Type operandType = op.getMemrefOrTensor().getType(); - if (auto unrankedMemRefType = operandType.dyn_cast()) { - UnrankedMemRefDescriptor desc(adaptor.getMemrefOrTensor()); - rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); - return success(); - } - if (auto rankedMemRefType = operandType.dyn_cast()) { - rewriter.replaceOp( - op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); - return success(); - } - return failure(); - } -}; - // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. @@ -987,7 +965,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, CondBranchOpLowering, ConstantOpLowering, GenericAtomicRMWOpLowering, - RankOpLowering, ReturnOpLowering, SelectOpLowering, SplatOpLowering, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 4badc0b31ddb60..1916ffe36dd661 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1072,6 +1072,19 @@ LogicalResult PrefetchOp::fold(ArrayRef cstOperands, return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef operands) { + // Constant fold rank when the rank of the operand is known. + auto type = getOperand().getType(); + auto shapedType = type.dyn_cast(); + if (shapedType && shapedType.hasRank()) + return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); + return IntegerAttr(); +} + //===----------------------------------------------------------------------===// // ReinterpretCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 1a43c0937d038c..1d045b29121541 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -899,20 +899,6 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) { return value.isa(); } -//===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -OpFoldResult RankOp::fold(ArrayRef operands) { - // Constant fold rank when the rank of the operand is known. - auto type = getOperand().getType(); - if (auto shapedType = type.dyn_cast()) - if (shapedType.hasRank()) - return IntegerAttr::get(IndexType::get(getContext()), - shapedType.getRank()); - return IntegerAttr(); -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index edddfb86e5539a..ecdd966a3c35e9 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -609,6 +609,19 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, StaticTensorGenerate>(context); } +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef operands) { + // Constant fold rank when the rank of the operand is known. + auto type = getOperand().getType(); + auto shapedType = type.dyn_cast(); + if (shapedType && shapedType.hasRank()) + return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); + return IntegerAttr(); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp index 64a005dfb55b1b..27e00a14c0d4d5 100644 --- a/mlir/lib/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Transforms/BufferOptimizations.cpp @@ -37,14 +37,16 @@ static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { - // Check if the dynamic shape dimension of the alloc is produced by RankOp. - // If this is the case, it is likely to be small. Furthermore, the dimension - // is limited to the maximum rank of the allocated memref to avoid large - // values by multiplying several small values. + // Check if the dynamic shape dimension of the alloc is produced by + // `memref.rank`. If this is the case, it is likely to be small. + // Furthermore, the dimension is limited to the maximum rank of the + // allocated memref to avoid large values by multiplying several small + // values. if (type.getRank() <= maxRankOfAllocatedMemRef) { - return llvm::all_of( - alloc.getDefiningOp()->getOperands(), - [&](Value operand) { return operand.getDefiningOp(); }); + return llvm::all_of(alloc.getDefiningOp()->getOperands(), + [&](Value operand) { + return operand.getDefiningOp(); + }); } return false; } diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index a26638a34151f9..009106f95e8a54 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt -convert-memref-to-llvm %s -split-input-file | FileCheck %s // RUN: mlir-opt -convert-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s - // CHECK-LABEL: func @view( // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index func @view(%arg0 : index, %arg1 : index, %arg2 : index) { @@ -835,3 +834,24 @@ func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { // CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 + +// ----- + +// CHECK-LABEL: func @rank_of_unranked +// CHECK32-LABEL: func @rank_of_unranked +func @rank_of_unranked(%unranked: memref<*xi32>) { + %rank = memref.rank %unranked : memref<*xi32> + return +} +// CHECK: %[[UNRANKED_DESC:.*]] = builtin.unrealized_conversion_cast +// CHECK-NEXT: llvm.extractvalue %[[UNRANKED_DESC]][0] : !llvm.struct<(i64, ptr)> +// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr)> + +// CHECK-LABEL: func @rank_of_ranked +// CHECK32-LABEL: func @rank_of_ranked +func @rank_of_ranked(%ranked: memref) { + %rank = memref.rank %ranked : memref + return +} +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK32: llvm.mlir.constant(1 : index) : i32 diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 015cb2fcaaf435..ea0ef33862ce5f 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -203,7 +203,7 @@ func @shape_of(%arg : tensor<*xf32>) { // CHECK-LABEL: @shape_of_unranked // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { - // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> + // CHECK: %[[RANK:.*]] = tensor.rank %[[ARG]] : tensor<*xf32> // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] { // CHECK: ^bb0(%[[I:.*]]: index): // CHECK: %[[EXTENT:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32> diff --git a/mlir/test/Conversion/StandardToLLVM/rank.mlir b/mlir/test/Conversion/StandardToLLVM/rank.mlir deleted file mode 100644 index 7c0a03aa8df37b..00000000000000 --- a/mlir/test/Conversion/StandardToLLVM/rank.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s - -// CHECK-LABEL: func @rank_of_unranked -// CHECK32-LABEL: func @rank_of_unranked -func @rank_of_unranked(%unranked: memref<*xi32>) { - %rank = rank %unranked : memref<*xi32> - return -} -// CHECK-NEXT: llvm.mlir.undef -// CHECK-NEXT: llvm.insertvalue -// CHECK-NEXT: llvm.insertvalue -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, ptr)> -// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr)> - -// CHECK-LABEL: func @rank_of_ranked -// CHECK32-LABEL: func @rank_of_ranked -func @rank_of_ranked(%ranked: memref) { - %rank = rank %ranked : memref - return -} -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK32: llvm.mlir.constant(1 : index) : i32 diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 251658fac76538..80282c21afab0a 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -185,10 +185,10 @@ func @dim_of_alloca(%size: index) -> index { // Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) // CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> -// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = memref.rank %[[MEM]] : memref<*xf32> // CHECK-NEXT: return %[[RANK]] : index func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { - %0 = rank %arg0 : memref<*xf32> + %0 = memref.rank %arg0 : memref<*xf32> %1 = memref.alloca(%0) : memref %c0 = arith.constant 0 : index %2 = memref.dim %1, %c0 : memref @@ -438,3 +438,15 @@ func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index) // CHECK: %[[RESULT:.+]] = memref.subview // CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}> // CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: func @fold_rank_memref +func @fold_rank_memref(%arg0 : memref) -> (index) { + // Fold a rank into a constant + // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index + %rank_0 = memref.rank %arg0 : memref + + // CHECK-NEXT: return [[C2]] + return %rank_0 : index +} diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 6014687e6e9ddd..55c5a821fb3dd1 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -844,3 +844,11 @@ func @test_alloc_memref_map_rank_mismatch() { %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> return } + +// ----- + +func @rank(%0: f32) { + // expected-error@+1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}} + "memref.rank"(%0): (f32)->index + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index f716c5de21742a..4ff2f8b5517be1 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -207,3 +207,14 @@ func @collapse_shape_to_dynamic // CHECK: func @collapse_shape_to_dynamic // CHECK: memref.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] + +// ----- + +func @rank(%t : memref<4x4x?xf32>) { + // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32> + %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index + + // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32> + %1 = memref.rank %t : memref<4x4x?xf32> + return +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index fc9abe439b8a2f..ec9601e269939f 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -183,7 +183,7 @@ func @extract_oob_from_tensor.from_elements(%element : index) -> index { // CHECK-LABEL: func @extract_from_tensor.generate // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> + %size = tensor.rank %tensor : tensor<*xf32> // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]] %0 = tensor.generate %size { ^bb0(%arg0: index): @@ -200,7 +200,7 @@ func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index // CHECK-LABEL: func @extract_from_tensor.generate_2d // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> + %size = tensor.rank %tensor : tensor<*xf32> // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]] // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]] // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]] @@ -221,7 +221,7 @@ func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tenso // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects // CHECK-SAME: %[[IDX:.*]]: index func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref) -> index { - %size = rank %tensor : tensor<*xf32> + %size = tensor.rank %tensor : tensor<*xf32> // CHECK: %[[DTENSOR:.*]] = tensor.generate %0 = tensor.generate %size { ^bb0(%arg0: index): @@ -900,3 +900,18 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> { // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64> // CHECK-NOT: tensor.expand_shape // CHECK: return %[[CST]] + +// ----- + +// CHECK-LABEL: func @fold_rank +func @fold_rank() -> (index) { + %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> + : tensor<2x1x4xi32> + + // Fold a ank into a constant + // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index + %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32> + + // CHECK-NEXT: return [[C3]] + return %rank_0 : index +} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 8b40ec80e02d46..564526f16370f2 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -292,3 +292,11 @@ func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor) : tensor into tensor return %0 : tensor } + +// ----- + +func @rank(%0: f32) { + // expected-error@+1 {{'tensor.rank' op operand #0 must be tensor of any type values}} + "tensor.rank"(%0): (f32)->index + return +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 63afc1f382b37d..8d50d151842184 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -160,3 +160,14 @@ func @legal_collapsing_reshape_dynamic_tensor // CHECK: func @legal_collapsing_reshape_dynamic_tensor // CHECK: tensor.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] + +// ----- + +func @rank(%t : tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32> + %0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index + + // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32> + %1 = tensor.rank %t : tensor<4x4x?xf32> + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index fe2d7207d3d015..b83f530eeacc68 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -99,12 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32> %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32> - // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32> - %71 = "std.rank"(%t) : (tensor<4x4x?xf32>) -> index - - // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32> - %72 = rank %t : tensor<4x4x?xf32> - // CHECK: = constant unit %73 = constant unit diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 13cfd16daf9a49..49f29f09bf492f 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1,13 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics -func @rank(f32) { -^bb(%0: f32): - "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any memref or tensor type}} - - return -} - -// ----- func @affine_apply_no_map() { ^bb0: %i = arith.constant 0 : index diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 5406a8588ce4b4..2e720eae3439ce 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -754,32 +754,6 @@ func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, // ----- -// CHECK-LABEL: func @fold_rank -func @fold_rank() -> (index) { - %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> - - // Fold a rank into a constant - // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index - %rank_0 = rank %const_0 : tensor<2x1x4xi32> - - // CHECK-NEXT: return [[C3]] - return %rank_0 : index -} - -// ----- - -// CHECK-LABEL: func @fold_rank_memref -func @fold_rank_memref(%arg0 : memref) -> (index) { - // Fold a rank into a constant - // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index - %rank_0 = rank %arg0 : memref - - // CHECK-NEXT: return [[C2]] - return %rank_0 : index -} - -// ----- - // CHECK-LABEL: func @nested_isolated_region func @nested_isolated_region() { // CHECK-NEXT: func @isolated_op diff --git a/mlir/test/Transforms/promote-buffers-to-stack.mlir b/mlir/test/Transforms/promote-buffers-to-stack.mlir index c78f8a71dbb7bc..2b6cd3185fa11b 100644 --- a/mlir/test/Transforms/promote-buffers-to-stack.mlir +++ b/mlir/test/Transforms/promote-buffers-to-stack.mlir @@ -77,25 +77,25 @@ func @condBranchDynamicType( // ----- // CHECK-LABEL: func @dynamicRanked -func @dynamicRanked(%tensor: tensor<*xf32>) { - %0 = rank %tensor : tensor<*xf32> +func @dynamicRanked(%memref: memref<*xf32>) { + %0 = memref.rank %memref : memref<*xf32> %1 = memref.alloc(%0) : memref return } -// CHECK-NEXT: %[[RANK:.*]] = rank +// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32> // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]]) // ----- // CHECK-LABEL: func @dynamicRanked2D -func @dynamicRanked2D(%tensor: tensor<*xf32>) { - %0 = rank %tensor : tensor<*xf32> +func @dynamicRanked2D(%memref: memref<*xf32>) { + %0 = memref.rank %memref : memref<*xf32> %1 = memref.alloc(%0, %0) : memref return } -// CHECK-NEXT: %[[RANK:.*]] = rank +// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32> // RANK-NEXT: %[[ALLOC:.*]] = memref.alloca(%[[RANK]], %[[RANK]]) // DEFINDEX-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[RANK]], %[[RANK]])