60 changes: 20 additions & 40 deletions mlir/lib/Interfaces/ViewLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,43 @@ using namespace mlir;
/// Include the definitions of the loop-like interfaces.
#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"

static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
OffsetSizeAndStrideOpInterface op, StringRef name,
unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
LogicalResult mlir::verifyListOfOperandsOrIntegers(
Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
/// Check static and dynamic offsets/sizes/strides breakdown.
if (attr.size() != expectedNumElements)
return op.emitError("expected ")
return op->emitError("expected ")
<< expectedNumElements << " " << name << " values";
unsigned expectedNumDynamicEntries =
llvm::count_if(attr.getValue(), [&](Attribute attr) {
return isDynamic(attr.cast<IntegerAttr>().getInt());
});
if (values.size() != expectedNumDynamicEntries)
return op.emitError("expected ")
return op->emitError("expected ")
<< expectedNumDynamicEntries << " dynamic " << name << " values";
return success();
}

LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
if (failed(verifyOpWithOffsetSizesAndStridesPart(
op, "offset", ranks[0],
OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
op.offsets())))
if (failed(verifyListOfOperandsOrIntegers(
op, "offset", ranks[0], op.static_offsets(), op.offsets(),
ShapedType::isDynamicStrideOrOffset)))
return failure();
if (failed(verifyOpWithOffsetSizesAndStridesPart(
op, "size", ranks[1],
OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
op.static_sizes(), ShapedType::isDynamic, op.sizes())))
if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1],
op.static_sizes(), op.sizes(),
ShapedType::isDynamic)))
return failure();
if (failed(verifyOpWithOffsetSizesAndStridesPart(
op, "stride", ranks[2],
OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
op.static_strides(), ShapedType::isDynamicStrideOrOffset,
op.strides())))
if (failed(verifyListOfOperandsOrIntegers(
op, "stride", ranks[2], op.static_strides(), op.strides(),
ShapedType::isDynamicStrideOrOffset)))
return failure();
return success();
}

/// Print a list with either (1) the static integer value in `arrayAttr` if
/// `isDynamic` evaluates to false or (2) the next value otherwise.
/// This allows idiomatic printing of mixed value and integer attributes in a
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
static void
printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
ArrayAttr arrayAttr,
llvm::function_ref<bool(int64_t)> isDynamic) {
void mlir::printListOfOperandsOrIntegers(
OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
llvm::function_ref<bool(int64_t)> isDynamic) {
p << '[';
unsigned idx = 0;
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
Expand Down Expand Up @@ -95,18 +84,9 @@ void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
}

/// Parse a mixed list with either (1) static integer values or (2) SSA values.
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
/// in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
static ParseResult
parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
StringRef attrName, int64_t dynVal,
SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
ParseResult mlir::parseListOfOperandsOrIntegers(
OpAsmParser &parser, OperationState &result, StringRef attrName,
int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,42 @@ func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?
outs(%b : memref<?x?xf32>)
return
}
// -----

func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
%c6 = constant 6 : index
%0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
return %0 : tensor<4x5x?xf32>
}
// CHECK: func @init_tensor_canonicalize
// CHECK: %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
// CHECK: %[[T1:.+]] = tensor_cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
// CHECK: return %[[T1]]

// -----

func @init_tensor_static_dim() -> (index, index) {
%c0 = constant 0 : index
%c2 = constant 2 : index
%c6 = constant 6 : index
%0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
%1 = dim %0, %c2 : tensor<4x5x?xf32>
%2 = dim %0, %c0 : tensor<4x5x?xf32>
return %1, %2 : index, index
}
// CHECK: func @init_tensor_static_dim
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
// CHECK: return %[[C6]], %[[C4]]

// -----

func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
%c2 = constant 2 : index
%0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
%1 = dim %0, %c2 : tensor<4x5x?xf32>
return %1 : index
}
// CHECK: func @init_tensor_dynamic_dim
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]]
42 changes: 40 additions & 2 deletions mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: mlir-opt -split-input-file %s | FileCheck %s
// | mlir-opt | FileCheck %s
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s

// TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
//
Expand Down Expand Up @@ -698,3 +697,42 @@ func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (
// CHECK-LABEL: func @memref_reshape_zero_dim
// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
// CHECK: linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>

// -----

func @init_tensor(%arg0 : index, %arg1 : index)
{
%0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
%1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
return
}
// CHECK-LABEL: func @init_tensor
// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32>
// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>

// -----

func @init_tensor_err(%arg0 : index, %arg1 : index)
{
// expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}}
%1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32>
return
}

// -----

func @init_tensor_err(%arg0 : index)
{
// expected-error @+1 {{expected 4 sizes values}}
%1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32>
return
}

// -----

func @init_tensor_err(%arg0 : index)
{
// expected-error @+1 {{expected 2 dynamic sizes values}}
%1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32>
return
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Standard/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
%1 = dim %0, %c3 : memref<*xf32>
return %1 : index
}

// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK-NEXT: return %[[C4]], %[[T0]]
func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
%1 = dim %0, %c0 : tensor<?x?xf32>
%2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
}