Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added FFTMod #67

Merged
merged 50 commits into from Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6011aa5
Added FFTMod
Chillee Apr 23, 2019
ba0cc27
Shortened a bit
Chillee Apr 23, 2019
bba7a32
Shortened FFTMod and refactored roots code into FFT itself
Chillee Apr 25, 2019
a2b9413
Updated formatting
Chillee Apr 25, 2019
e83e65b
removed another 2 lines
Chillee Apr 25, 2019
5b5a4f4
Moved FFTMod to different file and fixed -Wconversion errors
Chillee Apr 25, 2019
12f0a0a
Updated headers for FFTMod
Chillee Apr 25, 2019
75ba4da
Updated header
Chillee Apr 25, 2019
81e2ae8
Moved numerical precision commnets to description
Chillee Apr 25, 2019
77adbd4
Merge branch 'master' into fftmod
Chillee Apr 25, 2019
5243182
Fixed typo
Chillee Apr 26, 2019
ba5bc7d
Merge branch 'fftmod' of github.com:Chillee/kactl into fftmod
Chillee Apr 26, 2019
ace08ab
Made things fit within 63 columns
Chillee Apr 26, 2019
7d26dee
Fixed some formatting issues
Chillee Apr 26, 2019
7c16aeb
Fixed rep space issues
Chillee Apr 26, 2019
c77b038
Fixed spacing issues
Chillee Apr 26, 2019
fe7b393
Fixed formatting issues
Chillee Apr 26, 2019
64ef0ab
Modified header
Chillee Apr 26, 2019
f5ecd33
Switched to long double for roots calculations due to precision issues
Chillee Apr 26, 2019
b7eb861
Updated headers with correct error bounds
Chillee Apr 26, 2019
f072eb1
Removed one of the papers about accuracy
Chillee Apr 26, 2019
d5664e4
Fixed formatting
Chillee Apr 27, 2019
2ea261d
Merge branch 'master' of github.com:kth-competitive-programming/kactl…
Chillee Apr 27, 2019
cdd8367
removed extraneous spaces
Chillee Apr 27, 2019
99f761d
Fixed wconversion warning
Chillee Apr 27, 2019
b984354
Changed from vi to vl
Chillee Apr 27, 2019
e67097d
Fixed some more wconversion errors from the switch to vl
Chillee Apr 27, 2019
dcb957b
Fixed the extraneous use of Cd
Chillee Apr 29, 2019
93b71fd
Fixed some issues and simplified input API
Chillee Apr 29, 2019
5526de4
Removed n from function parameters
Chillee Apr 29, 2019
f37262a
Fixed formatting issue
Chillee Apr 30, 2019
5136a02
Updated in response to comments
Chillee Apr 30, 2019
7ec3cfc
Higher precision FFT
simonlindholm May 7, 2019
7d00189
Shorter description
simonlindholm May 7, 2019
c104ed3
Merge branch 'master' into fftmod
simonlindholm May 7, 2019
1d527d1
Preliminary description updates
simonlindholm May 8, 2019
a36aee5
Fix FFTMod after root computation updates
simonlindholm May 8, 2019
675d08c
Shave off a few chars
simonlindholm May 8, 2019
6e161a9
Naive FFTMod fuzz-test
simonlindholm May 8, 2019
5bd55d9
Remove recommendation about CRT
simonlindholm May 8, 2019
94c31ef
Update theoretical bounds
simonlindholm May 15, 2019
b8d00da
Remove complex<long double> typedef
simonlindholm May 15, 2019
2a2957b
Update test
simonlindholm May 15, 2019
12c374d
Proof of FFT-MOD bound was buggy :(
simonlindholm May 15, 2019
3f09df5
Description updates
simonlindholm May 20, 2019
df8c82c
Comment about long double perf
simonlindholm Jun 22, 2019
3cfc69e
Update FFT fuzz-test
simonlindholm Jun 23, 2019
597b5fa
Don't subtly break FFT-MOD on mod > 2^32
simonlindholm Jun 23, 2019
4251223
Merge branch 'master' into fftmod
simonlindholm Jun 24, 2019
0fa3607
Fuzz-test consistency
simonlindholm Jun 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 22 additions & 18 deletions content/numerical/FastFourierTransform.h
Expand Up @@ -3,22 +3,32 @@
* Date: 2019-01-09
* License: CC0
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf (do read, it's excellent)
Papers about accuracy: http://www.daemonology.net/papers/fft.pdf, http://www.cs.berkeley.edu/~fateman/papers/fftvsothers.pdf
For integers rounding works if $(|a| + |b|)\max(a, b) < \mathtt{\sim} 10^9$, or in theory maybe $10^6$.
* Description: fft(a, ...) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
Accuracy bound from http://www.daemonology.net/papers/fft.pdf
* Description: fft(a) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
For convolution of complex numbers or more than two vectors: FFT, multiply
pointwise, divide by n, reverse(start+1, end), FFT back.
For integers, consider using a number-theoretic transform instead, to avoid rounding issues.
* Time: O(N \log N) with $N = |A|+|B|-1$ ($\tilde 1s$ for $N=2^{22}$)
Rounding is safe if $(\sum a_i^2 + \sum b_i^2)\log_2{N} < 9\cdot10^{14}$
(in practice $10^{16}$; higher for random inputs).
Otherwise, use long doubles/NTT/FFTMod.
* Time: O(N \log N) with $N = |A|+|B|$ ($\tilde 1s$ for $N=2^{22}$)
* Status: somewhat tested
*/
#pragma once

typedef complex<double> C;
typedef vector<double> vd;

void fft(vector<C> &a, vector<C> &rt, vi& rev, int n) {
void fft(vector<C>& a) {
int n = sz(a), L = 31 - __builtin_clz(n);
static vector<complex<long double>> R(2, 1);
static vector<C> rt(2, 1); // (^ 10% faster if double)
for (static int k = 2; k < n; k *= 2) {
R.resize(n); rt.resize(n);
auto x = polar(1.0L, M_PIl / k); // M_PI, lower-case L
rep(i,k,2*k) rt[i] = R[i] = i&1 ? R[i/2] * x : R[i/2];
}
vi rev(n);
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
Expand All @@ -27,25 +37,19 @@ void fft(vector<C> &a, vector<C> &rt, vi& rev, int n) {
C z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]); /// exclude-line
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
}

vd conv(const vd& a, const vd& b) {
if (a.empty() || b.empty()) return {};
vd res(sz(a) + sz(b) - 1);
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
vector<C> in(n), out(n), rt(n, 1); vi rev(n);
rep(i,0,n) rev[i] = (rev[i/2] | (i&1) << L) / 2;
for (int k = 2; k < n; k *= 2) {
C z[] = {1, polar(1.0, M_PI / k)};
rep(i,k,2*k) rt[i] = rt[i/2] * z[i&1];
}
vector<C> in(n), out(n);
copy(all(a), begin(in));
rep(i,0,sz(b)) in[i].imag(b[i]);
fft(in, rt, rev, n);
fft(in);
trav(x, in) x *= x;
rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);
fft(out, rt, rev, n);
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4*n);
fft(out);
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);
return res;
}
37 changes: 37 additions & 0 deletions content/numerical/FastFourierTransformMod.h
@@ -0,0 +1,37 @@
/**
* Author: chilli
* Date: 2019-04-25
* License: CC0
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf
* Description: Higher precision FFT, can be used for convolutions modulo arbitrary integers
* as long as $N\log_2N\cdot \text{mod} < 8.6 \cdot 10^{14}$ (in practice $10^{16}$ or higher).
* Inputs must be in $[0, \text{mod})$.
* Time: O(N \log N), where $N = |A|+|B|$ (twice as slow as NTT or FFT)
* Status: somewhat tested
*/
#pragma once

