From e7e3e05e93eb3c7cd74b307ffb9b7e9dd80561c5 Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Sat, 7 Dec 2024 22:13:21 +0900 Subject: [PATCH 1/2] update binary trie --- data_structure/binary_trie.hpp | 84 ++++++++++++++----- data_structure/binary_trie.md | 28 +++++++ data_structure/test/binary_trie.test.cpp | 2 +- .../test/binary_trie.yuki2977.test.cpp | 32 +++++++ 4 files changed, 125 insertions(+), 21 deletions(-) create mode 100644 data_structure/binary_trie.md create mode 100644 data_structure/test/binary_trie.yuki2977.test.cpp diff --git a/data_structure/binary_trie.hpp b/data_structure/binary_trie.hpp index 46af29fb..92c48723 100644 --- a/data_structure/binary_trie.hpp +++ b/data_structure/binary_trie.hpp @@ -1,26 +1,27 @@ #pragma once #include -// CUT begin -struct BinaryTrie { - using Int = int; +template struct BinaryTrie { int maxD; - std::vector deg, sz; + std::vector deg, subtree_sum; std::vector ch0, ch1, par; int _new_node(int id_par) { - deg.emplace_back(0); - sz.emplace_back(0); + deg.emplace_back(Count()); + subtree_sum.emplace_back(Count()); ch0.emplace_back(-1); ch1.emplace_back(-1); par.emplace_back(id_par); - return ch0.size() - 1; + return (int)ch0.size() - 1; } BinaryTrie(int maxD = 0) : maxD(maxD) { _new_node(-1); } - int _goto(Int x) { + + // Return index of x. + // Create nodes to locate x if needed. + int _get_create_index(Int x) { int now = 0; - for (int d = maxD - 1; d >= 0; d--) { + for (int d = maxD - 1; d >= 0; --d) { int nxt = ((x >> d) & 1) ? ch1[now] : ch0[now]; if (nxt == -1) { nxt = _new_node(now); @@ -31,29 +32,49 @@ struct BinaryTrie { return now; } - void insert(Int x) { - int now = _goto(x); - if (deg[now] == 0) { - deg[now] = 1; - while (now >= 0) { sz[now]++, now = par[now]; } + // Return node index of x. + // Return -1 if x is not in trie. + int _get_index(Int x) const { + int now = 0; + for (int d = maxD - 1; d >= 0; --d) { + now = ((x >> d) & 1) ? ch1[now] : ch0[now]; + if (now == -1) return -1; } + return now; + } + + // insert x + void insert(Int x, Count c = Count(1)) { + int now = _get_create_index(x); + deg[now] += c; + while (now >= 0) subtree_sum[now] += c, now = par[now]; } + // delete x if exists void erase(Int x) { - int now = _goto(x); - if (deg[now] > 0) { + int now = _get_index(x); + if (now >= 0 and deg[now] != 0) { + Count r = deg[now]; deg[now] = 0; - while (now >= 0) { sz[now]--, now = par[now]; } + while (now >= 0) subtree_sum[now] -= r, now = par[now]; } } - Int xor_min(Int x) { + Count count(Int x) const { + int now = _get_index(x); + return now == -1 ? Count() : deg[now]; + } + + bool contains(Int x) const { return count(x) > Count(); } + + // min(y ^ x) for y in trie + Int xor_min(Int x) const { Int ret = 0; int now = 0; - if (!sz[now]) return -1; + if (!subtree_sum[now]) return -1; for (int d = maxD - 1; d >= 0; d--) { int y = ((x >> d) & 1) ? ch1[now] : ch0[now]; - if (y != -1 and sz[y]) { + if (y != -1 and subtree_sum[y]) { now = y; } else { ret += Int(1) << d, now = ch0[now] ^ ch1[now] ^ y; @@ -61,4 +82,27 @@ struct BinaryTrie { } return ret; } + + // Count elements y such that x ^ y < thres + Count count_less_xor(Int x, Int thres) const { + Count ret = Count(); + int now = 0; + + for (int d = maxD - 1; d >= 0; d--) { + if (now == -1) break; + + const bool bit_x = (x >> d) & 1; + + if ((thres >> d) & 1) { + const int child = bit_x ? ch1[now] : ch0[now]; + if (child != -1) ret += subtree_sum[child]; + + now = bit_x ? ch0[now] : ch1[now]; + } else { + now = bit_x ? ch1[now] : ch0[now]; + } + } + + return ret; + } }; diff --git a/data_structure/binary_trie.md b/data_structure/binary_trie.md new file mode 100644 index 00000000..b0d44138 --- /dev/null +++ b/data_structure/binary_trie.md @@ -0,0 +1,28 @@ +--- +title: Binary trie +documentation_of: ./binary_trie.hpp +--- + +非負整数の集合や多重集合に対する一部のクエリを効率的に行うためのデータ構造. + +## 使用方法 + +```cpp +using Key = int; +using Count = int; +const int D = 30; // Key の桁数 + +BinaryTrie trie(D); + +for (int a : A) trie.insert(a); + +Key t; +Count n = trie.count_less_xor(a, t); // a ^ x < t を満たす x (x は現在存在する値)を数える + +Key v = bt.xor_min(t); // t ^ x (x は現在存在する値)の最小値を求める +``` + +## 問題例 + +- [Library Checker: Set Xor-Min](https://judge.yosupo.jp/problem/set_xor_min) +- [No.2977 Kth Xor Pair - yukicoder](https://yukicoder.me/problems/no/2977) diff --git a/data_structure/test/binary_trie.test.cpp b/data_structure/test/binary_trie.test.cpp index 837cd265..ca0eb2cc 100644 --- a/data_structure/test/binary_trie.test.cpp +++ b/data_structure/test/binary_trie.test.cpp @@ -9,7 +9,7 @@ int main() { int Q; cin >> Q; - BinaryTrie bt(30); + BinaryTrie bt(30); while (Q--) { int q, x; cin >> q >> x; diff --git a/data_structure/test/binary_trie.yuki2977.test.cpp b/data_structure/test/binary_trie.yuki2977.test.cpp new file mode 100644 index 00000000..c30b219a --- /dev/null +++ b/data_structure/test/binary_trie.yuki2977.test.cpp @@ -0,0 +1,32 @@ +#include "../binary_trie.hpp" +#define PROBLEM "https://yukicoder.me/problems/no/2977" + +#include +using namespace std; + +int main() { + cin.tie(nullptr), ios::sync_with_stdio(false); + + int N; + long long K; + cin >> N >> K; + + vector A(N); + for (auto &x : A) cin >> x; + + constexpr int D = 30; + + BinaryTrie trie(D); + for (int a : A) trie.insert(a); + + int lo = 0, hi = 1 << D; // [lo, hi) + while (lo + 1 < hi) { + const int mid = (lo + hi) / 2; + + long long cnt = 0; + for (int a : A) cnt += trie.count_less_xor(a, mid); + (cnt >= K * 2 + N ? hi : lo) = mid; + } + + cout << lo << '\n'; +} From b9a4a0ad7b7b66e06b106293161186c9e80c1735 Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Sat, 7 Dec 2024 22:16:29 +0900 Subject: [PATCH 2/2] delete invalid test --- graph/test/bipartite_matching(slow).test.cpp | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 graph/test/bipartite_matching(slow).test.cpp diff --git a/graph/test/bipartite_matching(slow).test.cpp b/graph/test/bipartite_matching(slow).test.cpp deleted file mode 100644 index 83f741a9..00000000 --- a/graph/test/bipartite_matching(slow).test.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "../bipartite_matching(slow).hpp" -#include -#define PROBLEM "http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=GRL_7_A" - -int main() { - std::cin.tie(nullptr), std::ios::sync_with_stdio(false); - - int X, Y, E; - std::cin >> X >> Y >> E; - BipartiteMatching graph(X + Y); - while (E--) { - int s, t; - std::cin >> s >> t; - graph.add_edge(s, X + t); - } - std::cout << graph.solve() << '\n'; -}