diff --git a/src/lib.rs b/src/lib.rs index 91f4696..9237863 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -156,10 +156,12 @@ fn factorize(n: i64) -> HashMap { /// Fast computation of a primitive root mod p^e pub fn primitive_root(p: i64, e: u32) -> i64 { + println!("primitive_root called"); let g = primitive_root_mod_p(p); let mut g_lifted = g; // Lift it to p^e for _ in 1..e { - if g_lifted.pow((p - 1) as u32) % p.pow(e) == 1 { + println!("g_lifted: {}", g_lifted); + if mod_exp(g_lifted, p-1, p.pow(e)) == 1 { g_lifted += p.pow(e - 1); } } diff --git a/src/test.rs b/src/test.rs index 25ba447..1a6edfe 100644 --- a/src/test.rs +++ b/src/test.rs @@ -26,24 +26,19 @@ mod tests { #[test] fn test_polymul_ntt_square_modulus() { - let modulus: i64 = 17*17; // Prime modulus + let moduli = [17*17, 12289*12289]; // Different moduli to test let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(modulus, n); // n-th root of unity - - // Input polynomials (padded to length `n`) - let mut a = vec![1, 2, 3, 4]; - let mut b = vec![5, 6, 7, 8]; - a.resize(n, 0); - b.resize(n, 0); - - // Perform the standard polynomial multiplication - let c_std = polymul(&a, &b, n as i64, modulus); - - // Perform the NTT-based polynomial multiplication - let c_fast = polymul_ntt(&a, &b, n, modulus, omega); - // Ensure both methods produce the same result - assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match"); + for &modulus in &moduli { + let omega = omega(modulus, n); // n-th root of unity + let mut a = vec![1, 2, 3, 4]; + let mut b = vec![5, 6, 7, 8]; + a.resize(n, 0); + b.resize(n, 0); + let c_std = polymul(&a, &b, n as i64, modulus); + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); + assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match"); + } } #[test]