Skip to content

[MLIR][Arith][Vector] Reject i0 integer type in arith and vector ops#183589

Merged
joker-eph merged 1 commit into
llvm:mainfrom
joker-eph:i0
Mar 4, 2026
Merged

[MLIR][Arith][Vector] Reject i0 integer type in arith and vector ops#183589
joker-eph merged 1 commit into
llvm:mainfrom
joker-eph:i0

Conversation

@joker-eph
Copy link
Copy Markdown
Contributor

@joker-eph joker-eph commented Feb 26, 2026

Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects. i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:

  • Add AnyNonZeroBitwidthSignlessInteger (as a ConfinedType over
    AnySignlessInteger) and AnyNonZeroBitwidthSignlessIntegerOrIndex
    to CommonTypeConstraints.td.
  • Introduce Arith_SignlessIntegerOrIndexLike in ArithOps.td that wraps
    AnyNonZeroBitwidthSignlessIntegerOrIndex via
    TypeOrValueSemanticsContainer, and update SignlessFixedWidthIntegerLike
    to use AnyNonZeroBitwidthSignlessInteger. Replace all uses of the
    shared SignlessIntegerOrIndexLike in ArithOps.td with the new
    dialect-local constraint.
  • Update IndexCastTypeConstraint to use Arith_SignlessIntegerOrIndexLike.
  • Update BitcastTypeConstraint to exclude i0 by composing the already-
    defined SignlessFixedWidthIntegerLike and FloatLike constraints,
    keeping the definition compact (3 alternatives instead of 7).
  • Add AnyVectorOfNonI0Elem and AnyVectorOfNonZeroRankNonI0Elem in
    VectorOps.td and apply them to vector.contract, vector.reduction,
    vector.multi_reduction, vector.outerproduct, vector.bitcast, and
    vector.scan.
  • Update arith/invalid.mlir with explicit i0 rejection tests covering all
    integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
    bitcast, index_cast, index_castui) for both scalar and vector forms.
  • Update vector/invalid.mlir with i0 rejection tests for all covered ops.
  • Remove the now-invalid i0 canonicalization tests from
    arith/canonicalize.mlir.

Fixes #177822
Fixes #179266
Fixes #180463
Fixes #181532

See also https://discourse.llvm.org/t/rfc-reject-i0-integer-type-in-arith-and-vector-ops/90011

@joker-eph
Copy link
Copy Markdown
Contributor Author

@kuhar : is this what you had in mind? This may deserve an RFC though.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 26, 2026

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

Add ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:

  • Add AnyNonZeroBitwidthSignlessInteger and AnyNonZeroBitwidthSignlessIntegerOrIndex to CommonTypeConstraints.td.
  • Introduce Arith_SignlessIntegerOrIndexLike in ArithOps.td that uses the new non-zero-width base type, and update SignlessFixedWidthIntegerLike likewise. Replace all uses of the shared SignlessIntegerOrIndexLike in ArithOps.td with the new dialect-local constraint.
  • Add AnyVectorNonZeroBitwidthIntElem in VectorOps.td and apply it to vector.bitcast source and result types.
  • Update arith/invalid.mlir and vector/invalid.mlir with explicit rejection tests for i0 scalar and vector element types.
  • Remove the now-invalid i0 canonicalization tests from arith/canonicalize.mlir.

Full diff: https://github.com/llvm/llvm-project/pull/183589.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+28-20)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+11-2)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+12)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-81)
  • (modified) mlir/test/Dialect/Arith/invalid.mlir (+30-4)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3d8517c56e784..6c8a40970bc0a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+    AnyNonZeroBitwidthSignlessIntegerOrIndex,
+    "signless-non-zero-bitwidth-integer-like">;
+
 // Base class for Arith dialect ops. Ops in this dialect have no memory
 // effects and can be applied element-wise to vectors and tensors.
 class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)>;
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
 
 // Base class for integer binary operations without undefined behavior.
 class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
 
 // Casts do not accept indices. Type constraint for signless-integer-like types
 // excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation for arithmetic operations.
 def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
