Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[polynomial] Move primitive root attribute to ntt/intt ops. #93227

Merged
merged 4 commits into from
May 30, 2024

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented May 23, 2024

Better design to put semantics on the ops, and in this case the ntt/intt op can lower in multiple ways depending on the polynomial ring modulus (it can need an nth root of unity for cyclic polymul -> ntt, or a 2nth root for negacyclic polymul -> ntt)

@llvmbot llvmbot added the mlir label May 23, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir

Author: Jeremy Kun (j2kun)

Changes

Patch is 20.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93227.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td (+12-5)
  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td (+27-6)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (+44-4)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (+19-21)
  • (modified) mlir/test/Dialect/Polynomial/canonicalization.mlir (+63-5)
  • (modified) mlir/test/Dialect/Polynomial/ops.mlir (+4-4)
  • (modified) mlir/test/Dialect/Polynomial/ops_errors.mlir (+11-25)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 3ef899d3376b1..7035219238dd5 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,8 +277,7 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
   Polynomial_IntPolynomialAttr
 ]>;
 
-// Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
     Example:
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
 
       `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
 
-    The choice of primitive root is determined by subsequent lowerings.
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins Polynomial_PolynomialType:$input);
+  let arguments = (ins
+    Polynomial_PolynomialType:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
     output polynomial at powers of a primitive `n`-th root of unity (see
     `polynomial.ntt`). The ring of the polynomial is taken from the required
     encoding attribute of the tensor.
+
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let arguments = (
+    ins RankedTensorOf<[AnyInteger]>:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index e5dbfa7fa21ee..2bfdc15ddd034 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -102,24 +102,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
-    OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
   );
   let assemblyFormat = "`<` struct(params) `>`";
   let builders = [
     AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
-              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
-              CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
       return $_get(
         coefficientTy.getContext(),
         coefficientTy,
         coefficientModulusAttr,
-        polynomialModulusAttr,
-        primitiveRootAttr);
+        polynomialModulusAttr);
     }]>,
   ];
 }
 
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+  let summary = "an attribute containing an integer and its degree as a root of unity";
+  let description = [{
+    A primitive root attribute stores an integer root `value` and an integer
+    `degree`, corresponding to a primitive root of unity of the given degree in
+    an unspecified ring.
+
+    This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+    to specify the root of unity used in lowering the transform.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+    ```
+  }];
+  let parameters = (ins
+    "::mlir::IntegerAttr":$value,
+    "::mlir::IntegerAttr":$degree
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
 #endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..f8216e1b2307d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -14,6 +14,10 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
+def Equal : Constraint<CPred<"$0 == $1">>;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -28,15 +32,51 @@ def SubAsAdd : Pat<
       (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
-  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
   (replaceWithValue $poly),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 def NTTAfterINTT : Pat<
-  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
   (replaceWithValue $tensor),
-  []
+  [(Equal $r1, $r2)]
+>;
+
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+  (Arith_AddIOp
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+  (Polynomial_AddOp
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+  (Arith_SubIOp
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
+>;
+def INTTOfSub : Pat<
+  (Polynomial_SubOp
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..58773a24181df 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
 }
 
 /// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
   // Root bitwidth may be 1 less then cmod.
   APInt r = APInt(root).zext(cmod.getBitWidth());
   assert(r.ule(cmod) && "root must be less than cmod");
+  unsigned upperBound = n.getZExtValue();
 
   APInt a = r;
-  for (size_t k = 1; k < n; k++) {
+  for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
     a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
 /// Verify that the types involved in an NTT or INTT operation are
 /// compatible.
 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
-                                 RankedTensorType tensorType) {
+                                 RankedTensorType tensorType,
+                                 std::optional<PrimitiveRootAttr> root) {
   Attribute encoding = tensorType.getEncoding();
   if (!encoding) {
     return op->emitOpError()
@@ -157,33 +159,29 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     return diag;
   }
 
-  if (!ring.getPrimitiveRoot()) {
-    return op->emitOpError()
-           << "ring type " << ring << " does not provide a primitive root "
-           << "of unity, which is required to express an NTT";
-  }
-
-  if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
-                                 ring.getCoefficientModulus().getValue())) {
-    return op->emitOpError()
-           << "ring type " << ring << " has a primitiveRoot attribute '"
-           << ring.getPrimitiveRoot()
-           << "' that is not a primitive root of the coefficient ring";
+  if (root.has_value()) {
+    APInt rootValue = root.value().getValue().getValue();
+    APInt rootDegree = root.value().getDegree().getValue();
+    APInt cmod = ring.getCoefficientModulus().getValue();
+    if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+      return op->emitOpError()
+             << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
+             << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
+             << rootDegree.getZExtValue();
+    }
   }
 
   return success();
 }
 
 LogicalResult NTTOp::verify() {
-  auto ring = getInput().getType().getRing();
-  auto tensorType = getOutput().getType();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+                     getOutput().getType(), getRoot());
 }
 
 LogicalResult INTTOp::verify() {
-  auto tensorType = getInput().getType();
-  auto ring = getOutput().getType().getRing();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+                     getInput().getType(), getRoot());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..354b76e3d9669 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
   // CHECK-NOT: polynomial.ntt
   // CHECK-NOT: polynomial.intt
   // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
-  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
-  %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+  %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
   // CHECK-NOT: polynomial.intt
   // CHECK-NOT: polynomial.ntt
   // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
-  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
-  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %t2 = arith.addi %t1, %t1 : !tensor_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %t2 : !tensor_ty
@@ -43,3 +44,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
   // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
   return %0 : !sub_ty
 }
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.subi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ff709960c50e9..c9bcb7f95b7d3 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,11 +11,11 @@
 #one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 
 #ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 module {
@@ -87,12 +87,12 @@ module {
   }
 
   func.func @test_ntt(%0 : !ntt_poly_ty) {
-    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
     return
   }
 
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
-    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return
   }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..f22b14897e98a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error@below {{expects a ring encoding to be provided to the tensor}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error@below {{tensor encoding is not a ring attribute}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
   return
 }
 
@@ -84,21 +84,21 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
   // expected-error@below {{not equivalent to the polynomial ring}}
-  %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1024xi32, #ring1> -> !poly_ty
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -106,21 +106,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
   // expected-error@below {{does not match output type}}
   // expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
-  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
-  return
-}
-
-// -----
-
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<ring=#ring>
-
-// CHECK-NOT: @test_invalid_ntt
-// CHECK-NOT: polynomial.ntt
-func.func @test_invalid_ntt(%0 : !poly_ty) {
-  // expected-error@below {{does not provide a primitive root of unity, which is required to express an NTT}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
   return
 }
 
