From fa5ad742a96e7b501cd1329b2778c944968b98fb Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Mon, 25 Aug 2025 00:44:05 +0900 Subject: [PATCH 1/3] modint: introduce static methods --- formal_power_series/factorial_power.hpp | 13 ++-- formal_power_series/formal_power_series.hpp | 4 +- .../lagrange_interpolation.hpp | 2 +- ..._of_exponential_times_polynomial_limit.hpp | 2 +- .../test/bernoulli_number.test.cpp | 2 +- .../test/stirling_number_of_2nd.test.cpp | 2 +- modint.hpp | 73 +++++++++++++------ number/modint_runtime.hpp | 50 +++++++++---- 8 files changed, 100 insertions(+), 48 deletions(-) diff --git a/formal_power_series/factorial_power.hpp b/formal_power_series/factorial_power.hpp index e997d757..0b9795d5 100644 --- a/formal_power_series/factorial_power.hpp +++ b/formal_power_series/factorial_power.hpp @@ -3,17 +3,16 @@ #include #include -// CUT begin // Convert factorial power -> sampling // [y[0], y[1], ..., y[N - 1]] -> \sum_i a_i x^\underline{i} // Complexity: O(N log N) template std::vector factorial_to_ys(const std::vector &as) { const int N = as.size(); std::vector exp(N, 1); - for (int i = 1; i < N; i++) exp[i] = T(i).facinv(); + for (int i = 1; i < N; i++) exp[i] = T::facinv(i); auto ys = nttconv(as, exp); ys.resize(N); - for (int i = 0; i < N; i++) ys[i] *= T(i).fac(); + for (int i = 0; i < N; i++) ys[i] *= T::fac(i); return ys; } @@ -22,9 +21,9 @@ template std::vector factorial_to_ys(const std::vector &as) { // Complexity: O(N log N) template std::vector ys_to_factorial(std::vector ys) { const int N = ys.size(); - for (int i = 1; i < N; i++) ys[i] *= T(i).facinv(); + for (int i = 1; i < N; i++) ys[i] *= T::facinv(i); std::vector expinv(N, 1); - for (int i = 1; i < N; i++) expinv[i] = T(i).facinv() * (i % 2 ? -1 : 1); + for (int i = 1; i < N; i++) expinv[i] = T::facinv(i) * (i % 2 ? -1 : 1); auto as = nttconv(ys, expinv); as.resize(N); return as; @@ -36,12 +35,12 @@ template std::vector shift_of_factorial(const std::vector &as, T const int N = as.size(); std::vector b(N, 1), c(N, 1); for (int i = 1; i < N; i++) b[i] = b[i - 1] * (shift - i + 1) * T(i).inv(); - for (int i = 0; i < N; i++) c[i] = as[i] * T(i).fac(); + for (int i = 0; i < N; i++) c[i] = as[i] * T::fac(i); std::reverse(c.begin(), c.end()); auto ret = nttconv(b, c); ret.resize(N); std::reverse(ret.begin(), ret.end()); - for (int i = 0; i < N; i++) ret[i] *= T(i).facinv(); + for (int i = 0; i < N; i++) ret[i] *= T::facinv(i); return ret; } diff --git a/formal_power_series/formal_power_series.hpp b/formal_power_series/formal_power_series.hpp index 724f48d6..a4990946 100644 --- a/formal_power_series/formal_power_series.hpp +++ b/formal_power_series/formal_power_series.hpp @@ -220,14 +220,14 @@ template struct FormalPowerSeries : std::vector { P shift(T c) const { const int n = (int)this->size(); P ret = *this; - for (int i = 0; i < n; i++) ret[i] *= T(i).fac(); + for (int i = 0; i < n; i++) ret[i] *= T::fac(i); std::reverse(ret.begin(), ret.end()); P exp_cx(n, 1); for (int i = 1; i < n; i++) exp_cx[i] = exp_cx[i - 1] * c * T(i).inv(); ret = ret * exp_cx; ret.resize(n); std::reverse(ret.begin(), ret.end()); - for (int i = 0; i < n; i++) ret[i] *= T(i).facinv(); + for (int i = 0; i < n; i++) ret[i] *= T::facinv(i); return ret; } diff --git a/formal_power_series/lagrange_interpolation.hpp b/formal_power_series/lagrange_interpolation.hpp index 12e11016..94f3e12b 100644 --- a/formal_power_series/lagrange_interpolation.hpp +++ b/formal_power_series/lagrange_interpolation.hpp @@ -11,7 +11,7 @@ template MODINT interpolate_iota(const std::vector ys, const int N = ys.size(); if (x_eval.val() < N) return ys[x_eval.val()]; std::vector facinv(N); - facinv[N - 1] = MODINT(N - 1).fac().inv(); + facinv[N - 1] = MODINT::facinv(N - 1); for (int i = N - 1; i > 0; i--) facinv[i - 1] = facinv[i] * i; std::vector numleft(N); MODINT numtmp = 1; diff --git a/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp b/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp index 5accfb09..70f7ebc0 100644 --- a/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp +++ b/formal_power_series/sum_of_exponential_times_polynomial_limit.hpp @@ -20,7 +20,7 @@ MODINT sum_of_exponential_times_polynomial_limit(MODINT r, std::vector i MODINT ret = 0; rp = 1; for (int i = 0; i <= d; i++) { - ret += bs[d - i] * MODINT(d + 1).nCr(i) * rp; + ret += bs[d - i] * MODINT::binom(d + 1, i) * rp; rp *= -r; } return ret / MODINT(1 - r).pow(d + 1); diff --git a/formal_power_series/test/bernoulli_number.test.cpp b/formal_power_series/test/bernoulli_number.test.cpp index 7785169e..f660278a 100644 --- a/formal_power_series/test/bernoulli_number.test.cpp +++ b/formal_power_series/test/bernoulli_number.test.cpp @@ -10,5 +10,5 @@ int main() { using mint = ModInt<998244353>; FormalPowerSeries x({0, 1}); FormalPowerSeries b = ((x.exp(N + 2) - 1) >> 1).inv(N + 1); - for (int i = 0; i <= N; i++) printf("%d ", (b.coeff(i) * mint(i).fac()).val()); + for (int i = 0; i <= N; i++) printf("%d ", (b.coeff(i) * mint::fac(i)).val()); } diff --git a/formal_power_series/test/stirling_number_of_2nd.test.cpp b/formal_power_series/test/stirling_number_of_2nd.test.cpp index bd25a107..1e71cf1f 100644 --- a/formal_power_series/test/stirling_number_of_2nd.test.cpp +++ b/formal_power_series/test/stirling_number_of_2nd.test.cpp @@ -11,7 +11,7 @@ int main() { cin >> N; using mint = ModInt<998244353>; FormalPowerSeries a(N + 1); - a[N] = mint(N).fac().inv(); + a[N] = mint::facinv(N); for (int i = N - 1; i >= 0; i--) { a[i] = a[i + 1] * (i + 1); } auto b = a; for (int i = 0; i <= N; i++) { a[i] *= mint(i).pow(N), b[i] *= (i % 2 ? -1 : 1); } diff --git a/modint.hpp b/modint.hpp index f96deb6b..fe98690c 100644 --- a/modint.hpp +++ b/modint.hpp @@ -5,6 +5,7 @@ #include template struct ModInt { + static_assert(md > 1); using lint = long long; constexpr static int mod() { return md; } static int get_primitive_root() { @@ -102,39 +103,64 @@ template struct ModInt { return this->pow(md - 2); } } - constexpr ModInt fac() const { - while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2); - return facs[this->val_]; + + constexpr static ModInt fac(int n) { + assert(n >= 0); + if (n >= md) return ModInt(0); + while (n >= int(facs.size())) _precalculation(facs.size() * 2); + return facs[n]; } - constexpr ModInt facinv() const { - while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2); - return facinvs[this->val_]; + [[deprecated("use static method")]] constexpr ModInt fac() { return ModInt::fac(this->val_); } + + constexpr static ModInt facinv(int n) { + assert(n >= 0); + if (n >= md) return ModInt(0); + while (n >= int(facs.size())) _precalculation(facs.size() * 2); + return facinvs[n]; } - constexpr ModInt doublefac() const { - lint k = (this->val_ + 1) / 2; - return (this->val_ & 1) ? ModInt(k * 2).fac() / (ModInt(2).pow(k) * ModInt(k).fac()) - : ModInt(k).fac() * ModInt(2).pow(k); + [[deprecated("use static method")]] constexpr ModInt facinv() { + return ModInt::facinv(this->val_); } - constexpr ModInt nCr(int r) const { - if (r < 0 or this->val_ < r) return ModInt(0); - return this->fac() * (*this - r).facinv() * ModInt(r).facinv(); + constexpr static ModInt doublefac(int n) { + assert(n >= 0); + if (n >= md) return ModInt(0); + long long k = (n + 1) / 2; + return (n & 1) ? ModInt::fac(k * 2) / (ModInt(2).pow(k) * ModInt::fac(k)) + : ModInt::fac(k) * ModInt(2).pow(k); + } + [[deprecated("use static method")]] constexpr ModInt doublefac() { + return ModInt::doublefac(this->val_); } - constexpr ModInt nPr(int r) const { + constexpr static ModInt nCr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModInt(0); + return ModInt::fac(n) * ModInt::facinv(r) * ModInt::facinv(n - r); + } + [[deprecated("use static method")]] constexpr ModInt nCr(int r) { + return ModInt::nCr(this->val_, r); + } + + constexpr static ModInt nPr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModInt(0); + return ModInt::fac(n) * ModInt::facinv(n - r); + } + [[deprecated("use static method")]] constexpr ModInt nPr(int r) { if (r < 0 or this->val_ < r) return ModInt(0); - return this->fac() * (*this - r).facinv(); + return ModInt::nPr(this->val_, r); } static ModInt binom(int n, int r) { static long long bruteforce_times = 0; if (r < 0 or n < r) return ModInt(0); - if (n <= bruteforce_times or n < (int)facs.size()) return ModInt(n).nCr(r); + if (n <= bruteforce_times or n < (int)facs.size()) return ModInt::nCr(n, r); r = std::min(r, n - r); - ModInt ret = ModInt(r).facinv(); + ModInt ret = ModInt::facinv(r); for (int i = 0; i < r; ++i) ret *= n - i; bruteforce_times += r; @@ -148,18 +174,23 @@ template struct ModInt { int sum = 0; for (int k : ks) { assert(k >= 0); - ret *= ModInt(k).facinv(), sum += k; + ret *= ModInt::facinv(k), sum += k; } - return ret * ModInt(sum).fac(); + return ret * ModInt::fac(sum); + } + template static ModInt multinomial(Args... args) { + int sum = (0 + ... + args); + ModInt result = (1 * ... * ModInt::facinv(args)); + return ModInt::fac(sum) * result; } - // Catalan number, C_n = binom(2n, n) / (n + 1) + // Catalan number, C_n = binom(2n, n) / (n + 1) = # of Dyck words of length 2n // C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ... // https://oeis.org/A000108 // Complexity: O(n) static ModInt catalan(int n) { if (n < 0) return ModInt(0); - return ModInt(n * 2).fac() * ModInt(n + 1).facinv() * ModInt(n).facinv(); + return ModInt::fac(n * 2) * ModInt::facinv(n + 1) * ModInt::facinv(n); } ModInt sqrt() const { diff --git a/number/modint_runtime.hpp b/number/modint_runtime.hpp index 66afc82c..9d979d26 100644 --- a/number/modint_runtime.hpp +++ b/number/modint_runtime.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -105,26 +106,47 @@ struct ModIntRuntime { ModIntRuntime pow(lint n) const { return power(n); } ModIntRuntime inv() const { return this->pow(md - 2); } - ModIntRuntime fac() const { + static ModIntRuntime fac(int n) { + assert(n >= 0); + if (n >= md) return ModIntRuntime(0); int l0 = facs().size(); - if (l0 > this->val_) return facs()[this->val_]; - - facs().resize(this->val_ + 1); - for (int i = l0; i <= this->val_; i++) + if (l0 > n) return facs()[n]; + facs().resize(n + 1); + for (int i = l0; i <= n; i++) facs()[i] = (i == 0 ? ModIntRuntime(1) : facs()[i - 1] * ModIntRuntime(i)); - return facs()[this->val_]; + return facs()[n]; + } + + [[deprecated("use static method")]] ModIntRuntime fac() const { + return ModIntRuntime::fac(this->val_); + } + + static ModIntRuntime doublefac(int n) { + assert(n >= 0); + if (n >= md) return ModIntRuntime(0); + long long k = (n + 1) / 2; + return (n & 1) + ? ModIntRuntime::fac(k * 2) / (ModIntRuntime(2).pow(k) * ModIntRuntime::fac(k)) + : ModIntRuntime::fac(k) * ModIntRuntime(2).pow(k); + } + + [[deprecated("use static method")]] constexpr ModIntRuntime doublefac() { + return ModIntRuntime::doublefac(this->val_); } - ModIntRuntime doublefac() const { - lint k = (this->val_ + 1) / 2; - return (this->val_ & 1) - ? ModIntRuntime(k * 2).fac() / (ModIntRuntime(2).pow(k) * ModIntRuntime(k).fac()) - : ModIntRuntime(k).fac() * ModIntRuntime(2).pow(k); + static ModIntRuntime nCr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModIntRuntime(0); + return ModIntRuntime::fac(n) / (ModIntRuntime::fac(r) * ModIntRuntime::fac(n - r)); + } + [[deprecated("use static method")]] constexpr ModIntRuntime nCr(int r) { + return ModIntRuntime::nCr(this->val_, r); } - ModIntRuntime nCr(int r) const { - if (r < 0 or this->val_ < r) return ModIntRuntime(0); - return this->fac() / ((*this - r).fac() * ModIntRuntime(r).fac()); + static ModIntRuntime nPr(int n, int r) { + assert(n >= 0); + if (r < 0 or n < r) return ModIntRuntime(0); + return ModIntRuntime::fac(n) / ModIntRuntime::fac(n - r); } ModIntRuntime sqrt() const { From 6118b05382f5dc20dea8b278a247d4c7fe45067a Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Mon, 25 Aug 2025 00:44:48 +0900 Subject: [PATCH 2/3] remove deprecated methods --- modint.hpp | 14 -------------- number/modint_runtime.hpp | 11 ----------- 2 files changed, 25 deletions(-) diff --git a/modint.hpp b/modint.hpp index fe98690c..c8cad0d3 100644 --- a/modint.hpp +++ b/modint.hpp @@ -110,7 +110,6 @@ template struct ModInt { while (n >= int(facs.size())) _precalculation(facs.size() * 2); return facs[n]; } - [[deprecated("use static method")]] constexpr ModInt fac() { return ModInt::fac(this->val_); } constexpr static ModInt facinv(int n) { assert(n >= 0); @@ -118,9 +117,6 @@ template struct ModInt { while (n >= int(facs.size())) _precalculation(facs.size() * 2); return facinvs[n]; } - [[deprecated("use static method")]] constexpr ModInt facinv() { - return ModInt::facinv(this->val_); - } constexpr static ModInt doublefac(int n) { assert(n >= 0); @@ -129,28 +125,18 @@ template struct ModInt { return (n & 1) ? ModInt::fac(k * 2) / (ModInt(2).pow(k) * ModInt::fac(k)) : ModInt::fac(k) * ModInt(2).pow(k); } - [[deprecated("use static method")]] constexpr ModInt doublefac() { - return ModInt::doublefac(this->val_); - } constexpr static ModInt nCr(int n, int r) { assert(n >= 0); if (r < 0 or n < r) return ModInt(0); return ModInt::fac(n) * ModInt::facinv(r) * ModInt::facinv(n - r); } - [[deprecated("use static method")]] constexpr ModInt nCr(int r) { - return ModInt::nCr(this->val_, r); - } constexpr static ModInt nPr(int n, int r) { assert(n >= 0); if (r < 0 or n < r) return ModInt(0); return ModInt::fac(n) * ModInt::facinv(n - r); } - [[deprecated("use static method")]] constexpr ModInt nPr(int r) { - if (r < 0 or this->val_ < r) return ModInt(0); - return ModInt::nPr(this->val_, r); - } static ModInt binom(int n, int r) { static long long bruteforce_times = 0; diff --git a/number/modint_runtime.hpp b/number/modint_runtime.hpp index 9d979d26..4e6b5757 100644 --- a/number/modint_runtime.hpp +++ b/number/modint_runtime.hpp @@ -117,10 +117,6 @@ struct ModIntRuntime { return facs()[n]; } - [[deprecated("use static method")]] ModIntRuntime fac() const { - return ModIntRuntime::fac(this->val_); - } - static ModIntRuntime doublefac(int n) { assert(n >= 0); if (n >= md) return ModIntRuntime(0); @@ -130,18 +126,11 @@ struct ModIntRuntime { : ModIntRuntime::fac(k) * ModIntRuntime(2).pow(k); } - [[deprecated("use static method")]] constexpr ModIntRuntime doublefac() { - return ModIntRuntime::doublefac(this->val_); - } - static ModIntRuntime nCr(int n, int r) { assert(n >= 0); if (r < 0 or n < r) return ModIntRuntime(0); return ModIntRuntime::fac(n) / (ModIntRuntime::fac(r) * ModIntRuntime::fac(n - r)); } - [[deprecated("use static method")]] constexpr ModIntRuntime nCr(int r) { - return ModIntRuntime::nCr(this->val_, r); - } static ModIntRuntime nPr(int n, int r) { assert(n >= 0); From de121d62c69c51224749a54347235fdfe6bed73b Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Mon, 25 Aug 2025 00:47:28 +0900 Subject: [PATCH 3/3] Add ModIntRuntime::facinv() --- number/modint_runtime.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/number/modint_runtime.hpp b/number/modint_runtime.hpp index 4e6b5757..937d4679 100644 --- a/number/modint_runtime.hpp +++ b/number/modint_runtime.hpp @@ -117,6 +117,8 @@ struct ModIntRuntime { return facs()[n]; } + static ModIntRuntime facinv(int n) { return ModIntRuntime::fac(n).inv(); } + static ModIntRuntime doublefac(int n) { assert(n >= 0); if (n >= md) return ModIntRuntime(0);