-        AnySignlessInteger.predicate,
-        VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
-        TensorOf<[AnySignlessInteger]>.predicate]>,
+        AnyNonZeroBitwidthSignlessInteger.predicate,
+        VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+        TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
     "signless-fixed-width-integer-like">;
 
 // Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
     Arith_BinaryOp<mnemonic, traits #
       [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
       DefaultValuedAttr<
         Arith_IntegerOverflowAttr,
         "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
                           attr-dict `:` type($result) }];
@@ -161,10 +169,10 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
 class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
-               SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+               Arith_SignlessIntegerOrIndexLike:$rhs,
                UnitAttr:$isExact)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
                           attr-dict `:` type($result) }];
@@ -293,8 +301,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
   }];
@@ -434,8 +442,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -477,8 +485,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -1573,8 +1581,8 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
 
 // Index cast can convert between memrefs of signless integers and indices too.
 def IndexCastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerOrIndexLike.predicate,
-        MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
+        Arith_SignlessIntegerOrIndexLike.predicate,
+        MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
     "signless-integer-like or memref of signless-integer">;
 
 def Arith_IndexCastOp
@@ -1737,8 +1745,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
   }];
 
   let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
-                       SignlessIntegerOrIndexLike:$lhs,
-                       SignlessIntegerOrIndexLike:$rhs);
+                       Arith_SignlessIntegerOrIndexLike:$lhs,
+                       Arith_SignlessIntegerOrIndexLike:$rhs);
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..87a15cba8a0e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,15 @@ include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted. Using
+// VectorOfAnyRankOf preserves the ::mlir::VectorType C++ class so that ODS
+// generates TypedValue<VectorType> (VectorValue) for op results.
+def AnyVectorNonZeroBitwidthIntElem : VectorOfAnyRankOf<[
+    Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -2454,8 +2463,8 @@ def Vector_ShapeCastOp :
 
 def Vector_BitCastOp :
   Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$result)>{
+    Arguments<(ins AnyVectorNonZeroBitwidthIntElem:$source)>,
+    Results<(outs AnyVectorNonZeroBitwidthIntElem:$result)>{
   let summary = "bitcast casts between vectors";
   let description = [{
     The bitcast operation casts between vectors of the same rank, the minor 1-D
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..409672d46c55e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,18 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
 def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                      "signless integer or index">;
 
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : Type<
+    And<[CPred<"$_self.isSignlessInteger()">,
+         CPred<"!$_self.isInteger(0)">]>,
+    "non-zero-bitwidth signless integer", "::mlir::IntegerType">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+    Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+        CPred<"::llvm::isa<::mlir::IndexType>($_self)">]>,
+    "non-zero-bitwidth signless integer or index">;
+
 // Floating point types.
 
 // Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..7713f93462b7f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3299,87 +3299,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
-// CHECK-LABEL: @extsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extsi = arith.extsi %c0 : i0 to i16
-  return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extui = arith.extui %c0 : i0 to i16
-  return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
-  %cFF = arith.constant 0xFF : i8
-  %trunc = arith.trunci %cFF : i8 to i0
-  return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shli = arith.shli %c0, %c0 : i0
-  return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrsi = arith.shrsi %c0, %c0 : i0
-  return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrui = arith.shrui %c0, %c0 : i0
-  return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %maxsi = arith.maxsi %c0, %c0 : i0
-  return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %minsi = arith.minsi %c0, %c0 : i0
-  return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
-  %c0 = arith.constant 0 : i0
-  %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
-  return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
 // CHECK-LABEL: @sequences_fastmath_contract
 // CHECK-SAME: ([[ARG0:%.+]]: bf16)
 // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..cf404b7c389f5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
 
 func.func @func_with_ops(f32) {
 ^bb0(%a : f32):
-  // expected-error@+1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+  // expected-error@+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %sf = arith.addi %a, %a : f32
 }
 
 // -----
 
 func.func @func_with_ops(%a: f32) {
-  // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+  // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %r:2 = arith.addui_extended %a, %a : f32, i32
   return
 }
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
 // Integer comparisons are not recognized for float types.
 func.func @func_with_ops(f32, f32) {
 ^bb0(%a : f32, %b : f32):
-  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
 }
 
 // -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
 // -----
 
 func.func @invalid_cmp_shape(%idx : () -> ()) {
-  // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+  // expected-error@+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
   %cmp = arith.cmpi eq, %idx, %idx : () -> ()
 
 // -----
@@ -877,3 +877,29 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
   %0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
   return
 }
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.addi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+  %0 = arith.addi %a, %b : vector<4xi0>
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+  // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.trunci %a : i32 to i0
+  return %0 : i0
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..51254b524920c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,12 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
   return
 }
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+  return %0 : vector<4xi0>
+}

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 26, 2026

@llvm/pr-subscribers-mlir-vector

Author: Mehdi Amini (joker-eph)

Changes

Add ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:

  • Add AnyNonZeroBitwidthSignlessInteger and AnyNonZeroBitwidthSignlessIntegerOrIndex to CommonTypeConstraints.td.
  • Introduce Arith_SignlessIntegerOrIndexLike in ArithOps.td that uses the new non-zero-width base type, and update SignlessFixedWidthIntegerLike likewise. Replace all uses of the shared SignlessIntegerOrIndexLike in ArithOps.td with the new dialect-local constraint.
  • Add AnyVectorNonZeroBitwidthIntElem in VectorOps.td and apply it to vector.bitcast source and result types.
  • Update arith/invalid.mlir and vector/invalid.mlir with explicit rejection tests for i0 scalar and vector element types.
  • Remove the now-invalid i0 canonicalization tests from arith/canonicalize.mlir.

Full diff: https://github.com/llvm/llvm-project/pull/183589.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+28-20)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+11-2)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+12)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-81)
  • (modified) mlir/test/Dialect/Arith/invalid.mlir (+30-4)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3d8517c56e784..6c8a40970bc0a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+    AnyNonZeroBitwidthSignlessIntegerOrIndex,
+    "signless-non-zero-bitwidth-integer-like">;
+
 // Base class for Arith dialect ops. Ops in this dialect have no memory
 // effects and can be applied element-wise to vectors and tensors.
 class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)>;
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
 
 // Base class for integer binary operations without undefined behavior.
 class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
 
 // Casts do not accept indices. Type constraint for signless-integer-like types
 // excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation for arithmetic operations.
 def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