@@ -128,13 +114,13 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root i...
[truncated]

Copy link

github-actions bot commented May 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

}]>,
];
}

def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
let summary = "an attribute containing an integer and its degree as a root of unity";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if mlir docs can render latex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like yes:
image

APInt rootValue = root.value().getValue().getValue();
APInt rootDegree = root.value().getDegree().getValue();
APInt cmod = ring.getCoefficientModulus().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove {, } for single statement if/fors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's questionable convention in the case of nested conditionals? up to you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed questionable, I usually lean towards having braces for multi-line bodies even if there is a single statement.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo my general distate for algebra (see what I did there...)

APInt rootValue = root.value().getValue().getValue();
APInt rootDegree = root.value().getDegree().getValue();
APInt cmod = ring.getCoefficientModulus().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed questionable, I usually lean towards having braces for multi-line bodies even if there is a single statement.

@j2kun j2kun merged commit 1f46729 into llvm:main May 30, 2024
7 checks passed
j2kun added a commit that referenced this pull request May 30, 2024
Rebased over #93227

---------

Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
ZenithalHourlyRate added a commit to ZenithalHourlyRate/llvm-project that referenced this pull request Oct 11, 2024
Related to llvm#93227
and google/heir#993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, which
is convenient if provided in ring attr.

As for different convolution pattern, to_tensor/tensor.cast/
from_tensor should be enough for changing primitiveRoot attr
in RingAttr
ZenithalHourlyRate added a commit to ZenithalHourlyRate/llvm-project that referenced this pull request Oct 11, 2024
Related to llvm#93227
and google/heir#993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, and it
is convenient for it to be provided in ring attr.

As for using different primitive root for the same polynomial,
to_tensor/tensor.cast/from_tensor should be enough for changing
primitiveRoot attribute in RingAttr.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants