From cc742e0ccde51a1d22086c4d819ca5d8074e90d8 Mon Sep 17 00:00:00 2001 From: doran2728 Date: Sat, 21 Mar 2026 19:58:32 +0300 Subject: [PATCH 1/3] remove redundant inputArray usage from forward.lean --- CompPoly/Univariate/NTT/Forward.lean | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/CompPoly/Univariate/NTT/Forward.lean b/CompPoly/Univariate/NTT/Forward.lean index 61977e9..4391ef0 100644 --- a/CompPoly/Univariate/NTT/Forward.lean +++ b/CompPoly/Univariate/NTT/Forward.lean @@ -22,10 +22,6 @@ namespace Forward variable {R : Type*} [Field R] -/-- Input coefficients packed as a fixed-size array over the domain. -/ -@[inline] def inputArray (D : Domain R) (p : CPolynomial.Raw R) : Array R := - Array.ofFn (fun i : D.Idx => p.coeff i.1) - /-- DFT/NTT formula at one output index. -/ @[inline] def nttAt (D : Domain R) (a : Array R) (k : D.Idx) : R := ∑ j : D.Idx, a.getD j.1 0 * D.omega ^ ((k : Nat) * (j : Nat)) @@ -36,7 +32,7 @@ variable {R : Type*} [Field R] /-- Spec-level forward NTT from a raw polynomial input. -/ @[inline] def forwardSpec (D : Domain R) (p : CPolynomial.Raw R) : Array R := - forwardArraySpec D (inputArray D p) + forwardArraySpec D p /-- Reverse the lowest `bits` bits of `i`. -/ def bitRevNat : Nat → Nat → Nat @@ -75,11 +71,7 @@ def runStages (D : Domain R) (a : Array R) : Array R := Id.run do /-- Intended fast implementation entry point. -/ @[inline] def forwardImpl (D : Domain R) (p : CPolynomial.Raw R) : Array R := - runStages D (bitRevPermute D (inputArray D p)) - -@[simp] theorem size_inputArray (D : Domain R) (p : CPolynomial.Raw R) : - (inputArray D p).size = D.n := by - simp [inputArray] + runStages D (bitRevPermute D p) @[simp] theorem size_forwardArraySpec (D : Domain R) (a : Array R) : (forwardArraySpec D a).size = D.n := by From ba21fdee405b2a25371177d395fdae480617f2ca Mon Sep 17 00:00:00 2001 From: doran2728 Date: Sat, 21 Mar 2026 19:59:12 +0300 Subject: [PATCH 2/3] remove redundant padding from domain.lean --- CompPoly/Univariate/NTT/Domain.lean | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/CompPoly/Univariate/NTT/Domain.lean b/CompPoly/Univariate/NTT/Domain.lean index 3df8130..3f17cb9 100644 --- a/CompPoly/Univariate/NTT/Domain.lean +++ b/CompPoly/Univariate/NTT/Domain.lean @@ -61,19 +61,6 @@ def requiredLength (p q : CPolynomial.Raw R) : Nat := def fits (D : Domain R) (p q : CPolynomial.Raw R) : Prop := requiredLength p q ≤ D.n -/-- Right-pad an array with zeros up to at least length `n`. -/ -def zeroPad (n : Nat) (a : Array R) : Array R := - a ++ Array.replicate (n - a.size) 0 - -@[simp] lemma size_zeroPad (n : Nat) (a : Array R) : - (zeroPad (R := R) n a).size = max n a.size := by - -- TODO: Replace with a direct arithmetic proof once helper lemmas are finalized. - sorry - -/-- Trim a polynomial and pad it to the domain size. -/ -def pad (D : Domain R) (p : CPolynomial.Raw R) : CPolynomial.Raw R := - zeroPad (R := R) D.n p.trim - /-- Truncate a polynomial to at most `m` coefficients. -/ def truncate (m : Nat) (p : CPolynomial.Raw R) : CPolynomial.Raw R := p.extract 0 m From 5136d131471aa0040629489a5a8b02151a59230e Mon Sep 17 00:00:00 2001 From: doran2728 Date: Sat, 21 Mar 2026 20:03:54 +0300 Subject: [PATCH 3/3] optimize the main for loop inside butterfly stages --- CompPoly/Univariate/NTT/Forward.lean | 22 +++++++++++----------- CompPoly/Univariate/NTT/Inverse.lean | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/CompPoly/Univariate/NTT/Forward.lean b/CompPoly/Univariate/NTT/Forward.lean index 4391ef0..f88230e 100644 --- a/CompPoly/Univariate/NTT/Forward.lean +++ b/CompPoly/Univariate/NTT/Forward.lean @@ -49,17 +49,17 @@ def butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : Array R := Id.ru let half : Nat := 2 ^ stage let wm := D.omega ^ (D.n / blockSize) let mut acc := a - for base in [0:D.n] do - if base % blockSize == 0 then - let mut w : R := 1 - for j in [0:half] do - let i0 := base + j - let i1 := i0 + half - let u := acc.getD i0 0 - let t := w * acc.getD i1 0 - acc := acc.set! i0 (u + t) - acc := acc.set! i1 (u - t) - w := w * wm + for block in [0:D.n / blockSize] do + let base := block * blockSize + let mut w : R := 1 + for j in [0:half] do + let i0 := base + j + let i1 := i0 + half + let u := acc.getD i0 0 + let t := w * acc.getD i1 0 + acc := acc.set! i0 (u + t) + acc := acc.set! i1 (u - t) + w := w * wm return acc /-- Run all radix-2 butterfly stages (target complexity: `O(n log n)`). -/ diff --git a/CompPoly/Univariate/NTT/Inverse.lean b/CompPoly/Univariate/NTT/Inverse.lean index e131cea..89fe483 100644 --- a/CompPoly/Univariate/NTT/Inverse.lean +++ b/CompPoly/Univariate/NTT/Inverse.lean @@ -45,17 +45,17 @@ def butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : Array R := Id.ru let half : Nat := 2 ^ stage let wm := D.omegaInv ^ (D.n / blockSize) let mut acc := a - for base in [0:D.n] do - if base % blockSize == 0 then - let mut w : R := 1 - for j in [0:half] do - let i0 := base + j - let i1 := i0 + half - let u := acc.getD i0 0 - let t := w * acc.getD i1 0 - acc := acc.set! i0 (u + t) - acc := acc.set! i1 (u - t) - w := w * wm + for block in [0:D.n / blockSize] do + let base := block * blockSize + let mut w : R := 1 + for j in [0:half] do + let i0 := base + j + let i1 := i0 + half + let u := acc.getD i0 0 + let t := w * acc.getD i1 0 + acc := acc.set! i0 (u + t) + acc := acc.set! i1 (u - t) + w := w * wm return acc /-- Run all radix-2 inverse butterfly stages. -/