205 changes: 204 additions & 1 deletion mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -split-input-file -canonicalize <%s | FileCheck %s --dump-input=fail
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail

// -----
// CHECK-LABEL: func @f
Expand Down Expand Up @@ -212,3 +212,206 @@ func @not_const(%arg0: !shape.shape) -> !shape.size {
%0 = shape.get_extent %arg0, 3
return %0 : !shape.size
}


// -----
// cstr_eq with non-constant but known equal shapes can be removed.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.cstr_eq %arg0, %arg0, %arg0
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// cstr_eq with equal const_shapes can be folded
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [0, 1]
%cs1 = shape.const_shape [0, 1]
%cs2 = shape.const_shape [0, 1]
%0 = shape.cstr_eq %cs0, %cs1, %cs2
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// cstr_eq with unequal const_shapes cannot be folded
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.cstr_eq
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [0, 1]
%cs1 = shape.const_shape [3, 1]
%0 = shape.cstr_eq %cs0, %cs1
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// cstr_eq without const_shapes cannot be folded
// CHECK-LABEL: func @f
func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
// CHECK-NEXT: shape.cstr_eq
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.cstr_eq %arg0, %arg1
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.const_witness true
%1 = shape.const_witness true
%2 = shape.const_witness true
%3 = shape.assuming_all %0, %1, %2
"consume.witness"(%3) : (!shape.witness) -> ()
return
}

// -----
// assuming_all should not be removed if not all witnesses are statically passing.
//
// Additionally check that the attribute is moved to the end as this op is
// commutative.
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source"
// CHECK-NEXT: shape.assuming_all %[[UNKNOWN]]
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.const_witness true
%1 = "test.source"() : () -> !shape.witness
%2 = shape.assuming_all %0, %1
"consume.witness"(%2) : (!shape.witness) -> ()
return
}

// -----
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]]
%0 = shape.const_shape [2, 3, 4]
%1 = shape.any %0, %arg0
return %1 : !shape.shape
}


// -----
// Folding of any with partially constant operands is not yet implemented.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: shape.any
// CHECK-NEXT: return %[[CS]]
%1 = shape.any %arg0, %arg1
return %1 : !shape.shape
}

// -----
// assuming with a known passing witness can be removed
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: source
// CHECK-NEXT: sink
// CHECK-NEXT: return
%0 = shape.const_witness true
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
}
"test.sink"(%1) : (index) -> ()
return
}

// -----
// assuming without a known passing passing witness cannot be removed
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: test.source
// CHECK-NEXT: shape.assuming
// CHECK-NEXT: test.source
// CHECK-NEXT: shape.assuming_yield
// CHECK-NEXT: }
// CHECK-NEXT: test.sink
// CHECK-NEXT: return
%0 = "test.source"() : () -> (!shape.witness)
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
}
"test.sink"(%1) : (index) -> ()
return
}

// -----
// Broadcastable with broadcastable constant shapes can be removed.
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [3, 1]
%cs1 = shape.const_shape [1, 5]
%0 = shape.cstr_broadcastable %cs0, %cs1
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// Broadcastable with non-broadcastable constant shapes is always false
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [1, 3]
%cs1 = shape.const_shape [1, 5]
%0 = shape.cstr_broadcastable %cs0, %cs1
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// Broadcastable without guaranteed broadcastable shapes cannot be removed.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [1,3]
%0 = shape.cstr_broadcastable %arg0, %cs0
"consume.witness"(%0) : (!shape.witness) -> ()
return
}

// -----
// Broadcastable with non-constant but known equal shapes can be removed.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.cstr_broadcastable %arg0, %arg0
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
6 changes: 4 additions & 2 deletions mlir/test/Dialect/Shape/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ func @test_constraints() {
%1 = shape.const_shape [1, 2, 3]
%w0 = shape.cstr_broadcastable %0, %1
%w1 = shape.cstr_eq %0, %1
%w3 = shape.assuming_all %w0, %w1
shape.assuming %w3 -> !shape.shape {
%w2 = shape.const_witness true
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1
shape.assuming_yield %2 : !shape.shape
}
Expand Down