diff --git a/formal_power_series/factorial_power.hpp b/formal_power_series/factorial_power.hpp index e997d757..0b9795d5 100644 --- a/formal_power_series/factorial_power.hpp +++ b/formal_power_series/factorial_power.hpp @@ -3,17 +3,16 @@ #include #include -// CUT begin // Convert factorial power -> sampling // [y[0], y[1], ..., y[N - 1]] -> \sum_i a_i x^\underline{i} // Complexity: O(N log N) template std::vector factorial_to_ys(const std::vector &as) { const int N = as.size(); std::vector exp(N, 1); - for (int i = 1; i < N; i++) exp[i] = T(i).facinv(); + for (int i = 1; i < N; i++) exp[i] = T::facinv(i); auto ys = nttconv(as, exp); ys.resize(N); - for (int i = 0; i < N; i++) ys[i] *= T(i).fac(); + for (int i = 0; i < N; i++) ys[i] *= T::fac(i); return ys; } @@ -22,9 +21,9 @@ template std::vector factorial_to_ys(const std::vector &as) { // Complexity: O(N log N) template std::vector ys_to_factorial(std::vector ys) { const int N = ys.size(); - for (int i = 1; i < N; i++) ys[i] *= T(i).facinv(); + for (int i = 1; i < N; i++) ys[i] *= T::facinv(i); std::vector expinv(N, 1); - for (int i = 1; i < N; i++) expinv[i] = T(i).facinv() * (i % 2 ? -1 : 1); + for (int i = 1; i < N; i++) expinv[i] = T::facinv(i) * (i % 2 ? -1 : 1); auto as = nttconv(ys, expinv); as.resize(N); return as; @@ -36,12 +35,12 @@ template std::vector shift_of_factorial(const std::vector &as, T const int N = as.size(); std::vector b(N, 1), c(N, 1); for (int i = 1; i < N; i++) b[i] = b[i - 1] * (shift - i + 1) * T(i).inv(); - for (int i = 0; i < N; i++) c[i] = as[i] * T(i).fac(); + for (int i = 0; i < N; i++) c[i] = as[i] * T::fac(i); std::reverse(c.begin(), c.end()); auto ret = nttconv(b, c); ret.resize(N); std::reverse(ret.begin(), ret.end()); - for (int i = 0; i < N; i++) ret[i] *= T(i).facinv(); + for (int i = 0; i < N; i++) ret[i] *= T::facinv(i); return ret; } diff --git a/formal_power_series/formal_power_series.hpp b/formal_power_series/formal_power_series.hpp index 724f48d6..a4990946 100644 --- a/formal_power_series/formal_power_series.hpp +++ b/formal_power_series/formal_power_series.hpp @@ -220,14 +220,14 @@ template struct FormalPowerSeries : std::vector { P shift(T c) const { const int n = (int)this->size(); P ret = *this; - for (int i = 0; i < n; i++) ret[i] *= T(i).fac(); + for (int i = 0; i < n; i++) ret[i] *= T::fac(i); std::reverse(ret.begin(), ret.end()); P exp_cx(n, 1); for (int i = 1; i < n; i++) exp_cx[i] = exp_cx[i - 1] * c * T(i).inv(); ret = ret * exp_cx; ret.resize(n); std::reverse(ret.begin(), ret.end()); - for (int i = 0; i < n; i++) ret[i] *= T(i).facinv(); + for (int i = 0; i < n; i++) ret[i] *= T::facinv(i); return ret; } diff --git a/formal_power_series/lagrange_interpolation.hpp b/formal_power_series/lagrange_interpolation.hpp index 12e11016..94f3e12b 100644 --- a/formal_power_series/lagrange_interpolation.hpp +++ b/formal_power_series/lagrange_interpolation.hpp @@ -11,7 +11,7 @@ template MODINT interpolate_iota(const std::vector ys, const int N = ys.size(); if (x_eval.val() < N) return ys[x_eval.val()]; std::vector facinv(N); - facinv[N - 1] = MODINT(N - 1).fac().inv(); + facinv[N - 1] = MODINT::facinv(N - 1); for (int i = N - 1; i > 0; i--) facinv[i - 1] = facinv[i] * i; std::vector numleft(N); MODINT numtmp = 1; diff --git a/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp b/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp index 5accfb09..70f7ebc0 100644 --- a/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp +++ b/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp @@ -20,7 +20,7 @@ MODINT sum_of_exponential_times_polynomial_limit(MODINT r, std::vector i MODINT ret = 0; rp = 1; for (int i = 0; i <= d; i++) { - ret += bs[d - i] * MODINT(d + 1).nCr(i) * rp; + ret += bs[d - i] * MODINT::binom(d + 1, i) * rp; rp *= -r; } return ret / MODINT(1 - r).pow(d + 1); diff --git a/formal_power_series/test/bernoulli_number.test.cpp b/formal_power_series/test/bernoulli_number.test.cpp index 7785169e..f660278a 100644 --- a/formal_power_series/test/bernoulli_number.test.cpp +++ b/formal_power_series/test/bernoulli_number.test.cpp @@ -10,5 +10,5 @@ int main() { using mint = ModInt<998244353>; FormalPowerSeries x({0, 1}); FormalPowerSeries b = ((x.exp(N + 2) - 1) >> 1).inv(N + 1); - for (int i = 0; i <= N; i++) printf("%d ", (b.coeff(i) * mint(i).fac()).val()); + for (int i = 0; i <= N; i++) printf("%d ", (b.coeff(i) * mint::fac(i)).val()); } diff --git a/formal_power_series/test/stirling_number_of_2nd.test.cpp b/formal_power_series/test/stirling_number_of_2nd.test.cpp index bd25a107..1e71cf1f 100644 --- a/formal_power_series/test/stirling_number_of_2nd.test.cpp +++ b/formal_power_series/test/stirling_number_of_2nd.test.cpp @@ -11,7 +11,7 @@ int main() { cin >> N; using mint = ModInt<998244353>; FormalPowerSeries a(N + 1); - a[N] = mint(N).fac().inv(); + a[N] = mint::facinv(N); for (int i = N - 1; i >= 0; i--) { a[i] = a[i + 1] * (i + 1); } auto b = a; for (int i = 0; i <= N; i++) { a[i] *= mint(i).pow(N), b[i] *= (i % 2 ? -1 : 1); } diff --git a/modint.hpp b/modint.hpp index f96deb6b..c8cad0d3 100644 --- a/modint.hpp +++ b/modint.hpp @@ -5,6 +5,7 @@ #include template struct ModInt { + static_assert(md > 1); using lint = long long; constexpr static int mod() { return md; } static int get_primitive_root() { @@ -102,39 +103,50 @@ template struct ModInt { return this->pow(md - 2); } } - constexpr ModInt fac() const { - while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2); - return facs[this->val_]; + + constexpr static ModInt fac(int n) { + assert(n >= 0); + if (n >= md) return ModInt(0); + while (n >= int(facs.size())) _precalculation(facs.size() * 2); + return facs[n]; } - constexpr ModInt facinv() const { - while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2); - return facinvs[this->val_]; + + constexpr static ModInt facinv(int n) { + assert(n >= 0); + if (n >= md) return ModInt(0); + while (n >= int(facs.size())) _precalculation(facs.size() * 2); + return facinvs[n]; } - constexpr ModInt doublefac() const { - lint k = (this->val_ + 1) / 2; - return (this->val_ & 1) ? ModInt(k * 2).fac() / (ModInt(2).pow(k) * ModInt(k).fac()) - : ModInt(k).fac() * ModInt(2).pow(k); + + constexpr static ModInt doublefac(int n) { + assert(n >= 0); + if (n >= md) return ModInt(0); + long long k = (n + 1) / 2; + return (n & 1) ? ModInt::fac(k * 2) / (ModInt(2).pow(k) * ModInt::fac(k)) + : ModInt::fac(k) * ModInt(2).pow(k); } - constexpr ModInt nCr(int r) const { - if (r < 0 or this->val_ < r) return ModInt(0); - return this->fac() * (*this - r).facinv() * ModInt(r).facinv(); + constexpr static ModInt nCr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModInt(0); + return ModInt::fac(n) * ModInt::facinv(r) * ModInt::facinv(n - r); } - constexpr ModInt nPr(int r) const { - if (r < 0 or this->val_ < r) return ModInt(0); - return this->fac() * (*this - r).facinv(); + constexpr static ModInt nPr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModInt(0); + return ModInt::fac(n) * ModInt::facinv(n - r); } static ModInt binom(int n, int r) { static long long bruteforce_times = 0; if (r < 0 or n < r) return ModInt(0); - if (n <= bruteforce_times or n < (int)facs.size()) return ModInt(n).nCr(r); + if (n <= bruteforce_times or n < (int)facs.size()) return ModInt::nCr(n, r); r = std::min(r, n - r); - ModInt ret = ModInt(r).facinv(); + ModInt ret = ModInt::facinv(r); for (int i = 0; i < r; ++i) ret *= n - i; bruteforce_times += r; @@ -148,18 +160,23 @@ template struct ModInt { int sum = 0; for (int k : ks) { assert(k >= 0); - ret *= ModInt(k).facinv(), sum += k; + ret *= ModInt::facinv(k), sum += k; } - return ret * ModInt(sum).fac(); + return ret * ModInt::fac(sum); + } + template static ModInt multinomial(Args... args) { + int sum = (0 + ... + args); + ModInt result = (1 * ... * ModInt::facinv(args)); + return ModInt::fac(sum) * result; } - // Catalan number, C_n = binom(2n, n) / (n + 1) + // Catalan number, C_n = binom(2n, n) / (n + 1) = # of Dyck words of length 2n // C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ... // https://oeis.org/A000108 // Complexity: O(n) static ModInt catalan(int n) { if (n < 0) return ModInt(0); - return ModInt(n * 2).fac() * ModInt(n + 1).facinv() * ModInt(n).facinv(); + return ModInt::fac(n * 2) * ModInt::facinv(n + 1) * ModInt::facinv(n); } ModInt sqrt() const { diff --git a/number/modint_runtime.hpp b/number/modint_runtime.hpp index 66afc82c..937d4679 100644 --- a/number/modint_runtime.hpp +++ b/number/modint_runtime.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -105,26 +106,38 @@ struct ModIntRuntime { ModIntRuntime pow(lint n) const { return power(n); } ModIntRuntime inv() const { return this->pow(md - 2); } - ModIntRuntime fac() const { + static ModIntRuntime fac(int n) { + assert(n >= 0); + if (n >= md) return ModIntRuntime(0); int l0 = facs().size(); - if (l0 > this->val_) return facs()[this->val_]; - - facs().resize(this->val_ + 1); - for (int i = l0; i <= this->val_; i++) + if (l0 > n) return facs()[n]; + facs().resize(n + 1); + for (int i = l0; i <= n; i++) facs()[i] = (i == 0 ? ModIntRuntime(1) : facs()[i - 1] * ModIntRuntime(i)); - return facs()[this->val_]; + return facs()[n]; + } + + static ModIntRuntime facinv(int n) { return ModIntRuntime::fac(n).inv(); } + + static ModIntRuntime doublefac(int n) { + assert(n >= 0); + if (n >= md) return ModIntRuntime(0); + long long k = (n + 1) / 2; + return (n & 1) + ? ModIntRuntime::fac(k * 2) / (ModIntRuntime(2).pow(k) * ModIntRuntime::fac(k)) + : ModIntRuntime::fac(k) * ModIntRuntime(2).pow(k); } - ModIntRuntime doublefac() const { - lint k = (this->val_ + 1) / 2; - return (this->val_ & 1) - ? ModIntRuntime(k * 2).fac() / (ModIntRuntime(2).pow(k) * ModIntRuntime(k).fac()) - : ModIntRuntime(k).fac() * ModIntRuntime(2).pow(k); + static ModIntRuntime nCr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModIntRuntime(0); + return ModIntRuntime::fac(n) / (ModIntRuntime::fac(r) * ModIntRuntime::fac(n - r)); } - ModIntRuntime nCr(int r) const { - if (r < 0 or this->val_ < r) return ModIntRuntime(0); - return this->fac() / ((*this - r).fac() * ModIntRuntime(r).fac()); + static ModIntRuntime nPr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModIntRuntime(0); + return ModIntRuntime::fac(n) / ModIntRuntime::fac(n - r); } ModIntRuntime sqrt() const {