-        AnySignlessInteger.predicate,
-        VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
-        TensorOf<[AnySignlessInteger]>.predicate]>,
+        AnyNonZeroBitwidthSignlessInteger.predicate,
+        VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+        TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
     "signless-fixed-width-integer-like">;
 
 // Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
     Arith_BinaryOp<mnemonic, traits #
       [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
       DefaultValuedAttr<
         Arith_IntegerOverflowAttr,
         "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
                           attr-dict `:` type($result) }];
@@ -161,10 +169,10 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
 class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
-               SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+               Arith_SignlessIntegerOrIndexLike:$rhs,
                UnitAttr:$isExact)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
                           attr-dict `:` type($result) }];
@@ -293,8 +301,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
   }];
@@ -434,8 +442,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -477,8 +485,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -1573,8 +1581,8 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
 
 // Index cast can convert between memrefs of signless integers and indices too.
 def IndexCastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerOrIndexLike.predicate,
-        MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
+        Arith_SignlessIntegerOrIndexLike.predicate,
+        MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
     "signless-integer-like or memref of signless-integer">;
 
 def Arith_IndexCastOp
@@ -1737,8 +1745,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
   }];
 
   let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
-                       SignlessIntegerOrIndexLike:$lhs,
-                       SignlessIntegerOrIndexLike:$rhs);
+                       Arith_SignlessIntegerOrIndexLike:$lhs,
+                       Arith_SignlessIntegerOrIndexLike:$rhs);
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..87a15cba8a0e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,15 @@ include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted. Using
+// VectorOfAnyRankOf preserves the ::mlir::VectorType C++ class so that ODS
+// generates TypedValue<VectorType> (VectorValue) for op results.
+def AnyVectorNonZeroBitwidthIntElem : VectorOfAnyRankOf<[
+    Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -2454,8 +2463,8 @@ def Vector_ShapeCastOp :
 
 def Vector_BitCastOp :
   Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$result)>{
+    Arguments<(ins AnyVectorNonZeroBitwidthIntElem:$source)>,
+    Results<(outs AnyVectorNonZeroBitwidthIntElem:$result)>{
   let summary = "bitcast casts between vectors";
   let description = [{
     The bitcast operation casts between vectors of the same rank, the minor 1-D
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..409672d46c55e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,18 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
 def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                      "signless integer or index">;
 
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : Type<
+    And<[CPred<"$_self.isSignlessInteger()">,
+         CPred<"!$_self.isInteger(0)">]>,
+    "non-zero-bitwidth signless integer", "::mlir::IntegerType">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+    Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+        CPred<"::llvm::isa<::mlir::IndexType>($_self)">]>,
+    "non-zero-bitwidth signless integer or index">;
+
 // Floating point types.
 
 // Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..7713f93462b7f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3299,87 +3299,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
-// CHECK-LABEL: @extsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extsi = arith.extsi %c0 : i0 to i16
-  return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extui = arith.extui %c0 : i0 to i16
-  return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
-  %cFF = arith.constant 0xFF : i8
-  %trunc = arith.trunci %cFF : i8 to i0
-  return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shli = arith.shli %c0, %c0 : i0
-  return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrsi = arith.shrsi %c0, %c0 : i0
-  return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrui = arith.shrui %c0, %c0 : i0
-  return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %maxsi = arith.maxsi %c0, %c0 : i0
-  return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %minsi = arith.minsi %c0, %c0 : i0
-  return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
-  %c0 = arith.constant 0 : i0
-  %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
-  return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
 // CHECK-LABEL: @sequences_fastmath_contract
 // CHECK-SAME: ([[ARG0:%.+]]: bf16)
 // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..cf404b7c389f5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
 
 func.func @func_with_ops(f32) {
 ^bb0(%a : f32):
-  // expected-error@+1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+  // expected-error@+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %sf = arith.addi %a, %a : f32
 }
 
 // -----
 
 func.func @func_with_ops(%a: f32) {
-  // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+  // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %r:2 = arith.addui_extended %a, %a : f32, i32
   return
 }
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
 // Integer comparisons are not recognized for float types.
 func.func @func_with_ops(f32, f32) {
 ^bb0(%a : f32, %b : f32):
-  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
 }
 
 // -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
 // -----
 
 func.func @invalid_cmp_shape(%idx : () -> ()) {
-  // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+  // expected-error@+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
   %cmp = arith.cmpi eq, %idx, %idx : () -> ()
 
 // -----
@@ -877,3 +877,29 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
   %0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
   return
 }
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.addi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+  %0 = arith.addi %a, %b : vector<4xi0>
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+  // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.trunci %a : i32 to i0
+  return %0 : i0
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..51254b524920c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,12 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
   return
 }
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+  return %0 : vector<4xi0>
+}

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 26, 2026

@llvm/pr-subscribers-mlir-arith

Author: Mehdi Amini (joker-eph)

Changes

Add ODS type constraints that exclude zero-bitwidth integers (i0) from operations in the arith and vector dialects. i0 has no meaningful arithmetic representation and operations on it can trigger undefined behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:

  • Add AnyNonZeroBitwidthSignlessInteger and AnyNonZeroBitwidthSignlessIntegerOrIndex to CommonTypeConstraints.td.
  • Introduce Arith_SignlessIntegerOrIndexLike in ArithOps.td that uses the new non-zero-width base type, and update SignlessFixedWidthIntegerLike likewise. Replace all uses of the shared SignlessIntegerOrIndexLike in ArithOps.td with the new dialect-local constraint.
  • Add AnyVectorNonZeroBitwidthIntElem in VectorOps.td and apply it to vector.bitcast source and result types.
  • Update arith/invalid.mlir and vector/invalid.mlir with explicit rejection tests for i0 scalar and vector element types.
  • Remove the now-invalid i0 canonicalization tests from arith/canonicalize.mlir.

Full diff: https://github.com/llvm/llvm-project/pull/183589.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+28-20)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+11-2)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+12)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-81)
  • (modified) mlir/test/Dialect/Arith/invalid.mlir (+30-4)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3d8517c56e784..6c8a40970bc0a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+    AnyNonZeroBitwidthSignlessIntegerOrIndex,
+    "signless-non-zero-bitwidth-integer-like">;
+
 // Base class for Arith dialect ops. Ops in this dialect have no memory
 // effects and can be applied element-wise to vectors and tensors.
 class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)>;
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
 
 // Base class for integer binary operations without undefined behavior.
 class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
 
 // Casts do not accept indices. Type constraint for signless-integer-like types
 // excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation for arithmetic operations.
 def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
