/
fenwicktree_01.hpp
103 lines (93 loc) · 2.47 KB
/
fenwicktree_01.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#pragma once
#include "ds/fenwicktree/fenwicktree.hpp"
struct FenwickTree_01 {
int N, n;
vc<u64> dat;
FenwickTree<Monoid_Add<int>> bit;
FenwickTree_01() {}
FenwickTree_01(int n) { build(n); }
template <typename F>
FenwickTree_01(int n, F f) {
build(n, f);
}
void build(int m) {
N = m;
n = ceil<int>(N + 1, 64);
dat.assign(n, u64(0));
bit.build(n);
}
template <typename F>
void build(int m, F f) {
N = m;
n = ceil<int>(N + 1, 64);
dat.assign(n, u64(0));
FOR(i, N) { dat[i / 64] |= u64(f(i)) << (i % 64); }
bit.build(n, [&](int i) -> int { return popcnt(dat[i]); });
}
int sum_all() { return bit.sum_all(); }
int sum(int k) { return prefix_sum(k); }
int prefix_sum(int k) {
int ans = bit.sum(k / 64);
ans += popcnt(dat[k / 64] & ((u64(1) << (k % 64)) - 1));
return ans;
}
int sum(int L, int R) {
if (L == 0) return prefix_sum(R);
int ans = 0;
ans -= popcnt(dat[L / 64] & ((u64(1) << (L % 64)) - 1));
ans += popcnt(dat[R / 64] & ((u64(1) << (R % 64)) - 1));
ans += bit.sum(L / 64, R / 64);
return ans;
}
void add(int k, int x) {
if (x == 1) add(k);
if (x == -1) remove(k);
}
void add(int k) {
dat[k / 64] |= u64(1) << (k % 64);
bit.add(k / 64, 1);
}
void remove(int k) {
dat[k / 64] &= ~(u64(1) << (k % 64));
bit.add(k / 64, -1);
}
int kth(int k, int L = 0) {
if (k >= sum_all()) return N;
k += popcnt(dat[L / 64] & ((u64(1) << (L % 64)) - 1));
L /= 64;
int mid = 0;
auto check = [&](auto e) -> bool {
if (e <= k) chmax(mid, e);
return e <= k;
};
int idx = bit.max_right(check, L);
if (idx == n) return N;
k -= mid;
u64 x = dat[idx];
int p = popcnt(x);
if (p <= k) return N;
k = binary_search([&](int n) -> bool { return (p - popcnt(x >> n)) <= k; },
0, 64, 0);
return 64 * idx + k;
}
int next(int k) {
int idx = k / 64;
k %= 64;
u64 x = dat[idx] & ~((u64(1) << k) - 1);
if (x) return 64 * idx + lowbit(x);
idx = bit.kth(0, idx + 1);
if (idx == n || !dat[idx]) return N;
return 64 * idx + lowbit(dat[idx]);
}
int prev(int k) {
if (k == N) --k;
int idx = k / 64;
k %= 64;
u64 x = dat[idx];
if (k < 63) x &= (u64(1) << (k + 1)) - 1;
if (x) return 64 * idx + topbit(x);
idx = bit.min_left([&](auto e) -> bool { return e <= 0; }, idx) - 1;
if (idx == -1) return -1;
return 64 * idx + topbit(dat[idx]);
}
};