Skip to content

Commit 01029b4

Browse files
committed
ATC001-C: Implement NTT
1 parent a8bf065 commit 01029b4

File tree

3 files changed

+773
-13
lines changed

3 files changed

+773
-13
lines changed

atc001-c/Main.hs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,22 @@ main = do
3737
-- Fast Fourier Transform (FFT)
3838
--
3939

40+
halve :: G.Vector vec a => vec a -> vec a
41+
halve v = let n = G.length v
42+
in G.generate (n `quot` 2) $ \j -> v G.! (j * 2)
43+
4044
fft :: forall vec a. (Num a, G.Vector vec a)
41-
=> vec a -- ^ For a primitive n-th root of unity @u@, @[1,u,u^2 .. u^(n-1)]@
45+
=> [vec a] -- ^ For a primitive n-th root of unity @u@, @iterate halve [1,u,u^2 .. u^(n-1)]@
4246
-> vec a -- ^ a polynomial of length n (= 2^k for some k)
4347
-> vec a
44-
fft u f | n == 1 = f
45-
| otherwise = let !n2 = n `quot` 2
46-
r0, r1', u2, t0, t1' :: vec a
47-
r0 = G.generate n2 $ \j -> (f G.! j) + (f G.! (j + n2))
48-
r1' = G.generate n2 $ \j -> ((f G.! j) - (f G.! (j + n2))) * u G.! j
49-
!u2 = G.generate n2 $ \j -> u G.! (j * 2)
50-
!t0 = fft u2 r0
51-
!t1' = fft u2 r1'
52-
in G.generate n $ \j -> if even j then t0 G.! (j `quot` 2) else t1' G.! (j `quot` 2)
48+
fft (u:u2) f | n == 1 = f
49+
| otherwise = let !n2 = n `quot` 2
50+
r0, r1', t0, t1' :: vec a
51+
r0 = G.generate n2 $ \j -> (f G.! j) + (f G.! (j + n2))
52+
r1' = G.generate n2 $ \j -> ((f G.! j) - (f G.! (j + n2))) * u G.! j
53+
!t0 = fft u2 r0
54+
!t1' = fft u2 r1'
55+
in G.generate n $ \j -> if even j then t0 G.! (j `quot` 2) else t1' G.! (j `quot` 2)
5356
where n = G.length f
5457

5558
mulFFT :: U.Vector Int -> U.Vector Int -> U.Vector Int
@@ -59,6 +62,7 @@ mulFFT !f !g = let n' = U.length f + U.length g - 2
5962
n = bit k
6063
u :: U.Vector (Complex Double)
6164
u = U.generate n $ \j -> cis (fromIntegral j * (2 * pi / fromIntegral n))
65+
us = iterate halve u
6266
f' = U.generate n $ \j -> if j < U.length f then
6367
fromIntegral (f U.! j)
6468
else
@@ -67,10 +71,10 @@ mulFFT !f !g = let n' = U.length f + U.length g - 2
6771
fromIntegral (g U.! j)
6872
else
6973
0
70-
f'' = fft u f'
71-
g'' = fft u g'
74+
f'' = fft us f'
75+
g'' = fft us g'
7276
fg = U.generate n $ \j -> (f'' U.! j) * (g'' U.! j)
73-
fg' = fft (U.map conjugate u) fg
77+
fg' = fft (map (U.map conjugate) us) fg
7478
in U.generate n $ \j -> round (realPart (fg' U.! j) / fromIntegral n)
7579

7680
--

0 commit comments

Comments
 (0)