[MLIR][Arith][Vector] Reject i0 integer type in arith and vector ops#183589
Conversation
|
@kuhar : is this what you had in mind? This may deserve an RFC though. |
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesAdd ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width). Changes:
Full diff: https://github.com/llvm/llvm-project/pull/183589.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3d8517c56e784..6c8a40970bc0a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+ AnyNonZeroBitwidthSignlessIntegerOrIndex,
+ "signless-non-zero-bitwidth-integer-like">;
+
// Base class for Arith dialect ops. Ops in this dialect have no memory
// effects and can be applied element-wise to vectors and tensors.
class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)>;
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
// Base class for integer binary operations without undefined behavior.
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
// Casts do not accept indices. Type constraint for signless-integer-like types
// excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation for arithmetic operations.
def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
- AnySignlessInteger.predicate,
- VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
- TensorOf<[AnySignlessInteger]>.predicate]>,
+ AnyNonZeroBitwidthSignlessInteger.predicate,
+ VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+ TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
"signless-fixed-width-integer-like">;
// Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
@@ -161,10 +169,10 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs,
UnitAttr:$isExact)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
attr-dict `:` type($result) }];
@@ -293,8 +301,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
@@ -434,8 +442,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -477,8 +485,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -1573,8 +1581,8 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
// Index cast can convert between memrefs of signless integers and indices too.
def IndexCastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerOrIndexLike.predicate,
- MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
+ Arith_SignlessIntegerOrIndexLike.predicate,
+ MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
"signless-integer-like or memref of signless-integer">;
def Arith_IndexCastOp
@@ -1737,8 +1745,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
- SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs);
+ Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs);
let hasFolder = 1;
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..87a15cba8a0e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,15 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted. Using
+// VectorOfAnyRankOf preserves the ::mlir::VectorType C++ class so that ODS
+// generates TypedValue<VectorType> (VectorValue) for op results.
+def AnyVectorNonZeroBitwidthIntElem : VectorOfAnyRankOf<[
+ Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -2454,8 +2463,8 @@ def Vector_ShapeCastOp :
def Vector_BitCastOp :
Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
- Arguments<(ins AnyVectorOfAnyRank:$source)>,
- Results<(outs AnyVectorOfAnyRank:$result)>{
+ Arguments<(ins AnyVectorNonZeroBitwidthIntElem:$source)>,
+ Results<(outs AnyVectorNonZeroBitwidthIntElem:$result)>{
let summary = "bitcast casts between vectors";
let description = [{
The bitcast operation casts between vectors of the same rank, the minor 1-D
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..409672d46c55e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,18 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
"signless integer or index">;
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : Type<
+ And<[CPred<"$_self.isSignlessInteger()">,
+ CPred<"!$_self.isInteger(0)">]>,
+ "non-zero-bitwidth signless integer", "::mlir::IntegerType">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+ Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+ CPred<"::llvm::isa<::mlir::IndexType>($_self)">]>,
+ "non-zero-bitwidth signless integer or index">;
+
// Floating point types.
// Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..7713f93462b7f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3299,87 +3299,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// CHECK-LABEL: @extsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extsi = arith.extsi %c0 : i0 to i16
- return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extui = arith.extui %c0 : i0 to i16
- return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
- %cFF = arith.constant 0xFF : i8
- %trunc = arith.trunci %cFF : i8 to i0
- return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shli = arith.shli %c0, %c0 : i0
- return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrsi = arith.shrsi %c0, %c0 : i0
- return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrui = arith.shrui %c0, %c0 : i0
- return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %maxsi = arith.maxsi %c0, %c0 : i0
- return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %minsi = arith.minsi %c0, %c0 : i0
- return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
- %c0 = arith.constant 0 : i0
- %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
- return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
// CHECK-LABEL: @sequences_fastmath_contract
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..cf404b7c389f5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
func.func @func_with_ops(f32) {
^bb0(%a : f32):
- // expected-error@+1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+ // expected-error@+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%sf = arith.addi %a, %a : f32
}
// -----
func.func @func_with_ops(%a: f32) {
- // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+ // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%r:2 = arith.addui_extended %a, %a : f32, i32
return
}
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
// Integer comparisons are not recognized for float types.
func.func @func_with_ops(f32, f32) {
^bb0(%a : f32, %b : f32):
- %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+ %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
}
// -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
// -----
func.func @invalid_cmp_shape(%idx : () -> ()) {
- // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+ // expected-error@+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
%cmp = arith.cmpi eq, %idx, %idx : () -> ()
// -----
@@ -877,3 +877,29 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
%0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
return
}
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.addi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+ %0 = arith.addi %a, %b : vector<4xi0>
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+ // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.trunci %a : i32 to i0
+ return %0 : i0
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..51254b524920c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,12 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
}
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+ return %0 : vector<4xi0>
+}
|
|
@llvm/pr-subscribers-mlir-vector Author: Mehdi Amini (joker-eph) ChangesAdd ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width). Changes:
Full diff: https://github.com/llvm/llvm-project/pull/183589.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3d8517c56e784..6c8a40970bc0a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+ AnyNonZeroBitwidthSignlessIntegerOrIndex,
+ "signless-non-zero-bitwidth-integer-like">;
+
// Base class for Arith dialect ops. Ops in this dialect have no memory
// effects and can be applied element-wise to vectors and tensors.
class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)>;
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
// Base class for integer binary operations without undefined behavior.
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
// Casts do not accept indices. Type constraint for signless-integer-like types
// excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation for arithmetic operations.
def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
- AnySignlessInteger.predicate,
- VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
- TensorOf<[AnySignlessInteger]>.predicate]>,
+ AnyNonZeroBitwidthSignlessInteger.predicate,
+ VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+ TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
"signless-fixed-width-integer-like">;
// Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
@@ -161,10 +169,10 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs,
UnitAttr:$isExact)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
attr-dict `:` type($result) }];
@@ -293,8 +301,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
@@ -434,8 +442,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -477,8 +485,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -1573,8 +1581,8 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
// Index cast can convert between memrefs of signless integers and indices too.
def IndexCastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerOrIndexLike.predicate,
- MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
+ Arith_SignlessIntegerOrIndexLike.predicate,
+ MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
"signless-integer-like or memref of signless-integer">;
def Arith_IndexCastOp
@@ -1737,8 +1745,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
- SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs);
+ Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs);
let hasFolder = 1;
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..87a15cba8a0e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,15 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted. Using
+// VectorOfAnyRankOf preserves the ::mlir::VectorType C++ class so that ODS
+// generates TypedValue<VectorType> (VectorValue) for op results.
+def AnyVectorNonZeroBitwidthIntElem : VectorOfAnyRankOf<[
+ Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -2454,8 +2463,8 @@ def Vector_ShapeCastOp :
def Vector_BitCastOp :
Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
- Arguments<(ins AnyVectorOfAnyRank:$source)>,
- Results<(outs AnyVectorOfAnyRank:$result)>{
+ Arguments<(ins AnyVectorNonZeroBitwidthIntElem:$source)>,
+ Results<(outs AnyVectorNonZeroBitwidthIntElem:$result)>{
let summary = "bitcast casts between vectors";
let description = [{
The bitcast operation casts between vectors of the same rank, the minor 1-D
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..409672d46c55e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,18 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
"signless integer or index">;
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : Type<
+ And<[CPred<"$_self.isSignlessInteger()">,
+ CPred<"!$_self.isInteger(0)">]>,
+ "non-zero-bitwidth signless integer", "::mlir::IntegerType">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+ Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+ CPred<"::llvm::isa<::mlir::IndexType>($_self)">]>,
+ "non-zero-bitwidth signless integer or index">;
+
// Floating point types.
// Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..7713f93462b7f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3299,87 +3299,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// CHECK-LABEL: @extsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extsi = arith.extsi %c0 : i0 to i16
- return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extui = arith.extui %c0 : i0 to i16
- return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
- %cFF = arith.constant 0xFF : i8
- %trunc = arith.trunci %cFF : i8 to i0
- return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shli = arith.shli %c0, %c0 : i0
- return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrsi = arith.shrsi %c0, %c0 : i0
- return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrui = arith.shrui %c0, %c0 : i0
- return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %maxsi = arith.maxsi %c0, %c0 : i0
- return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %minsi = arith.minsi %c0, %c0 : i0
- return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
- %c0 = arith.constant 0 : i0
- %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
- return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
// CHECK-LABEL: @sequences_fastmath_contract
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..cf404b7c389f5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
func.func @func_with_ops(f32) {
^bb0(%a : f32):
- // expected-error@+1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+ // expected-error@+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%sf = arith.addi %a, %a : f32
}
// -----
func.func @func_with_ops(%a: f32) {
- // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+ // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%r:2 = arith.addui_extended %a, %a : f32, i32
return
}
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
// Integer comparisons are not recognized for float types.
func.func @func_with_ops(f32, f32) {
^bb0(%a : f32, %b : f32):
- %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+ %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
}
// -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
// -----
func.func @invalid_cmp_shape(%idx : () -> ()) {
- // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+ // expected-error@+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
%cmp = arith.cmpi eq, %idx, %idx : () -> ()
// -----
@@ -877,3 +877,29 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
%0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
return
}
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.addi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+ %0 = arith.addi %a, %b : vector<4xi0>
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+ // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.trunci %a : i32 to i0
+ return %0 : i0
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..51254b524920c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,12 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
}
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+ return %0 : vector<4xi0>
+}
|
|
@llvm/pr-subscribers-mlir-arith Author: Mehdi Amini (joker-eph) ChangesAdd ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width). Changes:
Full diff: https://github.com/llvm/llvm-project/pull/183589.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3d8517c56e784..6c8a40970bc0a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+ AnyNonZeroBitwidthSignlessIntegerOrIndex,
+ "signless-non-zero-bitwidth-integer-like">;
+
// Base class for Arith dialect ops. Ops in this dialect have no memory
// effects and can be applied element-wise to vectors and tensors.
class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)>;
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
// Base class for integer binary operations without undefined behavior.
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
// Casts do not accept indices. Type constraint for signless-integer-like types
// excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation for arithmetic operations.
def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
- AnySignlessInteger.predicate,
- VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
- TensorOf<[AnySignlessInteger]>.predicate]>,
+ AnyNonZeroBitwidthSignlessInteger.predicate,
+ VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+ TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
"signless-fixed-width-integer-like">;
// Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
@@ -161,10 +169,10 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs,
UnitAttr:$isExact)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
attr-dict `:` type($result) }];
@@ -293,8 +301,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
@@ -434,8 +442,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -477,8 +485,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -1573,8 +1581,8 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
// Index cast can convert between memrefs of signless integers and indices too.
def IndexCastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerOrIndexLike.predicate,
- MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
+ Arith_SignlessIntegerOrIndexLike.predicate,
+ MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
"signless-integer-like or memref of signless-integer">;
def Arith_IndexCastOp
@@ -1737,8 +1745,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
- SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs);
+ Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs);
let hasFolder = 1;
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..87a15cba8a0e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,15 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted. Using
+// VectorOfAnyRankOf preserves the ::mlir::VectorType C++ class so that ODS
+// generates TypedValue<VectorType> (VectorValue) for op results.
+def AnyVectorNonZeroBitwidthIntElem : VectorOfAnyRankOf<[
+ Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -2454,8 +2463,8 @@ def Vector_ShapeCastOp :
def Vector_BitCastOp :
Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
- Arguments<(ins AnyVectorOfAnyRank:$source)>,
- Results<(outs AnyVectorOfAnyRank:$result)>{
+ Arguments<(ins AnyVectorNonZeroBitwidthIntElem:$source)>,
+ Results<(outs AnyVectorNonZeroBitwidthIntElem:$result)>{
let summary = "bitcast casts between vectors";
let description = [{
The bitcast operation casts between vectors of the same rank, the minor 1-D
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..409672d46c55e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,18 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
"signless integer or index">;
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : Type<
+ And<[CPred<"$_self.isSignlessInteger()">,
+ CPred<"!$_self.isInteger(0)">]>,
+ "non-zero-bitwidth signless integer", "::mlir::IntegerType">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+ Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+ CPred<"::llvm::isa<::mlir::IndexType>($_self)">]>,
+ "non-zero-bitwidth signless integer or index">;
+
// Floating point types.
// Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..7713f93462b7f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3299,87 +3299,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// CHECK-LABEL: @extsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extsi = arith.extsi %c0 : i0 to i16
- return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extui = arith.extui %c0 : i0 to i16
- return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
- %cFF = arith.constant 0xFF : i8
- %trunc = arith.trunci %cFF : i8 to i0
- return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shli = arith.shli %c0, %c0 : i0
- return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrsi = arith.shrsi %c0, %c0 : i0
- return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrui = arith.shrui %c0, %c0 : i0
- return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %maxsi = arith.maxsi %c0, %c0 : i0
- return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %minsi = arith.minsi %c0, %c0 : i0
- return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
- %c0 = arith.constant 0 : i0
- %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
- return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
// CHECK-LABEL: @sequences_fastmath_contract
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..cf404b7c389f5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
func.func @func_with_ops(f32) {
^bb0(%a : f32):
- // expected-error@+1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+ // expected-error@+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%sf = arith.addi %a, %a : f32
}
// -----
func.func @func_with_ops(%a: f32) {
- // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+ // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%r:2 = arith.addui_extended %a, %a : f32, i32
return
}
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
// Integer comparisons are not recognized for float types.
func.func @func_with_ops(f32, f32) {
^bb0(%a : f32, %b : f32):
- %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+ %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
}
// -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
// -----
func.func @invalid_cmp_shape(%idx : () -> ()) {
- // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+ // expected-error@+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
%cmp = arith.cmpi eq, %idx, %idx : () -> ()
// -----
@@ -877,3 +877,29 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
%0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
return
}
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.addi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+ %0 = arith.addi %a, %b : vector<4xi0>
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+ // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.trunci %a : i32 to i0
+ return %0 : i0
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..51254b524920c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,12 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
}
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+ return %0 : vector<4xi0>
+}
|
Yes, exactly. I strongly prefer this over band aid patches that fix non-zero bitwdith assumptions scattered across the codebase. I'm not sure if an RFC is necessary -- I think i0 was allowed unintentionally after most of these dialects was already in place and we missed ODS constraints. I'm not aware of anyone using these, but if you want to create one, it should be pretty straightforward. |
|
Actually, we do have tests for i0, so someone must have considered it at some point -- seems like an RFC would make sense after all. |
Why only |
Add ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width). Changes: - Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex` to CommonTypeConstraints.td. - Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps `AnyNonZeroBitwidthSignlessIntegerOrIndex` via `TypeOrValueSemanticsContainer`, and update `SignlessFixedWidthIntegerLike` to use `AnyNonZeroBitwidthSignlessInteger`. Replace all uses of the shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new dialect-local constraint. - Update `IndexCastTypeConstraint` to use `Arith_SignlessIntegerOrIndexLike`. - Update `BitcastTypeConstraint` to exclude i0 by composing the already- defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints, keeping the definition compact (3 alternatives instead of 7). - Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in VectorOps.td and apply them to `vector.contract`, `vector.reduction`, `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and `vector.scan`. - Update arith/invalid.mlir with explicit i0 rejection tests covering all integer op families (binary ops, cast ops, extended-multiply ops, cmpi, bitcast, index_cast, index_castui) for both scalar and vector<N> forms. - Update vector/invalid.mlir with i0 rejection tests for all covered ops. - Remove the now-invalid i0 canonicalization tests from arith/canonicalize.mlir.
|
I added this mention for the arith dialect: I think the RFC is converged enough to proceed, happy to roll this back if/when we get a good use-case to reverse the tradeoff analysis! |
…lvm#183589) Add ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width). Changes: - Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex` to CommonTypeConstraints.td. - Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps `AnyNonZeroBitwidthSignlessIntegerOrIndex` via `TypeOrValueSemanticsContainer`, and update `SignlessFixedWidthIntegerLike` to use `AnyNonZeroBitwidthSignlessInteger`. Replace all uses of the shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new dialect-local constraint. - Update `IndexCastTypeConstraint` to use `Arith_SignlessIntegerOrIndexLike`. - Update `BitcastTypeConstraint` to exclude i0 by composing the already- defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints, keeping the definition compact (3 alternatives instead of 7). - Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in VectorOps.td and apply them to `vector.contract`, `vector.reduction`, `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and `vector.scan`. - Update arith/invalid.mlir with explicit i0 rejection tests covering all integer op families (binary ops, cast ops, extended-multiply ops, cmpi, bitcast, index_cast, index_castui) for both scalar and vector<N> forms. - Update vector/invalid.mlir with i0 rejection tests for all covered ops. - Remove the now-invalid i0 canonicalization tests from arith/canonicalize.mlir. Fixes llvm#177822 Fixes llvm#179266 Fixes llvm#180463 Fixes llvm#181532 See also https://discourse.llvm.org/t/rfc-reject-i0-integer-type-in-arith-and-vector-ops/90011
…lvm#183589) Add ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width). Changes: - Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex` to CommonTypeConstraints.td. - Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps `AnyNonZeroBitwidthSignlessIntegerOrIndex` via `TypeOrValueSemanticsContainer`, and update `SignlessFixedWidthIntegerLike` to use `AnyNonZeroBitwidthSignlessInteger`. Replace all uses of the shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new dialect-local constraint. - Update `IndexCastTypeConstraint` to use `Arith_SignlessIntegerOrIndexLike`. - Update `BitcastTypeConstraint` to exclude i0 by composing the already- defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints, keeping the definition compact (3 alternatives instead of 7). - Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in VectorOps.td and apply them to `vector.contract`, `vector.reduction`, `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and `vector.scan`. - Update arith/invalid.mlir with explicit i0 rejection tests covering all integer op families (binary ops, cast ops, extended-multiply ops, cmpi, bitcast, index_cast, index_castui) for both scalar and vector<N> forms. - Update vector/invalid.mlir with i0 rejection tests for all covered ops. - Remove the now-invalid i0 canonicalization tests from arith/canonicalize.mlir. Fixes llvm#177822 Fixes llvm#179266 Fixes llvm#180463 Fixes llvm#181532 See also https://discourse.llvm.org/t/rfc-reject-i0-integer-type-in-arith-and-vector-ops/90011
Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects. i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).
Changes:
AnyNonZeroBitwidthSignlessInteger(as aConfinedTypeoverAnySignlessInteger) andAnyNonZeroBitwidthSignlessIntegerOrIndexto CommonTypeConstraints.td.
Arith_SignlessIntegerOrIndexLikein ArithOps.td that wrapsAnyNonZeroBitwidthSignlessIntegerOrIndexviaTypeOrValueSemanticsContainer, and updateSignlessFixedWidthIntegerLiketo use
AnyNonZeroBitwidthSignlessInteger. Replace all uses of theshared
SignlessIntegerOrIndexLikein ArithOps.td with the newdialect-local constraint.
IndexCastTypeConstraintto useArith_SignlessIntegerOrIndexLike.BitcastTypeConstraintto exclude i0 by composing the already-defined
SignlessFixedWidthIntegerLikeandFloatLikeconstraints,keeping the definition compact (3 alternatives instead of 7).
AnyVectorOfNonI0ElemandAnyVectorOfNonZeroRankNonI0EleminVectorOps.td and apply them to
vector.contract,vector.reduction,vector.multi_reduction,vector.outerproduct,vector.bitcast, andvector.scan.integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
bitcast, index_cast, index_castui) for both scalar and vector forms.
arith/canonicalize.mlir.
Fixes #177822
Fixes #179266
Fixes #180463
Fixes #181532
See also https://discourse.llvm.org/t/rfc-reject-i0-integer-type-in-arith-and-vector-ops/90011