From 309d256da71a53325b40c8ec26111e5065845e90 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:25:54 -0500 Subject: [PATCH 1/8] add documentation add arguments and return value docstrings --- src/lib.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 96 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 29d6f1c..93d8701 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,16 +2,26 @@ use reikna::totient::totient; use reikna::factor::quick_factorize; use std::collections::HashMap; -// Modular arithmetic functions using i64 +/// Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { (a + b) % p } +/// Modular multiplication fn mod_mul(a: i64, b: i64, p: i64) -> i64 { (a * b) % p } -pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { +/// Modular exponentiation +/// # Arguments +/// +/// * `base` - Base of the exponentiation. +/// * `exp` - Exponent. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// The result of the exponentiation modulo `p`. +fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { let mut result = 1; base %= p; while exp > 0 { @@ -24,6 +34,7 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { result } +/// Extended Euclidean algorithm fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { if b == 0 { (a, 1, 0) // gcd, x, y @@ -33,7 +44,8 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { } } -pub fn mod_inv(a: i64, modulus: i64) -> i64 { +/// Compute the modular inverse of a modulo modulus +fn mod_inv(a: i64, modulus: i64) -> i64 { let (gcd, x, _) = extended_gcd(a, modulus); if gcd != 1 { panic!("{} and {} are not coprime, no inverse exists", a, modulus); @@ -41,7 +53,14 @@ pub fn mod_inv(a: i64, modulus: i64) -> i64 { (x % modulus + modulus) % modulus // Ensure a positive result } -// Compute n-th root of unity (omega) for p not necessarily prime +/// Compute n-th root of unity (omega) for p not necessarily prime +/// # Arguments +/// +/// * `modulus` - Modulus. n must divide each prime power factor. +/// * `n` - Order of the root of unity. +/// +/// # Returns +/// The n-th root of unity modulo `modulus`. pub fn omega(modulus: i64, n: usize) -> i64 { let factors = factorize(modulus as i64); if factors.len() == 1 { @@ -56,7 +75,15 @@ pub fn omega(modulus: i64, n: usize) -> i64 { } } -// Forward transform using NTT, output bit-reversed +/// Forward transform using NTT, output bit-reversed +/// # Arguments +/// +/// * `a` - Input vector. +/// * `omega` - Primitive root of unity modulo `p`. +/// * `n` - Length of the input vector and the result. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { let mut result = a.to_vec(); let mut step = n/2; @@ -77,7 +104,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { result } -// Inverse transform using INTT, input bit-reversed +/// Inverse transform using INTT, input bit-reversed +/// # Arguments +/// +/// * `a` - Input vector (bit-reversed). +/// * `omega` - Primitive root of unity modulo `p`. +/// * `n` - Length of the input vector and the result. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// A vector representing the inverse NTT of the input vector. pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { let omega_inv = mod_inv(omega, p); let n_inv = mod_inv(n as i64, p); @@ -103,7 +139,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { .collect() } -// Naive polynomial multiplication +/// Naive polynomial multiplication +/// # Arguments +/// +/// * `a` - First polynomial (as a vector of coefficients). +/// * `b` - Second polynomial (as a vector of coefficients). +/// * `n` - Length of the polynomials and the result. +/// * `p` - Prime modulus for the operations. +/// +/// # Returns +/// A vector representing the polynomial product modulo `p`. pub fn polymul(a: &Vec, b: &Vec, n: i64, p: i64) -> Vec { let mut result = vec![0; n as usize]; for i in 0..a.len() { @@ -145,7 +190,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec HashMap { let mut factors = HashMap::new(); for factor in quick_factorize(n as u64) { @@ -167,6 +219,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 { } /// Finds a primitive root modulo a prime p +/// # Arguments +/// +/// * `p` - Prime modulus. +/// +/// # Returns +/// A primitive root modulo `p`. fn primitive_root_mod_p(p: i64) -> i64 { let phi = p - 1; let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities @@ -179,7 +237,16 @@ fn primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -// the Chinese remainder theorem for two moduli +/// the Chinese remainder theorem for two moduli +/// # Arguments +/// +/// * `a1` - First residue. +/// * `n1` - First modulus. +/// * `a2` - Second residue. +/// * `n2` - Second modulus. +/// +/// # Returns +/// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2). pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { let n = n1 * n2; let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 @@ -188,10 +255,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { if x < 0 { x + n } else { x } } -// computes an n^th root of unity modulo a composite modulus -// note we require that an n^th root of unity exists for each multiplicative group modulo p^e -// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus -// for the NTT, we require than a 2n^th root of unity exists +/// computes an n^th root of unity modulo a composite modulus +/// note we require that an n^th root of unity exists for each multiplicative group modulo p^e +/// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus +/// for the NTT, we require than a 2n^th root of unity exists +/// # Arguments +/// +/// * `modulus` - Modulus. n must divide each prime power factor. +/// * `n` - Order of the root of unity. +/// +/// # Returns +/// The n-th root of unity modulo `modulus`. pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); let mut result = 1; @@ -202,7 +276,15 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 { result } -//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n +/// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n +/// # Arguments +/// +/// * `omega` - n-th root of unity. +/// * `n` - Order of the root of unity. +/// * `modulus` - Modulus. +/// +/// # Returns +/// True if the root of unity satisfies the condition. pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)"); From d6891ff22f0e8d04a6a50fb465dd0e619173260e Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:34:36 -0500 Subject: [PATCH 2/8] add docstring for primitive root --- src/lib.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 93d8701..9424b0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,14 @@ fn factorize(n: i64) -> HashMap { } /// Fast computation of a primitive root mod p^e +/// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p +/// # Arguments +/// +/// * `p` - Prime modulus. +/// * `e` - Exponent. +/// +/// # Returns +/// A primitive root modulo `p^e`. pub fn primitive_root(p: i64, e: u32) -> i64 { let g = primitive_root_mod_p(p); let mut g_lifted = g; // Lift it to p^e @@ -257,7 +265,7 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { /// computes an n^th root of unity modulo a composite modulus /// note we require that an n^th root of unity exists for each multiplicative group modulo p^e -/// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus +/// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus /// for the NTT, we require than a 2n^th root of unity exists /// # Arguments /// From bfeb0a901dcd5bc95e80a98402980cdc036b7207 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:36:37 -0500 Subject: [PATCH 3/8] add extended_gcd docstring --- src/lib.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 9424b0f..8e8e954 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,13 @@ fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { } /// Extended Euclidean algorithm +/// # Arguments +/// +/// * `a` - First number. +/// * `b` - Second number. +/// +/// # Returns +/// A tuple with the greatest common divisor and the Bézout coefficients. fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { if b == 0 { (a, 1, 0) // gcd, x, y From 83f3032e59b8528088e1d75e78632810693f99c1 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:40:45 -0500 Subject: [PATCH 4/8] add docstring test for omega --- src/lib.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 8e8e954..831eb00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ fn mod_mul(a: i64, b: i64, p: i64) -> i64 { /// /// # Returns /// The result of the exponentiation modulo `p`. -fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { +pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { let mut result = 1; base %= p; while exp > 0 { @@ -68,6 +68,20 @@ fn mod_inv(a: i64, modulus: i64) -> i64 { /// /// # Returns /// The n-th root of unity modulo `modulus`. +/// +/// # Examples +/// +/// ``` +/// // For modulus = 17^2 = 289 and n = 8, we compute an 8th root of unity. +/// let modulus = 17 * 17; +/// let n = 8; +/// +/// // Compute the omega for the given modulus and order. +/// let omega = ntt::omega(modulus, n); +/// +/// // Verify that omega^n is congruent to 1 modulo modulus. +/// assert_eq!(ntt::mod_exp(omega, n as i64, modulus), 1); +/// ``` pub fn omega(modulus: i64, n: usize) -> i64 { let factors = factorize(modulus as i64); if factors.len() == 1 { From d78352660d61b193bf2aa43ca69d250d8a512fa6 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:48:57 -0500 Subject: [PATCH 5/8] add test cases to docstring --- src/lib.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 831eb00..62f18e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ fn mod_mul(a: i64, b: i64, p: i64) -> i64 { /// /// # Returns /// The result of the exponentiation modulo `p`. -pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { +fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { let mut result = 1; base %= p; while exp > 0 { @@ -72,15 +72,16 @@ fn mod_inv(a: i64, modulus: i64) -> i64 { /// # Examples /// /// ``` -/// // For modulus = 17^2 = 289 and n = 8, we compute an 8th root of unity. +/// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity. /// let modulus = 17 * 17; /// let n = 8; -/// -/// // Compute the omega for the given modulus and order. /// let omega = ntt::omega(modulus, n); -/// -/// // Verify that omega^n is congruent to 1 modulo modulus. -/// assert_eq!(ntt::mod_exp(omega, n as i64, modulus), 1); +/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus)); +/// +/// // For modulus = 17*41*73, we compute and verify an 8th root of unity. +/// let modulus = 17*41*73; +/// let omega = ntt::omega(modulus, n); +/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus)); /// ``` pub fn omega(modulus: i64, n: usize) -> i64 { let factors = factorize(modulus as i64); From 94cc5aa8cd8543af5d8d7f89ef94720ac509e99a Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:55:55 -0500 Subject: [PATCH 6/8] add doctest for ntt --- src/lib.rs | 17 +++++++++++++++-- src/main.rs | 1 + 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 62f18e3..7a5ae11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,6 +106,20 @@ pub fn omega(modulus: i64, n: usize) -> i64 { /// * `p` - Prime modulus for the operations. /// /// # Returns +/// A vector representing the NTT of the input vector. +/// +/// # Examples +/// +/// ``` +/// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p +/// let n: usize = 8; // Length of the NTT (must be a power of 2) +/// let omega = ntt::omega(modulus, n); // n-th root of unity +/// let mut a = vec![1, 2, 3, 4]; +/// a.resize(n, 0); +/// // Perform the forward NTT +/// let a_ntt = ntt::ntt(&a, omega, n, modulus); +/// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15]; +/// assert_eq!(a_ntt, a_ntt_expected); pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { let mut result = a.to_vec(); let mut step = n/2; @@ -319,5 +333,4 @@ pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)"); true -} - +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e7bb1f6..3a9ba1c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ fn main() { // Perform the forward NTT let a_ntt = ntt(&a, omega, n, modulus); + println!("a_ntt = {:?}", a_ntt); let b_ntt = ntt(&b, omega, n, modulus); // Perform the inverse NTT on the transformed A for verification From 77f4ce25f387c3ae041f40bdda08f221fd1994e0 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 14:58:52 -0500 Subject: [PATCH 7/8] add doctest for primitive root --- src/lib.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 7a5ae11..c448252 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ fn mod_mul(a: i64, b: i64, p: i64) -> i64 { /// /// # Returns /// The result of the exponentiation modulo `p`. -fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { +pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { let mut result = 1; base %= p; while exp > 0 { @@ -251,6 +251,15 @@ fn factorize(n: i64) -> HashMap { /// /// # Returns /// A primitive root modulo `p^e`. +/// +/// # Examples +/// +/// ``` +/// // For p = 17 and e = 2, we compute a primitive root modulo 289. +/// let p = 17; +/// let e = 2; +/// let g = ntt::primitive_root(p, e); +/// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1); pub fn primitive_root(p: i64, e: u32) -> i64 { let g = primitive_root_mod_p(p); let mut g_lifted = g; // Lift it to p^e From d9896ca6b47f597eb3d785d2f5dd1dd14c23a98d Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Mon, 17 Feb 2025 15:02:24 -0500 Subject: [PATCH 8/8] remove println! --- src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 3a9ba1c..e7bb1f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,6 @@ fn main() { // Perform the forward NTT let a_ntt = ntt(&a, omega, n, modulus); - println!("a_ntt = {:?}", a_ntt); let b_ntt = ntt(&b, omega, n, modulus); // Perform the inverse NTT on the transformed A for verification