30 changes: 15 additions & 15 deletions mlir/test/Dialect/Quant/convert-fakequant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// -----
// Verifies a quint8 single point.
// CHECK-LABEL: fakeQuantArgs_Quint8_0
func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>>
Expand All @@ -18,7 +18,7 @@ func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a quint8 single point (with narrow_range = true).
// CHECK-LABEL: fakeQuantArgs_Quint8_0_NarrowRange
func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>
Expand All @@ -33,7 +33,7 @@ func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
// -----
// Verifies a quint8 asymmetric 0..1 range.
// CHECK-LABEL: fakeQuantArgs_Quint8_0_1
func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 0.0039215686274509803>>
Expand All @@ -48,7 +48,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange
func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 0.003937007874015748:1>>
Expand All @@ -63,7 +63,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a quint8 symmetric range of -1..127/128.
// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange
func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
Expand All @@ -78,7 +78,7 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
// -----
// Verifies a qint8 single point.
// CHECK-LABEL: fakeQuantArgs_Qint8_0
func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>>
Expand All @@ -93,7 +93,7 @@ func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a qint8 single point (with narrow_range = true).
// CHECK-LABEL: fakeQuantArgs_Qint8_0_NarrowRange
func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>
Expand All @@ -108,7 +108,7 @@ func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
// -----
// Verifies a qint8 asymmetric 0..1 range.
// CHECK-LABEL: fakeQuantArgs_Qint8_0_1
func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
Expand All @@ -123,7 +123,7 @@ func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange
func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>>
Expand All @@ -138,7 +138,7 @@ func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a qint8 symmetric range of -1..127/128.
// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange
func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>>
Expand All @@ -154,7 +154,7 @@ func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of
// 0 and range -1.0 .. 32767/32768.
// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric
func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i16:f32, 3.0517578125E-5>>
Expand All @@ -169,7 +169,7 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verify that lowering to barriers of unranked tensors functions.
// CHECK-LABEL: fakeQuantArgs_UnrankedTensor
func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
func.func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
^bb0(%arg0: tensor<f32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<f32>)
// CHECK-SAME: -> tensor<!quant.uniform<u8:f32, 0.0039215686274509803>>
Expand All @@ -183,7 +183,7 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {

// -----
// CHECK-LABEL: fakeQuantArgs_all_positive
func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand All @@ -199,7 +199,7 @@ func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// CHECK-LABEL: fakeQuantArgs_all_negative
func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand All @@ -216,7 +216,7 @@ func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// -----
// Verifies a qint8 per axis
// CHECK-LABEL: fakeQuantPerAxis
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

// CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Quant/parse-any.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK-LABEL: parseFullySpecified
// CHECK: !quant.any<i8<-8:7>:f32>
!qalias = type !quant.any<i8<-8:7>:f32>
func @parseFullySpecified() -> !qalias {
func.func @parseFullySpecified() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -13,7 +13,7 @@ func @parseFullySpecified() -> !qalias {
// CHECK-LABEL: parseNoExpressedType
// CHECK: !quant.any<i8<-8:7>>
!qalias = type !quant.any<i8<-8:7>>
func @parseNoExpressedType() -> !qalias {
func.func @parseNoExpressedType() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -22,7 +22,7 @@ func @parseNoExpressedType() -> !qalias {
// CHECK-LABEL: parseOnlyStorageType
// CHECK: !quant.any<i8>
!qalias = type !quant.any<i8>
func @parseOnlyStorageType() -> !qalias {
func.func @parseOnlyStorageType() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Quant/parse-calibrated.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK-LABEL: parseCalibrated
// CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>
!qalias = type !quant.calibrated<f32<-0.998:1.2321>>
func @parseCalibrated() -> !qalias {
func.func @parseCalibrated() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
14 changes: 7 additions & 7 deletions mlir/test/Dialect/Quant/parse-ops-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics

// -----
func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) ->
func.func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{layerStats must have a floating point element type}}
%0 = "quant.stats"(%arg0) {
Expand All @@ -11,7 +11,7 @@ func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) ->
}

// -----
func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) ->
func.func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{layerStats must have shape [2]}}
%0 = "quant.stats"(%arg0) {
Expand All @@ -21,7 +21,7 @@ func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) ->
}

// -----
func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) ->
func.func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{layerStats must have shape [2]}}
%0 = "quant.stats"(%arg0) {
Expand All @@ -32,7 +32,7 @@ func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) ->

// -----
// CHECK-LABEL: validStatistics
func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// expected-error@+1 {{axisStats must have a floating point element type}}
%0 = "quant.stats"(%0) {
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
Expand All @@ -46,7 +46,7 @@ func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x
}

// -----
func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) ->
func.func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
%0 = "quant.stats"(%arg0) {
Expand All @@ -62,7 +62,7 @@ func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) ->
}

// -----
func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
func.func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
// expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
%0 = "quant.stats"(%arg0) {
Expand All @@ -77,7 +77,7 @@ func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
}

// -----
func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
// expected-error@+1 {{axis must be specified for axisStats}}
%1 = "quant.stats"(%arg0) {
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Quant/parse-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// -----
// CHECK-LABEL: validConstFakeQuant
func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.const_fake_quant"(%arg0) {
min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
Expand All @@ -17,7 +17,7 @@ func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// CHECK-LABEL: validConstFakeQuantPerAxis
func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> {
func.func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> {
%0 = "quant.const_fake_quant_per_axis"(%arg0) {
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
Expand All @@ -32,15 +32,15 @@ func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32>

// -----
// CHECK-LABEL: validStatisticsRef
func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.stats_ref"(%arg0) { statsKey = "foobar" } :
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

// -----
// CHECK-LABEL: validStatistics
func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.stats"(%arg0) {
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
Expand All @@ -57,7 +57,7 @@ func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// CHECK-LABEL: validCoupledRef
func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
func.func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = "quant.coupled_ref"(%arg0) { coupledKey = "foobar" } :
(tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
Expand Down
34 changes: 17 additions & 17 deletions mlir/test/Dialect/Quant/parse-uniform.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint
// CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
!qalias = type !quant.uniform<i8<-8:7>:f32, 0.99872:127>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -14,7 +14,7 @@ func @parse() -> !qalias {
// Trailing whitespace.
// CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
!qalias = type !quant.uniform<i8<-8:7>:f32, 0.99872:127 >
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -24,7 +24,7 @@ func @parse() -> !qalias {
// [unsigned] storageType, expressedType, scale
// CHECK: !quant.uniform<u8:f32, 9.987200e-01>
!qalias = type !quant.uniform<u8:f32, 0.99872>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -33,7 +33,7 @@ func @parse() -> !qalias {
// Exponential scale (-)
// CHECK: !quant.uniform<u8:f32, 2.000000e-02>
!qalias = type !quant.uniform<u8:f32, 2.0e-2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -42,7 +42,7 @@ func @parse() -> !qalias {
// Exponential scale (+)
// CHECK: !quant.uniform<u8:f32, 2.000000e+02>
!qalias = type !quant.uniform<u8:f32, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -51,7 +51,7 @@ func @parse() -> !qalias {
// Storage type: i16
// CHECK: !quant.uniform<i16:f32, 2.000000e+02>
!qalias = type !quant.uniform<i16:f32, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -60,7 +60,7 @@ func @parse() -> !qalias {
// Storage type: u16
// CHECK: !quant.uniform<u16:f32, 2.000000e+02>
!qalias = type !quant.uniform<u16:f32, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -69,7 +69,7 @@ func @parse() -> !qalias {
// Storage type: i32
// CHECK: !quant.uniform<i32:f32, 2.000000e+02>
!qalias = type !quant.uniform<i32:f32, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -78,7 +78,7 @@ func @parse() -> !qalias {
// Storage type: u32
// CHECK: !quant.uniform<u32:f32, 2.000000e+02>
!qalias = type !quant.uniform<u32:f32, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -87,7 +87,7 @@ func @parse() -> !qalias {
// Expressed type: f32
// CHECK: !quant.uniform<u8:f32, 2.000000e+02>
!qalias = type !quant.uniform<u8:f32, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -96,7 +96,7 @@ func @parse() -> !qalias {
// Expressed type: f32
// CHECK: !quant.uniform<u8:f32, 0x41646ABBA0000000:128>
!qalias = type !quant.uniform<u8:f32, 0x41646ABBA0000000:128>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -105,7 +105,7 @@ func @parse() -> !qalias {
// Expressed type: f16
// CHECK: !quant.uniform<u8:f16, 2.000000e+02>
!qalias = type !quant.uniform<u8:f16, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -114,7 +114,7 @@ func @parse() -> !qalias {
// Expressed type: f64
// CHECK: !quant.uniform<u8:f64, 2.000000e+02>
!qalias = type !quant.uniform<u8:f64, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -123,7 +123,7 @@ func @parse() -> !qalias {
// Expressed type: bf16
// CHECK: !quant.uniform<u8:bf16, 2.000000e+02>
!qalias = type !quant.uniform<u8:bf16, 2.0e+2>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -132,7 +132,7 @@ func @parse() -> !qalias {
// Per-axis scales and zero points (affine)
// CHECK: !quant.uniform<u8:f32:1, {2.000000e+02:-120,9.987200e-01:127}>
!qalias = type !quant.uniform<u8:f32:1, {2.0e+2:-120,0.99872:127}>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -141,7 +141,7 @@ func @parse() -> !qalias {
// Per-axis scales and no zero points (fixedpoint)
// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01}>
!qalias = type !quant.uniform<i8:f32:1, {2.0e+2,0.99872}>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
Expand All @@ -150,7 +150,7 @@ func @parse() -> !qalias {
// Per-axis scales and zero points (mixed affine and fixedpoint)
// CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
!qalias = type !quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>
func @parse() -> !qalias {
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
14 changes: 7 additions & 7 deletions mlir/test/Dialect/Quant/quant_region.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @source
func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
Expand All @@ -13,7 +13,7 @@ func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -
}

// CHECK-LABEL: @annotated
func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
Expand All @@ -26,7 +26,7 @@ func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>
}

// CHECK-LABEL: @quantized
func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
Expand All @@ -40,7 +40,7 @@ func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>

// -----

func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform<i32:f16, 3.000000e+00> and input type 'tensor<4xf32>'}}
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
Expand All @@ -55,7 +55,7 @@ func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tens

// -----

func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}}
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
Expand All @@ -70,7 +70,7 @@ func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: ten

// -----

func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}}
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
Expand All @@ -85,7 +85,7 @@ func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor

// -----

func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
func.func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-note @+1 {{required by region isolation constraints}}
%0 = "quant.region"(%arg0, %arg1) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>):
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/SCF/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// CHECK: %[[RESULT_TENSOR:.*]] = bufferization.to_tensor %[[RESULT_MEMREF:.*]] : memref<?xf32>
// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32>
// CHECK: }
func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
func.func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
%0 = scf.if %pred -> (tensor<?xf32>) {
scf.yield %true_val : tensor<?xf32>
} else {
Expand All @@ -34,7 +34,7 @@ func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tens
// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_9:.*]] : memref<f32>
// CHECK: return %[[VAL_8]] : tensor<f32>
// CHECK: }
func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
func.func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
%ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
scf.yield %iter : tensor<f32>
}
Expand All @@ -46,7 +46,7 @@ func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f3
// It would previously fail altogether.
// CHECK-LABEL: func @if_correct_recursive_legalization_behavior
// CHECK: "test.munge_tensor"
func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> {
func.func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> {
%0 = scf.if %pred -> (tensor<f32>) {
%1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
scf.yield %1: tensor<f32>
Expand All @@ -70,7 +70,7 @@ func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>
// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[RESULT:.*]] : memref<f32>
// CHECK: return %[[TENSOR]] : tensor<f32>
// CHECK: }
func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> {
func.func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> {
%ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> {
%0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>)
scf.yield %0 : tensor<f32>
Expand All @@ -87,7 +87,7 @@ func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: in
// CHECK: scf.yield %{{.*}}, %{{.*}} : i64, memref<f32>
// CHECK: %[[RES2:.*]] = bufferization.to_tensor %[[RES1]]#2 : memref<f32>
// CHECK: return %[[RES1]]#1, %[[RES2]] : i64, tensor<f32>
func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor<f32>) -> (i64, tensor<f32>) {
func.func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor<f32>) -> (i64, tensor<f32>) {
%c2_i64 = arith.constant 2 : i64
%0:3 = scf.while (%arg3 = %arg0, %arg4 = %arg2) : (i64, tensor<f32>) -> (i64, i64, tensor<f32>) {
%1 = arith.cmpi slt, %arg3, %arg1 : i64
Expand Down
132 changes: 66 additions & 66 deletions mlir/test/Dialect/SCF/canonicalize.mlir

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions mlir/test/Dialect/SCF/control-flow-sink.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// CHECK: %[[V1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
// CHECK: scf.yield %[[V1]]
// CHECK: return %[[V0]]
func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 {
func.func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 {
%0 = arith.addi %arg1, %arg1 : i32
%1 = arith.muli %arg1, %arg1 : i32
%result = scf.if %arg0 -> i32 {
Expand All @@ -22,14 +22,14 @@ func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 {

// -----

func private @consume(i32) -> ()
func.func private @consume(i32) -> ()

// CHECK-LABEL: @test_scf_if_then_only_sink
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
// CHECK: scf.if %[[ARG0]]
// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
// CHECK: call @consume(%[[V0]])
func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) {
func.func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) {
%0 = arith.addi %arg1, %arg1 : i32
scf.if %arg0 {
call @consume(%0) : (i32) -> ()
Expand All @@ -40,15 +40,15 @@ func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) {

// -----

func private @consume(i32) -> ()
func.func private @consume(i32) -> ()

// CHECK-LABEL: @test_scf_if_double_sink
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
// CHECK: scf.if %[[ARG0]]
// CHECK: scf.if %[[ARG0]]
// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
// CHECK: call @consume(%[[V0]])
func @test_scf_if_double_sink(%arg0: i1, %arg1: i32) {
func.func @test_scf_if_double_sink(%arg0: i1, %arg1: i32) {
%0 = arith.addi %arg1, %arg1 : i32
scf.if %arg0 {
scf.if %arg0 {
Expand Down
36 changes: 18 additions & 18 deletions mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: scf.for
// CHECK: memref.store %[[C2]], %{{.*}}[] : memref<i64>
func @scf_for_canonicalize_min(%A : memref<i64>) {
func.func @scf_for_canonicalize_min(%A : memref<i64>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
Expand All @@ -23,7 +23,7 @@ func @scf_for_canonicalize_min(%A : memref<i64>) {
// CHECK: %[[Cneg2:.*]] = arith.constant -2 : i64
// CHECK: scf.for
// CHECK: memref.store %[[Cneg2]], %{{.*}}[] : memref<i64>
func @scf_for_canonicalize_max(%A : memref<i64>) {
func.func @scf_for_canonicalize_max(%A : memref<i64>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
Expand All @@ -42,7 +42,7 @@ func @scf_for_canonicalize_max(%A : memref<i64>) {
// CHECK: scf.for
// CHECK: affine.max
// CHECK: arith.index_cast
func @scf_for_max_not_canonicalizable(%A : memref<i64>) {
func.func @scf_for_max_not_canonicalizable(%A : memref<i64>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
Expand All @@ -63,7 +63,7 @@ func @scf_for_max_not_canonicalizable(%A : memref<i64>) {
// CHECK: scf.for
// CHECK: scf.for
// CHECK: memref.store %[[C5]], %{{.*}}[] : memref<i64>
func @scf_for_loop_nest_canonicalize_min(%A : memref<i64>) {
func.func @scf_for_loop_nest_canonicalize_min(%A : memref<i64>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
Expand All @@ -86,7 +86,7 @@ func @scf_for_loop_nest_canonicalize_min(%A : memref<i64>) {
// CHECK: scf.for
// CHECK: affine.min
// CHECK: arith.index_cast
func @scf_for_not_canonicalizable_1(%A : memref<i64>) {
func.func @scf_for_not_canonicalizable_1(%A : memref<i64>) {
// This should not canonicalize because: 4 - %i may take the value 1 < 2.
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand All @@ -106,7 +106,7 @@ func @scf_for_not_canonicalizable_1(%A : memref<i64>) {
// CHECK: scf.for
// CHECK: affine.apply
// CHECK: arith.index_cast
func @scf_for_canonicalize_partly(%A : memref<i64>) {
func.func @scf_for_canonicalize_partly(%A : memref<i64>) {
// This should canonicalize only partly: 256 - %i <= 256.
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
Expand All @@ -126,7 +126,7 @@ func @scf_for_canonicalize_partly(%A : memref<i64>) {
// CHECK: scf.for
// CHECK: affine.min
// CHECK: arith.index_cast
func @scf_for_not_canonicalizable_2(%A : memref<i64>, %step : index) {
func.func @scf_for_not_canonicalizable_2(%A : memref<i64>, %step : index) {
// This example should simplify but affine_map is currently missing
// semi-affine canonicalizations: `((s0 * 42 - 1) floordiv s0) * s0`
// should evaluate to 41 * s0.
Expand All @@ -149,7 +149,7 @@ func @scf_for_not_canonicalizable_2(%A : memref<i64>, %step : index) {
// CHECK: scf.for
// CHECK: affine.min
// CHECK: arith.index_cast
func @scf_for_not_canonicalizable_3(%A : memref<i64>, %step : index) {
func.func @scf_for_not_canonicalizable_3(%A : memref<i64>, %step : index) {
// This example should simplify but affine_map is currently missing
// semi-affine canonicalizations: `-(((s0 * s0 - 1) floordiv s0) * s0)`
// should evaluate to (s0 - 1) * s0.
Expand All @@ -172,7 +172,7 @@ func @scf_for_not_canonicalizable_3(%A : memref<i64>, %step : index) {
// CHECK: scf.for
// CHECK: affine.min
// CHECK: arith.index_cast
func @scf_for_invalid_loop(%A : memref<i64>, %step : index) {
func.func @scf_for_invalid_loop(%A : memref<i64>, %step : index) {
// This is an invalid loop. It should not be touched by the canonicalization
// pattern.
%c1 = arith.constant 1 : index
Expand All @@ -193,7 +193,7 @@ func @scf_for_invalid_loop(%A : memref<i64>, %step : index) {
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: scf.parallel
// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
func @scf_parallel_canonicalize_min_1(%A : memref<i64>) {
func.func @scf_parallel_canonicalize_min_1(%A : memref<i64>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
Expand All @@ -212,7 +212,7 @@ func @scf_parallel_canonicalize_min_1(%A : memref<i64>) {
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: scf.parallel
// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
func @scf_parallel_canonicalize_min_2(%A : memref<i64>) {
func.func @scf_parallel_canonicalize_min_2(%A : memref<i64>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c7 = arith.constant 7 : index
Expand All @@ -231,7 +231,7 @@ func @scf_parallel_canonicalize_min_2(%A : memref<i64>) {
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK: scf.for
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
func.func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand All @@ -249,7 +249,7 @@ func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>,
// CHECK: scf.for
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
func.func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
%t2 : tensor<10x10xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -273,7 +273,7 @@ func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
// CHECK: scf.for
// CHECK: scf.for
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
func.func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
%t2 : tensor<10x10xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -302,7 +302,7 @@ func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>,
// CHECK: scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[t]]
// CHECK: tensor.dim %[[arg0]]
func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
func.func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
%t2 : tensor<?x?xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -320,7 +320,7 @@ func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
// CHECK-LABEL: func @tensor_dim_of_loop_result(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_loop_result(%t : tensor<?x?xf32>) -> index {
func.func @tensor_dim_of_loop_result(%t : tensor<?x?xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand All @@ -337,7 +337,7 @@ func @tensor_dim_of_loop_result(%t : tensor<?x?xf32>) -> index {
// CHECK-LABEL: func @tensor_dim_of_loop_result_no_canonicalize(
// CHECK: %[[loop:.*]]:2 = scf.for
// CHECK: tensor.dim %[[loop]]#1
func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
func.func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
%u : tensor<?x?xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -356,7 +356,7 @@ func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
// CHECK: %[[C4:.*]] = arith.constant 4 : i64
// CHECK: scf.for
// CHECK: memref.store %[[C4]], %{{.*}}[] : memref<i64>
func @one_trip_scf_for_canonicalize_min(%A : memref<i64>) {
func.func @one_trip_scf_for_canonicalize_min(%A : memref<i64>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/SCF/for-loop-peeling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
// CHECK: }
// CHECK: return %[[RESULT]]
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
%c0 = arith.constant 0 : i32
%r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0) -> i32 {
%s = affine.min #map(%ub, %iv)[%step]
Expand Down Expand Up @@ -50,7 +50,7 @@ func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
// CHECK: %[[RESULT:.*]] = arith.addi %[[LOOP]], %[[C1_I32]] : i32
// CHECK: return %[[RESULT]]
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
func @fully_static_bounds() -> i32 {
func.func @fully_static_bounds() -> i32 {
%c0_i32 = arith.constant 0 : i32
%lb = arith.constant 0 : index
%step = arith.constant 4 : index
Expand Down Expand Up @@ -90,7 +90,7 @@ func @fully_static_bounds() -> i32 {
// CHECK: }
// CHECK: return %[[RESULT]]
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
func @dynamic_upper_bound(%ub : index) -> i32 {
func.func @dynamic_upper_bound(%ub : index) -> i32 {
%c0_i32 = arith.constant 0 : i32
%lb = arith.constant 0 : index
%step = arith.constant 4 : index
Expand Down Expand Up @@ -128,7 +128,7 @@ func @dynamic_upper_bound(%ub : index) -> i32 {
// CHECK: }
// CHECK: return
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
func @no_loop_results(%ub : index, %d : memref<i32>) {
func.func @no_loop_results(%ub : index, %d : memref<i32>) {
%c0_i32 = arith.constant 0 : i32
%lb = arith.constant 0 : index
%step = arith.constant 4 : index
Expand Down Expand Up @@ -192,7 +192,7 @@ func @no_loop_results(%ub : index, %d : memref<i32>) {
#map3 = affine_map<(d0, d1)[s0] -> (s0, d0 - d1 - 1)>
#map4 = affine_map<(d0, d1, d2)[s0] -> (s0, d0 - d1, d2)>
#map5 = affine_map<(d0, d1)[s0] -> (-s0, -d0 + d1)>
func @test_affine_op_rewrite(%lb : index, %ub: index,
func.func @test_affine_op_rewrite(%lb : index, %ub: index,
%step: index, %d : memref<?xindex>,
%some_val: index) {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -260,7 +260,7 @@ func @test_affine_op_rewrite(%lb : index, %ub: index,
// CHECK-NO-SKIP: }
// CHECK-NO-SKIP: }
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
func @nested_loops(%lb0: index, %lb1 : index, %ub0: index, %ub1: index,
func.func @nested_loops(%lb0: index, %lb1 : index, %ub0: index, %ub1: index,
%step: index) -> i32 {
%c0 = arith.constant 0 : i32
%r0 = scf.for %iv0 = %lb0 to %ub0 step %step iter_args(%arg0 = %c0) -> i32 {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SCF/for-loop-specialization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#map0 = affine_map<()[s0, s1] -> (1024, s0 - s1)>
#map1 = affine_map<()[s0, s1] -> (64, s0 - s1)>

func @for(%outer: index, %A: memref<?xf32>, %B: memref<?xf32>,
func.func @for(%outer: index, %A: memref<?xf32>, %B: memref<?xf32>,
%C: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
// CHECK: }
// CHECK: return
// CHECK: }
func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
func.func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %i = %c0 to %arg1 step %c1 {
Expand Down Expand Up @@ -58,7 +58,7 @@ func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
// CHECK: }
// CHECK: return
// CHECK: }
func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
func.func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %i = %c0 to %arg1 step %c1 {
Expand Down Expand Up @@ -88,7 +88,7 @@ func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
// CHECK: }
// CHECK: return %[[VAL_14:.*]]#2 : f32
// CHECK: }
func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
func.func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
%s0 = arith.constant 0.0 : f32
%result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) {
%sn = arith.addf %iarg0, %iarg1 : f32
Expand Down Expand Up @@ -125,7 +125,7 @@ func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
// CHECK: }
// CHECK: return %[[VAL_17:.*]]#1 : i32
// CHECK: }
func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 {
func.func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c0 = arith.constant 0 : index
Expand Down
84 changes: 42 additions & 42 deletions mlir/test/Dialect/SCF/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics

func @loop_for_lb(%arg0: f32, %arg1: index) {
func.func @loop_for_lb(%arg0: f32, %arg1: index) {
// expected-error@+1 {{operand #0 must be index}}
"scf.for"(%arg0, %arg1, %arg1) ({}) : (f32, index, index) -> ()
return
}

// -----

func @loop_for_ub(%arg0: f32, %arg1: index) {
func.func @loop_for_ub(%arg0: f32, %arg1: index) {
// expected-error@+1 {{operand #1 must be index}}
"scf.for"(%arg1, %arg0, %arg1) ({}) : (index, f32, index) -> ()
return
}

// -----

func @loop_for_step(%arg0: f32, %arg1: index) {
func.func @loop_for_step(%arg0: f32, %arg1: index) {
// expected-error@+1 {{operand #2 must be index}}
"scf.for"(%arg1, %arg1, %arg0) ({}) : (index, index, f32) -> ()
return
}

// -----

func @loop_for_step_positive(%arg0: index) {
func.func @loop_for_step_positive(%arg0: index) {
// expected-error@+2 {{constant step operand must be positive}}
%c0 = arith.constant 0 : index
"scf.for"(%arg0, %arg0, %c0) ({
Expand All @@ -36,7 +36,7 @@ func @loop_for_step_positive(%arg0: index) {

// -----

func @loop_for_one_region(%arg0: index) {
func.func @loop_for_one_region(%arg0: index) {
// expected-error@+1 {{requires one region}}
"scf.for"(%arg0, %arg0, %arg0) (
{scf.yield},
Expand All @@ -47,7 +47,7 @@ func @loop_for_one_region(%arg0: index) {

// -----

func @loop_for_single_block(%arg0: index) {
func.func @loop_for_single_block(%arg0: index) {
// expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
"scf.for"(%arg0, %arg0, %arg0) (
{
Expand All @@ -62,7 +62,7 @@ func @loop_for_single_block(%arg0: index) {

// -----

func @loop_for_single_index_argument(%arg0: index) {
func.func @loop_for_single_index_argument(%arg0: index) {
// expected-error@+1 {{op expected body first argument to be an index argument for the induction variable}}
"scf.for"(%arg0, %arg0, %arg0) (
{
Expand All @@ -75,23 +75,23 @@ func @loop_for_single_index_argument(%arg0: index) {

// -----

func @loop_if_not_i1(%arg0: index) {
func.func @loop_if_not_i1(%arg0: index) {
// expected-error@+1 {{operand #0 must be 1-bit signless integer}}
"scf.if"(%arg0) ({}, {}) : (index) -> ()
return
}

// -----

func @loop_if_more_than_2_regions(%arg0: i1) {
func.func @loop_if_more_than_2_regions(%arg0: i1) {
// expected-error@+1 {{expected 2 regions}}
"scf.if"(%arg0) ({}, {}, {}): (i1) -> ()
return
}

// -----

func @loop_if_not_one_block_per_region(%arg0: i1) {
func.func @loop_if_not_one_block_per_region(%arg0: i1) {
// expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
"scf.if"(%arg0) ({
^bb0:
Expand All @@ -104,7 +104,7 @@ func @loop_if_not_one_block_per_region(%arg0: i1) {

// -----

func @loop_if_illegal_block_argument(%arg0: i1) {
func.func @loop_if_illegal_block_argument(%arg0: i1) {
// expected-error@+1 {{region #0 should have no arguments}}
"scf.if"(%arg0) ({
^bb0(%0 : index):
Expand All @@ -115,7 +115,7 @@ func @loop_if_illegal_block_argument(%arg0: i1) {

// -----

func @parallel_arguments_different_tuple_size(
func.func @parallel_arguments_different_tuple_size(
%arg0: index, %arg1: index, %arg2: index) {
// expected-error@+1 {{custom op 'scf.parallel' expected 1 operands}}
scf.parallel (%i0) = (%arg0) to (%arg1, %arg2) step () {
Expand All @@ -125,7 +125,7 @@ func @parallel_arguments_different_tuple_size(

// -----

func @parallel_body_arguments_wrong_type(
func.func @parallel_body_arguments_wrong_type(
%arg0: index, %arg1: index, %arg2: index) {
// expected-error@+1 {{'scf.parallel' op expects arguments for the induction variable to be of index type}}
"scf.parallel"(%arg0, %arg1, %arg2) ({
Expand All @@ -137,7 +137,7 @@ func @parallel_body_arguments_wrong_type(

// -----

func @parallel_body_wrong_number_of_arguments(
func.func @parallel_body_wrong_number_of_arguments(
%arg0: index, %arg1: index, %arg2: index) {
// expected-error@+1 {{'scf.parallel' op expects the same number of induction variables: 2 as bound and step values: 1}}
"scf.parallel"(%arg0, %arg1, %arg2) ({
Expand All @@ -149,7 +149,7 @@ func @parallel_body_wrong_number_of_arguments(

// -----

func @parallel_no_tuple_elements() {
func.func @parallel_no_tuple_elements() {
// expected-error@+1 {{'scf.parallel' op needs at least one tuple element for lowerBound, upperBound and step}}
scf.parallel () = () to () step () {
}
Expand All @@ -158,7 +158,7 @@ func @parallel_no_tuple_elements() {

// -----

func @parallel_step_not_positive(
func.func @parallel_step_not_positive(
%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
// expected-error@+3 {{constant step operand must be positive}}
%c0 = arith.constant 1 : index
Expand All @@ -170,7 +170,7 @@ func @parallel_step_not_positive(

// -----

func @parallel_fewer_results_than_reduces(
func.func @parallel_fewer_results_than_reduces(
%arg0 : index, %arg1: index, %arg2: index) {
// expected-error@+1 {{expects number of results: 0 to be the same as number of reductions: 1}}
scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
Expand All @@ -185,7 +185,7 @@ func @parallel_fewer_results_than_reduces(

// -----

func @parallel_more_results_than_reduces(
func.func @parallel_more_results_than_reduces(
%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-error@+2 {{expects number of results: 1 to be the same as number of reductions: 0}}
%zero = arith.constant 1.0 : f32
Expand All @@ -197,7 +197,7 @@ func @parallel_more_results_than_reduces(

// -----

func @parallel_more_results_than_initial_values(
func.func @parallel_more_results_than_initial_values(
%arg0 : index, %arg1: index, %arg2: index) {
// expected-error@+1 {{expects number of results: 1 to be the same as number of initial values: 0}}
%res = scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 {
Expand All @@ -210,7 +210,7 @@ func @parallel_more_results_than_initial_values(

// -----

func @parallel_different_types_of_results_and_reduces(
func.func @parallel_different_types_of_results_and_reduces(
%arg0 : index, %arg1: index, %arg2: index) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg1)
Expand All @@ -226,7 +226,7 @@ func @parallel_different_types_of_results_and_reduces(

// -----

func @top_level_reduce(%arg0 : f32) {
func.func @top_level_reduce(%arg0 : f32) {
// expected-error@+1 {{expects parent op 'scf.parallel'}}
scf.reduce(%arg0) : f32 {
^bb0(%lhs : f32, %rhs : f32):
Expand All @@ -237,7 +237,7 @@ func @top_level_reduce(%arg0 : f32) {

// -----

func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
func.func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
Expand All @@ -251,7 +251,7 @@ func @reduce_empty_block(%arg0 : index, %arg1 : f32) {

// -----

func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
func.func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
Expand All @@ -266,7 +266,7 @@ func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {

// -----

func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
func.func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
Expand All @@ -282,7 +282,7 @@ func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {

// -----

func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
func.func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
Expand All @@ -297,7 +297,7 @@ func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {

// -----

func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) {
func.func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
Expand All @@ -313,7 +313,7 @@ func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) {

// -----

func @reduceReturn_not_inside_reduce(%arg0 : f32) {
func.func @reduceReturn_not_inside_reduce(%arg0 : f32) {
"foo.region"() ({
// expected-error@+1 {{expects parent op 'scf.reduce'}}
scf.reduce.return %arg0 : f32
Expand All @@ -323,7 +323,7 @@ func @reduceReturn_not_inside_reduce(%arg0 : f32) {

// -----

func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
{
// expected-error@+1 {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 2}}
%x, %y = scf.if %arg0 -> (f32, f32) {
Expand All @@ -338,7 +338,7 @@ func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)

// -----

func @std_if_missing_else(%arg0: i1, %arg1: f32)
func.func @std_if_missing_else(%arg0: i1, %arg1: f32)
{
// expected-error@+1 {{must have an else block if defining values}}
%x = scf.if %arg0 -> (f32) {
Expand All @@ -350,7 +350,7 @@ func @std_if_missing_else(%arg0: i1, %arg1: f32)

// -----

func @std_for_operands_mismatch(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_operands_mismatch(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1 : i32
// expected-error@+1 {{mismatch in number of loop-carried values and defined values}}
Expand All @@ -365,7 +365,7 @@ func @std_for_operands_mismatch(%arg0 : index, %arg1 : index, %arg2 : index) {

// -----

func @std_for_operands_mismatch_2(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_operands_mismatch_2(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1 : i32
%u0 = arith.constant 1.0 : f32
Expand All @@ -382,7 +382,7 @@ func @std_for_operands_mismatch_2(%arg0 : index, %arg1 : index, %arg2 : index) {

// -----

func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-note@+1 {{prior use here}}
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
Expand All @@ -398,7 +398,7 @@ func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {

// -----

func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
// expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
Expand All @@ -414,7 +414,7 @@ func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {

// -----

func @parallel_invalid_yield(
func.func @parallel_invalid_yield(
%arg0: index, %arg1: index, %arg2: index) {
scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
%c0 = arith.constant 1.0 : f32
Expand All @@ -426,7 +426,7 @@ func @parallel_invalid_yield(

// -----

func @yield_invalid_parent_op() {
func.func @yield_invalid_parent_op() {
"my.op"() ({
// expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}}
scf.yield
Expand All @@ -436,7 +436,7 @@ func @yield_invalid_parent_op() {

// -----

func @while_parser_type_mismatch() {
func.func @while_parser_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{expected as many input types as operands (expected 0 got 1)}}
scf.while : (i32) -> () {
Expand All @@ -448,7 +448,7 @@ func @while_parser_type_mismatch() {

// -----

func @while_bad_terminator() {
func.func @while_bad_terminator() {
// expected-error@+1 {{expects the 'before' region to terminate with 'scf.condition'}}
scf.while : () -> () {
// expected-note@+1 {{terminator here}}
Expand All @@ -460,7 +460,7 @@ func @while_bad_terminator() {

// -----

func @while_cross_region_type_mismatch() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}}
scf.while : () -> () {
Expand All @@ -473,7 +473,7 @@ func @while_cross_region_type_mismatch() {

// -----

func @while_cross_region_type_mismatch() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}}
scf.while : () -> () {
Expand All @@ -486,7 +486,7 @@ func @while_cross_region_type_mismatch() {

// -----

func @while_result_type_mismatch() {
func.func @while_result_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}}
scf.while : () -> () {
Expand All @@ -499,7 +499,7 @@ func @while_result_type_mismatch() {

// -----

func @while_bad_terminator() {
func.func @while_bad_terminator() {
%true = arith.constant true
// expected-error@+1 {{expects the 'after' region to terminate with 'scf.yield'}}
scf.while : () -> () {
Expand All @@ -512,7 +512,7 @@ func @while_bad_terminator() {

// -----

func @execute_region() {
func.func @execute_region() {
// expected-error @+1 {{region cannot have any arguments}}
"scf.execute_region"() ({
^bb0(%i : i32):
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// Epilogue:
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]], %{{.*}} : f32
// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C3]]] : memref<?xf32>
func @simple_pipeline(%A: memref<?xf32>, %result: memref<?xf32>) {
func.func @simple_pipeline(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -59,7 +59,7 @@ func @simple_pipeline(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C6]]] : memref<?xf32>
// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[L2]]#1, %{{.*}} : f32
// CHECK-NEXT: memref.store %[[ADD2]], %[[R]][%[[C9]]] : memref<?xf32>
func @simple_pipeline_step(%A: memref<?xf32>, %result: memref<?xf32>) {
func.func @simple_pipeline_step(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c11 = arith.constant 11 : index
Expand Down Expand Up @@ -114,7 +114,7 @@ func @simple_pipeline_step(%A: memref<?xf32>, %result: memref<?xf32>) {
// ANNOTATE: arith.addf {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "epilogue"}
// ANNOTATE: memref.store {{.*}} {__test_pipelining_iteration = 1 : i32, __test_pipelining_part = "epilogue"}

func @three_stage(%A: memref<?xf32>, %result: memref<?xf32>) {
func.func @three_stage(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -164,7 +164,7 @@ func @three_stage(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-NEXT: memref.store %[[ADD3]], %[[R]][%[[C8]]] : memref<?xf32>
// CHECK-NEXT: %[[ADD4:.*]] = arith.addf %[[LR]]#3, %{{.*}} : f32
// CHECK-NEXT: memref.store %[[ADD4]], %[[R]][%[[C9]]] : memref<?xf32>
func @long_liverange(%A: memref<?xf32>, %result: memref<?xf32>) {
func.func @long_liverange(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -214,7 +214,7 @@ func @long_liverange(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-NEXT: %[[MUL3:.*]] = arith.mulf %[[ADD3]], %[[LR]]#1 : f32
// CHECK-NEXT: memref.store %[[MUL2]], %[[R]][%[[C8]]] : memref<?xf32>
// CHECK-NEXT: memref.store %[[MUL3]], %[[R]][%[[C9]]] : memref<?xf32>
func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
func.func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -250,7 +250,7 @@ func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
// Epilogue:
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LR]]#1, %[[LR]]#0 : f32
// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C0]]] : memref<?xf32>
func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
func.func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -288,7 +288,7 @@ func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
// Epilogue:
// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
// CHECK-NEXT: return %[[ADD2]] : f32
func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
func.func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -324,7 +324,7 @@ func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
// Epilogue:
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
// CHECK-NEXT: return %[[ADD1]] : f32
func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SCF/loop-range.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline='func.func(scf-for-loop-range-folding)' -split-input-file | FileCheck %s

func @fold_one_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
func.func @fold_one_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -28,7 +28,7 @@ func @fold_one_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
// CHECK: %[[I5:.*]] = arith.muli %[[I4]], %[[I4]] : i32
// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]

func @fold_one_loop2(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
func.func @fold_one_loop2(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -61,7 +61,7 @@ func @fold_one_loop2(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
// CHECK: %[[I5:.*]] = arith.muli %[[I4]], %[[I4]] : i32
// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]

func @fold_two_loops(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
func.func @fold_two_loops(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down Expand Up @@ -98,7 +98,7 @@ func @fold_two_loops(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
// If an instruction's operands are not defined outside the loop, we cannot
// perform the optimization, as is the case with the arith.muli below. (If
// paired with loop invariant code motion we can continue.)
func @fold_only_first_add(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
func.func @fold_only_first_add(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Dialect/SCF/loop-unroll.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 annotate=true' | FileCheck %s --check-prefix UNROLL-BY-2-ANNOTATE
// RUN: mlir-opt %s --affine-loop-unroll='unroll-factor=6 unroll-up-to-factor=true' | FileCheck %s --check-prefix UNROLL-UP-TO

func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index,
func.func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index,
%arg3: memref<?xf32>) {
%0 = arith.constant 7.0 : f32
scf.for %i0 = %arg0 to %arg1 step %arg2 {
Expand Down Expand Up @@ -83,7 +83,7 @@ func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index,
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3-NEXT: return

func @dynamic_loop_unroll_outer_by_2(
func.func @dynamic_loop_unroll_outer_by_2(
%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
%arg5 : index, %arg6: memref<?xf32>) {
%0 = arith.constant 7.0 : f32
Expand Down Expand Up @@ -118,7 +118,7 @@ func @dynamic_loop_unroll_outer_by_2(
// UNROLL-OUTER-BY-2-NEXT: }
// UNROLL-OUTER-BY-2-NEXT: return

func @dynamic_loop_unroll_inner_by_2(
func.func @dynamic_loop_unroll_inner_by_2(
%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
%arg5 : index, %arg6: memref<?xf32>) {
%0 = arith.constant 7.0 : f32
Expand Down Expand Up @@ -154,7 +154,7 @@ func @dynamic_loop_unroll_inner_by_2(

// Test that no epilogue clean-up loop is generated because the trip count is
// a multiple of the unroll factor.
func @static_loop_unroll_by_2(%arg0 : memref<?xf32>) {
func.func @static_loop_unroll_by_2(%arg0 : memref<?xf32>) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : index
%ub = arith.constant 20 : index
Expand Down Expand Up @@ -186,7 +186,7 @@ func @static_loop_unroll_by_2(%arg0 : memref<?xf32>) {

// Test that epilogue clean up loop is generated (trip count is not
// a multiple of unroll factor).
func @static_loop_unroll_by_3(%arg0 : memref<?xf32>) {
func.func @static_loop_unroll_by_3(%arg0 : memref<?xf32>) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : index
%ub = arith.constant 20 : index
Expand Down Expand Up @@ -223,7 +223,7 @@ func @static_loop_unroll_by_3(%arg0 : memref<?xf32>) {

// Test that the single iteration epilogue loop body is promoted to the loops
// containing block.
func @static_loop_unroll_by_3_promote_epilogue(%arg0 : memref<?xf32>) {
func.func @static_loop_unroll_by_3_promote_epilogue(%arg0 : memref<?xf32>) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
Expand Down Expand Up @@ -256,7 +256,7 @@ func @static_loop_unroll_by_3_promote_epilogue(%arg0 : memref<?xf32>) {
// UNROLL-BY-3-NEXT: return

// Test unroll-up-to functionality.
func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : index
%ub = arith.constant 2 : index
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/Dialect/SCF/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// Verify the generic form can be parsed.
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s

func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
scf.for %i0 = %arg0 to %arg1 step %arg2 {
scf.for %i1 = %arg0 to %arg1 step %arg2 {
%min_cmp = arith.cmpi slt, %i0, %i1 : index
Expand All @@ -26,7 +26,7 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-NEXT: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {

func @std_if(%arg0: i1, %arg1: f32) {
func.func @std_if(%arg0: i1, %arg1: f32) {
scf.if %arg0 {
%0 = arith.addf %arg1, %arg1 : f32
}
Expand All @@ -36,7 +36,7 @@ func @std_if(%arg0: i1, %arg1: f32) {
// CHECK-NEXT: scf.if %{{.*}} {
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32

func @std_if_else(%arg0: i1, %arg1: f32) {
func.func @std_if_else(%arg0: i1, %arg1: f32) {
scf.if %arg0 {
%0 = arith.addf %arg1, %arg1 : f32
} else {
Expand All @@ -50,7 +50,7 @@ func @std_if_else(%arg0: i1, %arg1: f32) {
// CHECK-NEXT: } else {
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32

func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
func.func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : index, %arg4 : index) {
%step = arith.constant 1 : index
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
Expand Down Expand Up @@ -113,7 +113,7 @@ func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield

func @parallel_explicit_yield(
func.func @parallel_explicit_yield(
%arg0: index, %arg1: index, %arg2: index) {
scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
scf.yield
Expand All @@ -131,7 +131,7 @@ func @parallel_explicit_yield(
// CHECK-NEXT: return
// CHECK-NEXT: }

func @std_if_yield(%arg0: i1, %arg1: f32)
func.func @std_if_yield(%arg0: i1, %arg1: f32)
{
%x, %y = scf.if %arg0 -> (f32, f32) {
%0 = arith.addf %arg1, %arg1 : f32
Expand All @@ -157,7 +157,7 @@ func @std_if_yield(%arg0: i1, %arg1: f32)
// CHECK-NEXT: scf.yield %[[T3]], %[[T4]] : f32, f32
// CHECK-NEXT: }

func @std_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%result = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (f32) {
%sn = arith.addf %si, %si : f32
Expand All @@ -177,7 +177,7 @@ func @std_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-NEXT: }


func @std_for_yield_multi(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_yield_multi(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1 : i32
%u0 = arith.constant 1.0 : f32
Expand All @@ -204,7 +204,7 @@ func @std_for_yield_multi(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-NEXT: scf.yield %[[NEXT1]], %[[NEXT2]], %[[NEXT3]] : f32, i32, f32


func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
func.func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
%sum_0 = arith.constant 0.0 : f32
%c0 = arith.constant 0.0 : f32
%sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) {
Expand Down Expand Up @@ -242,7 +242,7 @@ func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %ste
// CHECK-NEXT: return %[[RESULT]]

// CHECK-LABEL: @while
func @while() {
func.func @while() {
%0 = "test.get_some_value"() : () -> i32
%1 = "test.get_some_value"() : () -> f32

Expand All @@ -265,7 +265,7 @@ func @while() {
}

// CHECK-LABEL: @infinite_while
func @infinite_while() {
func.func @infinite_while() {
%true = arith.constant true

// CHECK: scf.while : () -> () {
Expand All @@ -281,7 +281,7 @@ func @infinite_while() {
}

// CHECK-LABEL: func @execute_region
func @execute_region() -> i64 {
func.func @execute_region() -> i64 {
// CHECK: scf.execute_region -> i64 {
// CHECK-NEXT: arith.constant
// CHECK-NEXT: scf.yield
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func.func(scf-parallel-loop-fusion)' -split-input-file | FileCheck %s

func @fuse_empty_loops() {
func.func @fuse_empty_loops() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -24,7 +24,7 @@ func @fuse_empty_loops() {

// -----

func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -70,7 +70,7 @@ func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,

// -----

func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
%result: memref<100x10xf32>) {
%c100 = arith.constant 100 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -127,7 +127,7 @@ func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,

// -----

func @do_not_fuse_nested_ploop1() {
func.func @do_not_fuse_nested_ploop1() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -149,7 +149,7 @@ func @do_not_fuse_nested_ploop1() {

// -----

func @do_not_fuse_nested_ploop2() {
func.func @do_not_fuse_nested_ploop2() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -171,7 +171,7 @@ func @do_not_fuse_nested_ploop2() {

// -----

func @do_not_fuse_loops_unmatching_num_loops() {
func.func @do_not_fuse_loops_unmatching_num_loops() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -189,7 +189,7 @@ func @do_not_fuse_loops_unmatching_num_loops() {

// -----

func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
func.func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -208,7 +208,7 @@ func @do_not_fuse_loops_with_side_effecting_ops_in_between() {

// -----

func @do_not_fuse_loops_unmatching_iteration_space() {
func.func @do_not_fuse_loops_unmatching_iteration_space() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand All @@ -227,7 +227,7 @@ func @do_not_fuse_loops_unmatching_iteration_space() {

// -----

func @do_not_fuse_unmatching_write_read_patterns(
func.func @do_not_fuse_unmatching_write_read_patterns(
%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
Expand Down Expand Up @@ -258,7 +258,7 @@ func @do_not_fuse_unmatching_write_read_patterns(

// -----

func @do_not_fuse_unmatching_read_write_patterns(
func.func @do_not_fuse_unmatching_read_write_patterns(
%A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -288,7 +288,7 @@ func @do_not_fuse_unmatching_read_write_patterns(

// -----

func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -310,7 +310,7 @@ func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {

// -----

func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SCF/parallel-loop-specialization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#map0 = affine_map<()[s0, s1] -> (1024, s0 - s1)>
#map1 = affine_map<()[s0, s1] -> (64, s0 - s1)>

func @parallel_loop(%outer_i0: index, %outer_i1: index, %A: memref<?x?xf32>, %B: memref<?x?xf32>,
func.func @parallel_loop(%outer_i0: index, %outer_i1: index, %A: memref<?x?xf32>, %B: memref<?x?xf32>,
%C: memref<?x?xf32>, %result: memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline='func.func(scf-parallel-loop-tiling{parallel-loop-tile-sizes=1,4 no-min-max-bounds=true})' -split-input-file | FileCheck %s

func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : index, %arg4 : index, %arg5 : index,
%A: memref<?x?xf32>, %B: memref<?x?xf32>,
%C: memref<?x?xf32>, %result: memref<?x?xf32>) {
Expand Down Expand Up @@ -45,7 +45,7 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,

// -----

func @static_loop_with_step() {
func.func @static_loop_with_step() {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c22 = arith.constant 22 : index
Expand Down Expand Up @@ -76,7 +76,7 @@ func @static_loop_with_step() {

// -----

func @tile_nested_innermost() {
func.func @tile_nested_innermost() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -124,7 +124,7 @@ func @tile_nested_innermost() {

// -----

func @tile_nested_in_non_ploop() {
func.func @tile_nested_in_non_ploop() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline='func.func(scf-parallel-loop-tiling{parallel-loop-tile-sizes=1,4})' -split-input-file | FileCheck %s

func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : index, %arg4 : index, %arg5 : index,
%A: memref<?x?xf32>, %B: memref<?x?xf32>,
%C: memref<?x?xf32>, %result: memref<?x?xf32>) {
Expand Down Expand Up @@ -37,7 +37,7 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,

// -----

func @static_loop_with_step() {
func.func @static_loop_with_step() {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c22 = arith.constant 22 : index
Expand Down Expand Up @@ -67,7 +67,7 @@ func @static_loop_with_step() {

// -----

func @tile_nested_innermost() {
func.func @tile_nested_innermost() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -115,7 +115,7 @@ func @tile_nested_innermost() {

// -----

func @tile_nested_in_non_ploop() {
func.func @tile_nested_in_non_ploop() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Shape/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// CHECK: "test.sink"(%[[TENSOR]]) : (tensor<2xf16>) -> ()
// CHECK: return
// CHECK: }
func @shape_assuming() {
func.func @shape_assuming() {
%0 = shape.const_witness true
%1 = shape.assuming %0 -> (tensor<2xf16>) {
%2 = "test.source"() : () -> (tensor<2xf16>)
Expand Down
224 changes: 112 additions & 112 deletions mlir/test/Dialect/Shape/canonicalize.mlir

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions mlir/test/Dialect/Shape/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
func.func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+1 {{ReduceOp body is expected to have 3 arguments}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size):
Expand All @@ -11,7 +11,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {

// -----

func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
func.func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size):
Expand All @@ -24,7 +24,7 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {

// -----

func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
func.func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType if the ReduceOp operates on a ShapeType}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: f32, %lci: !shape.size):
Expand All @@ -35,7 +35,7 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {

// -----

func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
func.func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of IndexType if the ReduceOp operates on an extent tensor}}
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
^bb0(%index: index, %dim: f32, %lci: index):
Expand All @@ -46,7 +46,7 @@ func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {

// -----

func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
func.func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
// expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> f32 {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
Expand All @@ -57,7 +57,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {

// -----

func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
func.func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+3 {{number of operands does not match number of results of its parent}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
Expand All @@ -68,7 +68,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {

// -----

func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
func.func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+4 {{types mismatch between yield op and its parent}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
Expand All @@ -80,15 +80,15 @@ func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {

// -----

func @assuming_all_op_too_few_operands() {
func.func @assuming_all_op_too_few_operands() {
// expected-error@+1 {{no operands specified}}
%w0 = shape.assuming_all
return
}

// -----

func @shape_of(%value_arg : !shape.value_shape,
func.func @shape_of(%value_arg : !shape.value_shape,
%shaped_arg : tensor<?x3x4xf32>) {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
%0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
Expand All @@ -97,23 +97,23 @@ func @shape_of(%value_arg : !shape.value_shape,

// -----

func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
func.func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
// expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}}
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex>
return
}

// -----

func @rank(%arg : !shape.shape) {
func.func @rank(%arg : !shape.shape) {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%0 = shape.rank %arg : !shape.shape -> index
return
}

// -----

func @get_extent(%arg : tensor<?xindex>) -> index {
func.func @get_extent(%arg : tensor<?xindex>) -> index {
%c0 = shape.const_size 0
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
Expand All @@ -122,31 +122,31 @@ func @get_extent(%arg : tensor<?xindex>) -> index {

// -----

func @mul(%lhs : !shape.size, %rhs : index) -> index {
func.func @mul(%lhs : !shape.size, %rhs : index) -> index {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.mul %lhs, %rhs : !shape.size, index -> index
return %result : index
}

// -----

func @num_elements(%arg : !shape.shape) -> index {
func.func @num_elements(%arg : !shape.shape) -> index {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.num_elements %arg : !shape.shape -> index
return %result : index
}

// -----

func @add(%lhs : !shape.size, %rhs : index) -> index {
func.func @add(%lhs : !shape.size, %rhs : index) -> index {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.add %lhs, %rhs : !shape.size, index -> index
return %result : index
}

// -----

func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
func.func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
%result = shape.broadcast %arg0, %arg1
: !shape.shape, !shape.shape -> tensor<?xindex>
Expand All @@ -156,7 +156,7 @@ func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {

// -----

func @broadcast(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
func.func @broadcast(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
%result = shape.broadcast %arg0, %arg1
: !shape.shape, tensor<?xindex> -> tensor<?xindex>
Expand Down Expand Up @@ -231,7 +231,7 @@ shape.function_library @shape_lib {
// expected-error@+1 {{required to be shape function library}}
module attributes {shape.lib = @fn} {

func @fn(%arg: !shape.value_shape) -> !shape.shape {
func.func @fn(%arg: !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
}
Expand All @@ -242,7 +242,7 @@ func @fn(%arg: !shape.value_shape) -> !shape.shape {

// Test that op referred to by shape lib is a shape function library.

func @fn(%arg: !shape.value_shape) -> !shape.shape {
func.func @fn(%arg: !shape.value_shape) -> !shape.shape {
// expected-error@+1 {{SymbolTable}}
%0 = shape.shape_of %arg {shape.lib = @fn} : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
Expand All @@ -257,7 +257,7 @@ module attributes {shape.lib = @fn} { }

// -----

func @fn(%arg: !shape.shape) -> !shape.witness {
func.func @fn(%arg: !shape.shape) -> !shape.witness {
// expected-error@+1 {{required at least 2 input shapes}}
%0 = shape.cstr_broadcastable %arg : !shape.shape
return %0 : !shape.witness
Expand All @@ -267,7 +267,7 @@ func @fn(%arg: !shape.shape) -> !shape.witness {

// Test that type inference flags the wrong return type.

func @const_shape() {
func.func @const_shape() {
// expected-error@+1 {{'tensor<3xindex>' are incompatible with return type(s) of operation 'tensor<2xindex>'}}
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
return
Expand Down
88 changes: 44 additions & 44 deletions mlir/test/Dialect/Shape/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s

// CHECK-LABEL: shape_num_elements
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
func.func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
%init = shape.const_size 1
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
Expand All @@ -16,7 +16,7 @@ func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
}

// CHECK-LABEL: extent_tensor_num_elements
func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
func.func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
%init = arith.constant 1 : index
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
^bb0(%index : index, %extent : index, %acc : index):
Expand All @@ -26,78 +26,78 @@ func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
return %num_elements : index
}

func @test_shape_num_elements_unknown() {
func.func @test_shape_num_elements_unknown() {
%0 = "shape.unknown_shape"() : () -> !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%2 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}

func @const_shape() {
func.func @const_shape() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
return
}

func @test_shape_num_elements_fixed() {
func.func @test_shape_num_elements_fixed() {
%0 = shape.const_shape [1, 57, 92] : !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}

func @test_broadcast_fixed() {
func.func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}

func @test_broadcast_extents() -> tensor<4xindex> {
func.func @test_broadcast_extents() -> tensor<4xindex> {
%0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex>
%1 = shape.const_shape [4, 57, 92] : tensor<3xindex>
%2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex>
return %2 : tensor<4xindex>
}

func @test_shape_any_fixed() {
func.func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}

func @test_shape_any_unknown() {
func.func @test_shape_any_unknown() {
%0 = shape.const_shape [4, -1, 92] : !shape.shape
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}

func @test_shape_any_fixed_mismatch() {
func.func @test_shape_any_fixed_mismatch() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [2, 57, 92] : !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}

func @test_parse_const_shape() {
func.func @test_parse_const_shape() {
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.const_shape [1, 2, 3] : tensor<3xindex>
return
}

func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
func.func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
return %0 : tensor<?xindex>
}

func @test_constraints() {
func.func @test_constraints() {
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%true = arith.constant true
Expand All @@ -114,19 +114,19 @@ func @test_constraints() {
return
}

func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
func.func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
%rhs : tensor<?xindex>) {
%w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
return
}

func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
func.func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
%rhs : tensor<?xindex>) {
%w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
return
}

func @mul(%size_arg : !shape.size, %index_arg : index) {
func.func @mul(%size_arg : !shape.size, %index_arg : index) {
%size_prod = shape.mul %size_arg, %size_arg
: !shape.size, !shape.size -> !shape.size
%index_prod = shape.mul %index_arg, %index_arg : index, index -> index
Expand All @@ -135,7 +135,7 @@ func @mul(%size_arg : !shape.size, %index_arg : index) {
return
}

func @div(%size_arg : !shape.size, %index_arg : index) {
func.func @div(%size_arg : !shape.size, %index_arg : index) {
%size_div = shape.div %size_arg, %size_arg
: !shape.size, !shape.size -> !shape.size
%index_div = shape.div %index_arg, %index_arg : index, index -> index
Expand All @@ -144,7 +144,7 @@ func @div(%size_arg : !shape.size, %index_arg : index) {
return
}

func @add(%size_arg : !shape.size, %index_arg : index) {
func.func @add(%size_arg : !shape.size, %index_arg : index) {
%size_sum = shape.add %size_arg, %size_arg
: !shape.size, !shape.size -> !shape.size
%index_sum = shape.add %index_arg, %index_arg : index, index -> index
Expand All @@ -153,7 +153,7 @@ func @add(%size_arg : !shape.size, %index_arg : index) {
return
}

func @const_size() {
func.func @const_size() {
// CHECK: %c1 = shape.const_size 1
// CHECK: %c2 = shape.const_size 2
// CHECK: %c2_0 = shape.const_size 2
Expand All @@ -163,66 +163,66 @@ func @const_size() {
return
}

func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
func.func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
%0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
return %0 : tensor<3xindex>
}

func @test_identity_to_extent_tensor(%arg: tensor<3xindex>) -> tensor<3xindex> {
func.func @test_identity_to_extent_tensor(%arg: tensor<3xindex>) -> tensor<3xindex> {
%0 = shape.to_extent_tensor %arg : tensor<3xindex> -> tensor<3xindex>
return %0 : tensor<3xindex>
}

func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
func.func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
%0 = shape.from_extent_tensor %arg : tensor<?xindex>
return %0 : !shape.shape
}

func @rank(%shape : !shape.shape) -> !shape.size {
func.func @rank(%shape : !shape.shape) -> !shape.size {
%rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}

func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
func.func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}

func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
func.func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}

func @shape_eq_on_tensors(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
func.func @shape_eq_on_tensors(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
return %result : i1
}

func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
func.func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
return %result : i1
}

func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
func.func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
%c0 = shape.const_size 0
%result = shape.get_extent %arg, %c0 :
!shape.shape, !shape.size -> !shape.size
return %result : !shape.size
}

func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
%c0 = arith.constant 0 : index
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
return %result : index
}

func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
%c0 = shape.const_size 0
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
return %result : !shape.size
}

func @any() {
func.func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
Expand All @@ -232,93 +232,93 @@ func @any() {
return
}

func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index {
func.func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index {
%result = shape.num_elements %arg : tensor<?xindex> -> index
return %result : index
}

func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
func.func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
%result = shape.num_elements %arg : !shape.shape -> !shape.size
return %result : !shape.size
}

// Testing invoking shape function from another. shape_equal_shapes is merely
// a trivial helper function to invoke elsewhere.
func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
func.func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
return %2 : !shape.shape
}
func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
func.func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
%2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
return %2 : !shape.shape
}

func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
-> !shape.shape {
%result = shape.any %a, %b, %c
: !shape.shape, !shape.shape, !shape.shape -> !shape.shape
return %result : !shape.shape
}

func @any_on_mixed(%a : tensor<?xindex>,
func.func @any_on_mixed(%a : tensor<?xindex>,
%b : tensor<?xindex>,
%c : !shape.shape) -> !shape.shape {
%result = shape.any %a, %b, %c
: tensor<?xindex>, tensor<?xindex>, !shape.shape -> !shape.shape
return %result : !shape.shape
}

func @any_on_extent_tensors(%a : tensor<?xindex>,
func.func @any_on_extent_tensors(%a : tensor<?xindex>,
%b : tensor<?xindex>,
%c : tensor<?xindex>) -> tensor<?xindex> {
%result = shape.any %a, %b, %c
: tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return %result : tensor<?xindex>
}

func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
func.func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
%b : tensor<?xindex>) -> i1 {
%result = shape.is_broadcastable %a, %b
: tensor<?xindex>, tensor<?xindex>
return %result : i1
}

func @is_broadcastable_on_shapes(%a : !shape.shape,
func.func @is_broadcastable_on_shapes(%a : !shape.shape,
%b : !shape.shape) -> i1 {
%result = shape.is_broadcastable %a, %b
: !shape.shape, !shape.shape
return %result : i1
}

func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
func.func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}

func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
func.func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}

func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
func.func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 5
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}

func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 9
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Shape/remove-shape-constraints.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// Check that cstr_broadcastable is removed.
//
// CHECK-BOTH: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
// REPLACE-NOT: shape.cstr_eq
// REPLACE: shape.assuming %[[WITNESS]]
Expand All @@ -23,7 +23,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// Check that cstr_eq is removed.
//
// CHECK-BOTH: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
// REPLACE-NOT: shape.cstr_eq
// REPLACE: shape.assuming %[[WITNESS]]
Expand All @@ -42,7 +42,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// should be removed still.
//
// CHECK-BOTH: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
func.func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// CANON-NEXT: test.source
// CANON-NEXT: return
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Shape/shape-to-shape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// CHECK-LABEL: func @num_elements_to_reduce
// CHECK-SAME: ([[ARG:%.*]]: !shape.shape) -> !shape.size
func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
func.func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
%num_elements = shape.num_elements %shape : !shape.shape -> !shape.size
return %num_elements : !shape.size
}
Expand All @@ -18,7 +18,7 @@ func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {

// CHECK-LABEL: func @num_elements_to_reduce_on_index
// CHECK-SAME: ([[ARG:%.*]]: tensor<?xindex>) -> index
func @num_elements_to_reduce_on_index(%shape : tensor<?xindex>) -> index {
func.func @num_elements_to_reduce_on_index(%shape : tensor<?xindex>) -> index {
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
return %num_elements : index
}
Expand Down