-        AnySignlessInteger.predicate,
-        VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
-        TensorOf<[AnySignlessInteger]>.predicate]>,
+        AnyNonZeroBitwidthSignlessInteger.predicate,
+        VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+        TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
     "signless-fixed-width-integer-like">;
 
 // Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
     Arith_BinaryOp<mnemonic, traits #
       [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
       DefaultValuedAttr<
         Arith_IntegerOverflowAttr,
         "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
                           attr-dict `:` type($result) }];
@@ -161,10 +169,10 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
 class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
-               SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+               Arith_SignlessIntegerOrIndexLike:$rhs,
                UnitAttr:$isExact)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
                           attr-dict `:` type($result) }];
@@ -293,8 +301,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
   }];
@@ -434,8 +442,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -477,8 +485,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -1573,8 +1581,8 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
 
 // Index cast can convert between memrefs of signless integers and indices too.
 def IndexCastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerOrIndexLike.predicate,
-        MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
+        Arith_SignlessIntegerOrIndexLike.predicate,
+        MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
     "signless-integer-like or memref of signless-integer">;
 
 def Arith_IndexCastOp
@@ -1737,8 +1745,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
   }];
 
   let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
-                       SignlessIntegerOrIndexLike:$lhs,
-                       SignlessIntegerOrIndexLike:$rhs);
+                       Arith_SignlessIntegerOrIndexLike:$lhs,
+                       Arith_SignlessIntegerOrIndexLike:$rhs);
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..87a15cba8a0e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,15 @@ include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted. Using
+// VectorOfAnyRankOf preserves the ::mlir::VectorType C++ class so that ODS
+// generates TypedValue<VectorType> (VectorValue) for op results.
+def AnyVectorNonZeroBitwidthIntElem : VectorOfAnyRankOf<[
+    Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -2454,8 +2463,8 @@ def Vector_ShapeCastOp :
 
 def Vector_BitCastOp :
   Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$result)>{
+    Arguments<(ins AnyVectorNonZeroBitwidthIntElem:$source)>,
+    Results<(outs AnyVectorNonZeroBitwidthIntElem:$result)>{
   let summary = "bitcast casts between vectors";
   let description = [{
     The bitcast operation casts between vectors of the same rank, the minor 1-D
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..409672d46c55e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,18 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
 def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                      "signless integer or index">;
 
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : Type<
+    And<[CPred<"$_self.isSignlessInteger()">,
+         CPred<"!$_self.isInteger(0)">]>,
+    "non-zero-bitwidth signless integer", "::mlir::IntegerType">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+    Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+        CPred<"::llvm::isa<::mlir::IndexType>($_self)">]>,
+    "non-zero-bitwidth signless integer or index">;
+
 // Floating point types.
 
 // Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..7713f93462b7f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3299,87 +3299,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
-// CHECK-LABEL: @extsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extsi = arith.extsi %c0 : i0 to i16
-  return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extui = arith.extui %c0 : i0 to i16
-  return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
-  %cFF = arith.constant 0xFF : i8
-  %trunc = arith.trunci %cFF : i8 to i0
-  return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shli = arith.shli %c0, %c0 : i0
-  return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrsi = arith.shrsi %c0, %c0 : i0
-  return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrui = arith.shrui %c0, %c0 : i0
-  return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %maxsi = arith.maxsi %c0, %c0 : i0
-  return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %minsi = arith.minsi %c0, %c0 : i0
-  return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
-  %c0 = arith.constant 0 : i0
-  %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
-  return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
 // CHECK-LABEL: @sequences_fastmath_contract
 // CHECK-SAME: ([[ARG0:%.+]]: bf16)
 // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..cf404b7c389f5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
 
 func.func @func_with_ops(f32) {
 ^bb0(%a : f32):
-  // expected-error@+1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+  // expected-error@+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %sf = arith.addi %a, %a : f32
 }
 
 // -----
 
 func.func @func_with_ops(%a: f32) {
-  // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+  // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %r:2 = arith.addui_extended %a, %a : f32, i32
   return
 }
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
 // Integer comparisons are not recognized for float types.
 func.func @func_with_ops(f32, f32) {
 ^bb0(%a : f32, %b : f32):
-  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
 }
 
 // -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
 // -----
 
 func.func @invalid_cmp_shape(%idx : () -> ()) {
-  // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+  // expected-error@+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
   %cmp = arith.cmpi eq, %idx, %idx : () -> ()
 
 // -----
@@ -877,3 +877,29 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
   %0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
   return
 }
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.addi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+  %0 = arith.addi %a, %b : vector<4xi0>
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+  // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.trunci %a : i32 to i0
+  return %0 : i0
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..51254b524920c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,12 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
   return
 }
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+  return %0 : vector<4xi0>
+}

@kuhar
Copy link
Copy Markdown
Member

kuhar commented Feb 26, 2026

@kuhar : is this what you had in mind? This may deserve an RFC though.

Yes, exactly. I strongly prefer this over band aid patches that fix non-zero bitwdith assumptions scattered across the codebase.

I'm not sure if an RFC is necessary -- I think i0 was allowed unintentionally after most of these dialects was already in place and we missed ODS constraints. I'm not aware of anyone using these, but if you want to create one, it should be pretty straightforward.

@kuhar
Copy link
Copy Markdown
Member

kuhar commented Feb 26, 2026

Actually, we do have tests for i0, so someone must have considered it at some point -- seems like an RFC would make sense after all.

@banach-space
Copy link
Copy Markdown
Contributor

Add AnyVectorNonZeroBitwidthIntElem in VectorOps.td and apply it to vector.bitcast source and result types.

Why only vector.bitcast?

Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects.  i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:
- Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over
  `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex`
  to CommonTypeConstraints.td.
- Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps
  `AnyNonZeroBitwidthSignlessIntegerOrIndex` via
  `TypeOrValueSemanticsContainer`, and update `SignlessFixedWidthIntegerLike`
  to use `AnyNonZeroBitwidthSignlessInteger`.  Replace all uses of the
  shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new
  dialect-local constraint.
- Update `IndexCastTypeConstraint` to use `Arith_SignlessIntegerOrIndexLike`.
- Update `BitcastTypeConstraint` to exclude i0 by composing the already-
  defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints,
  keeping the definition compact (3 alternatives instead of 7).
- Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in
  VectorOps.td and apply them to `vector.contract`, `vector.reduction`,
  `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and
  `vector.scan`.
- Update arith/invalid.mlir with explicit i0 rejection tests covering all
  integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
  bitcast, index_cast, index_castui) for both scalar and vector<N> forms.
- Update vector/invalid.mlir with i0 rejection tests for all covered ops.
- Remove the now-invalid i0 canonicalization tests from
  arith/canonicalize.mlir.
@joker-eph
Copy link
Copy Markdown
Contributor Author

I added this mention for the arith dialect:

    Manipulating value with type `i0` isn't supported in this dialect at the
    moment and is considered invalid. This can change in the future if some
    motivating use-cases are presented.

I think the RFC is converged enough to proceed, happy to roll this back if/when we get a good use-case to reverse the tradeoff analysis!

@joker-eph joker-eph merged commit 7f04494 into llvm:main Mar 4, 2026
10 checks passed
@joker-eph joker-eph deleted the i0 branch March 4, 2026 11:34
sahas3 pushed a commit to sahas3/llvm-project that referenced this pull request Mar 4, 2026
…lvm#183589)

Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects.  i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:
- Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over
  `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex`
  to CommonTypeConstraints.td.
- Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps
  `AnyNonZeroBitwidthSignlessIntegerOrIndex` via
`TypeOrValueSemanticsContainer`, and update
`SignlessFixedWidthIntegerLike`
  to use `AnyNonZeroBitwidthSignlessInteger`.  Replace all uses of the
  shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new
  dialect-local constraint.
- Update `IndexCastTypeConstraint` to use
`Arith_SignlessIntegerOrIndexLike`.
- Update `BitcastTypeConstraint` to exclude i0 by composing the already-
  defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints,
  keeping the definition compact (3 alternatives instead of 7).
- Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in
  VectorOps.td and apply them to `vector.contract`, `vector.reduction`,
  `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and
  `vector.scan`.
- Update arith/invalid.mlir with explicit i0 rejection tests covering
all
integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
bitcast, index_cast, index_castui) for both scalar and vector<N> forms.
- Update vector/invalid.mlir with i0 rejection tests for all covered
ops.
- Remove the now-invalid i0 canonicalization tests from
  arith/canonicalize.mlir.

Fixes llvm#177822
Fixes llvm#179266
Fixes llvm#180463
Fixes llvm#181532

See also
https://discourse.llvm.org/t/rfc-reject-i0-integer-type-in-arith-and-vector-ops/90011
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
…lvm#183589)

Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects.  i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:
- Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over
  `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex`
  to CommonTypeConstraints.td.
- Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps
  `AnyNonZeroBitwidthSignlessIntegerOrIndex` via
`TypeOrValueSemanticsContainer`, and update
`SignlessFixedWidthIntegerLike`
  to use `AnyNonZeroBitwidthSignlessInteger`.  Replace all uses of the
  shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new
  dialect-local constraint.
- Update `IndexCastTypeConstraint` to use
`Arith_SignlessIntegerOrIndexLike`.
- Update `BitcastTypeConstraint` to exclude i0 by composing the already-
  defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints,
  keeping the definition compact (3 alternatives instead of 7).
- Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in
  VectorOps.td and apply them to `vector.contract`, `vector.reduction`,
  `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and
  `vector.scan`.
- Update arith/invalid.mlir with explicit i0 rejection tests covering
all
integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
bitcast, index_cast, index_castui) for both scalar and vector<N> forms.
- Update vector/invalid.mlir with i0 rejection tests for all covered
ops.
- Remove the now-invalid i0 canonicalization tests from
  arith/canonicalize.mlir.

Fixes llvm#177822
Fixes llvm#179266
Fixes llvm#180463
Fixes llvm#181532

See also
https://discourse.llvm.org/t/rfc-reject-i0-integer-type-in-arith-and-vector-ops/90011
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment