Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions utilities/multidim_index.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <cassert>
#include <vector>

// n-dimentional index <-> 1-dimentional index converter
// [a_0, ..., a_{dim - 1}] <-> a_0 + a_1 * size_0 + ... + a_{dim - 1} * (size_0 * ... * size_{dim - 2})
template <class T = int> struct multidim_index {
int dim = 0;
T _size = 1;
std::vector<T> sizes;
std::vector<T> weights;

multidim_index() = default;

explicit multidim_index(const std::vector<T> &sizes)
: dim(sizes.size()), sizes(sizes), weights(dim, T(1)) {
for (int d = 0; d < (int)sizes.size(); ++d) {
assert(sizes.at(d) > 0);
_size *= sizes.at(d);
if (d >= 1) weights.at(d) = weights.at(d - 1) * sizes.at(d - 1);
}
}

T size() const { return _size; }

T flat_index(const std::vector<T> &encoded_vec) const {
assert((int)encoded_vec.size() == (int)sizes.size());
T ret = 0;
for (int d = 0; d < (int)sizes.size(); ++d) {
assert(0 <= encoded_vec.at(d) and encoded_vec.at(d) < sizes.at(d));
ret += encoded_vec.at(d) * weights.at(d);
}
return ret;
}

std::vector<T> encode(T flat_index) const {
assert(0 <= flat_index and flat_index < size());
std::vector<T> ret(sizes.size());
for (int d = (int)sizes.size() - 1; d >= 0; --d) {
ret.at(d) = flat_index / weights.at(d);
flat_index %= weights.at(d);
}
return ret;
}

template <class F> void lo_to_hi(F f) {
for (int d = 0; d < (int)sizes.size(); ++d) {
if (sizes.at(d) == 1) continue;

T i = 0;
std::vector<T> ivec(sizes.size());

int cur = sizes.size();

while (true) {
f(i, i + weights.at(d));
--cur;

while (cur >= 0 and ivec.at(cur) + 1 == sizes.at(cur) - (cur == d)) {
i -= ivec.at(cur) * weights.at(cur);
ivec.at(cur--) = 0;
}

if (cur < 0) break;

++ivec.at(cur);
i += weights.at(cur);
cur = sizes.size();
}
}
}

// Subset sum (fast zeta transform)
template <class U> void subset_sum(std::vector<U> &vec) {
assert((T)vec.size() == size());
lo_to_hi([&](T lo, T hi) { vec.at(hi) += vec.at(lo); });
}

// Inverse of subset sum (fast moebius transform)
template <class U> void subset_sum_inv(std::vector<U> &vec) {
assert((T)vec.size() == size());
const T s = size() - 1;
lo_to_hi([&](T dummylo, T dummyhi) { vec.at(s - dummylo) -= vec.at(s - dummyhi); });
}

// Superset sum (fast zeta transform)
template <class U> void superset_sum(std::vector<U> &vec) {
assert((T)vec.size() == size());
const T s = size() - 1;
lo_to_hi([&](T dummylo, T dummyhi) { vec.at(s - dummyhi) += vec.at(s - dummylo); });
}

// Inverse of superset sum (fast moebius transform)
template <class U> void superset_sum_inv(std::vector<U> &vec) {
assert((T)vec.size() == size());
lo_to_hi([&](T lo, T hi) { vec.at(lo) -= vec.at(hi); });
}
};
48 changes: 48 additions & 0 deletions utilities/multidim_index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
title: Multidimensional index
documentation_of: ./multidim_index.hpp
---

多次元の添え字を 1 次元に潰す処理とその逆変換の実装.

## 使用方法

### 添字変換

以下のような C++ コードを実行すると,以下のような結果を得る.

```cpp
multidim_index mi({2, 3, 4});

for (int i = 0; i < mi.size(); ++i) {
const auto vec = mi.encode(i);
assert(mi.flat_index(vec) == i);

cout << i << ": (";
for (int x : vec) cout << x << ",";
cout << ")\n";
}
```

```txt
0: (0,0,0,)
1: (1,0,0,)
2: (0,1,0,)
3: (1,1,0,)
4: (0,2,0,)
5: (1,2,0,)
6: (0,0,1,)
...
23: (1,2,3,)
```

### 累積和やその逆変換

通常の `+=` 演算子による累積和処理に関しては, `subset_sum()` / `subset_sum_inv()` (下側累積和およびその逆変換)・ `superset_sum()` / `superset_sum_inv()` (上側累積和およびその逆変換)が提供されている.

より一般の演算を行いたい場合は,ラムダ式 `f` を用意した上で `lo_to_hi(F f)` 関数を使えばよい.

## 問題例

- [AtCoder Beginner Contest 335(Sponsored by Mynavi) G - Discrete Logarithm Problems](https://atcoder.jp/contests/abc335/tasks/abc335_g)
- $P - 1$ の正の約数全てをキーとする DP テーブルで,整序関係をもとに累積和を取る. [参考提出](https://atcoder.jp/contests/abc335/submissions/49118789)
25 changes: 25 additions & 0 deletions utilities/test/multidim_index.zeta.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#define PROBLEM "https://judge.yosupo.jp/problem/bitwise_and_convolution"
#include "../../modint.hpp"
#include "../multidim_index.hpp"

#include <iostream>
#include <vector>
using namespace std;

int main() {
cin.tie(nullptr), ios::sync_with_stdio(false);
int N;
cin >> N;

multidim_index mi(std::vector<int>(N, 2));
vector<ModInt<998244353>> A(1 << N), B(1 << N);

for (auto &x : A) cin >> x;
for (auto &x : B) cin >> x;
mi.superset_sum(A);
mi.superset_sum(B);
for (int i = 0; i < 1 << N; ++i) A.at(i) *= B.at(i);
mi.superset_sum_inv(A);

for (auto x : A) cout << x << ' ';
}