#include "FastFourierTransform.h"

typedef vector<ll> vl;
template<int M> vl convMod(const vl &a, const vl &b) {
if (a.empty() || b.empty()) return {};
vl res(sz(a) + sz(b) - 1);
int B=32-__builtin_clz(sz(res)), n=1<<B, cut=int(sqrt(M));
vector<C> L(n), R(n), outs(n), outl(n);
rep(i,0,sz(a)) L[i] = C((int)a[i] / cut, (int)a[i] % cut);
rep(i,0,sz(b)) R[i] = C((int)b[i] / cut, (int)b[i] % cut);
fft(L), fft(R);
rep(i,0,n) {
int j = -i & (n - 1);
outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * n);
outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * n) / 1i;
}
fft(outl), fft(outs);
rep(i,0,sz(res)) {
ll av = ll(real(outl[i])+.5), cv = ll(imag(outs[i])+.5);
ll bv = ll(imag(outl[i])+.5) + ll(real(outs[i])+.5);
res[i] = ((av % M * cut + bv) % M * cut + cv) % M;
}
return res;
}
1 change: 0 additions & 1 deletion content/numerical/NumberTheoreticTransform.h
Expand Up @@ -5,7 +5,6 @@
* Source: based on KACTL's FFT
* Description: Can be used for convolutions modulo specific nice primes
* of the form $2^a b+1$, where the convolution result has size at most $2^a$.
* For other primes/integers, use three different primes and combine with CRT.
* Inputs must be in [0, mod).
* Time: O(N \log N)
* Status: fuzz-tested
Expand Down
1 change: 1 addition & 0 deletions content/numerical/chapter.tex
Expand Up @@ -19,5 +19,6 @@ \chapter{Numerical}
\kactlimport{Tridiagonal.h}
\section{Fourier transforms}
\kactlimport{FastFourierTransform.h}
\kactlimport{FastFourierTransformMod.h}
\kactlimport{NumberTheoreticTransform.h}
\kactlimport{FastSubsetTransform.h}
57 changes: 20 additions & 37 deletions fuzz-tests/numerical/FastFourierTransform.cpp
Expand Up @@ -11,50 +11,33 @@ typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;

typedef valarray<complex<double> > carray;
void fft(carray& x, carray& roots) {
int N = sz(x);
if (N <= 1) return;
carray even = x[slice(0, N/2, 2)];
carray odd = x[slice(1, N/2, 2)];
carray rs = roots[slice(0, N/2, 2)];
fft(even, rs);
fft(odd, rs);
rep(k,0,N/2) {
auto t = roots[k] * odd[k];
x[k ] = even[k] + t;
x[k+N/2] = even[k] - t;
}
}

typedef vector<double> vd;
vd conv(const vd& a, const vd& b) {
int s = sz(a) + sz(b) - 1, L = 32-__builtin_clz(s), n = 1<<L;
if (s <= 0) return {};
carray av(n), bv(n), roots(n);
rep(i,0,n) roots[i] = polar(1.0, -2 * M_PI * i / n);
copy(all(a), begin(av)); fft(av, roots);
copy(all(b), begin(bv)); fft(bv, roots);
roots = roots.apply(conj);
carray cv = av * bv; fft(cv, roots);
vd c(s); rep(i,0,s) c[i] = cv[i].real() / n;
return c;
}
#include "../../content/numerical/FastFourierTransform.h"

const double eps = 1e-8;
int main() {
int n = 8;
carray a(n), av(n), roots(n);
rep(i,0,n) a[i] = rand() % 10 - 5;
rep(i,0,n) roots[i] = polar(1.0, -2 * M_PI * i / n);
av = a;
fft(av, roots);
vector<C> a(n);
rep(i,0,n) a[i] = C(rand() % 10 - 5, rand() % 10 - 5);
auto aorig = a;
fft(a);
rep(k,0,n) {
complex<double> sum{};
C sum{};
rep(x,0,n) {
sum += a[x] * polar(1.0, -2 * M_PI * k * x / n);
sum += aorig[x] * polar(1.0, 2 * M_PI * k * x / n);
}
assert(norm(sum - a[k]) < 1e-6);
}

vd A(4), B(6);
trav(x, A) x = rand() / (RAND_MAX + 1.0) * 10 - 5;
trav(x, B) x = rand() / (RAND_MAX + 1.0) * 10 - 5;
vd C = conv(A, B);
rep(i,0,sz(A) + sz(B) - 1) {
double sum = 0;
rep(j,0,sz(A)) if (i - j >= 0 && i - j < sz(B)) {
sum += A[j] * B[i - j];
}
assert(abs(sum-av[k]) < eps);
assert(abs(sum - C[i]) < eps);
}
cout<<"Tests passed!"<<endl;
}
47 changes: 47 additions & 0 deletions fuzz-tests/numerical/FastFourierTransformMod.cpp
@@ -0,0 +1,47 @@
#include <bits/stdc++.h>
using namespace std;

#define rep(i, a, b) for(int i = a; i < int(b); ++i)
#define trav(a, v) for(auto& a : v)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()

typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;

const ll mod = 1000000007;

#include "../../content/numerical/FastFourierTransformMod.h"

vl simpleConv(vl a, vl b) {
if (a.empty() || b.empty()) return {};
int s = sz(a) + sz(b) - 1;
vl c(s);
rep(i,0,sz(a)) rep(j,0,sz(b))
c[i+j] = (c[i+j] + (ll)a[i] * b[j]) % mod;
trav(x, c) if (x < 0) x += mod;
return c;
}

int ra() {
static unsigned X;
X *= 123671231;
X += 1238713;
X ^= 1237618;
return (X >> 1);
}

int main() {
vl a, b;
rep(it,0,6000) {
a.resize(ra() % 100);
b.resize(ra() % 100);
trav(x, a) x = ra() % mod;
trav(x, b) x = ra() % mod;
auto v1 = simpleConv(a, b);
auto v2 = convMod<mod>(a, b);
assert(v1 == v2);
}
cout<<"Tests passed!"<<endl;
}