/
wavelet.go
438 lines (365 loc) · 10.1 KB
/
wavelet.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
package bwt
import (
"errors"
"fmt"
"math"
"golang.org/x/exp/slices"
)
/*
For the waveletTree's usage, please read its
method documentation. To understand what it is and how
it works for either curiosity or maintenance, then read below.
# WaveletTree
The Wavelet Tree allows us to conduct RSA queries on strings in
a memory and run time efficient manner.
RSA stands for (R)ank, (S)elect, (A)ccess.
See this blog post by Alex Bowe for an additional explanation:
https://www.alexbowe.com/wavelet-trees/
## The Character's Path Encoding
Each character from a sequence's alphabet will be assigned a path.
This path encoding represents a path from the Wavelet Tree's root to some
leaf node that represents a character.
For example, given the alphabet A B C D E F G H, a possible encoding is:
A: 000
B: 001
C: 010
D: 011
E: 100
F: 101
G: 110
H: 111
If we wanted to get to the leaf that represents the character D, we'd have
to use D's path encoding to traverse the tree.
Consider 0 as the left and 1 as the right.
If we follow D's encoding, 011, then we'd take a path that looks like:
root
/
left
\
right
\
right
## The Data Represented at each node
Let us consider the sequence "bananas"
It has the alphabet b, a, n, s
Let's say it has the encoding:
a: 00
n: 01
b: 10
s: 11
and that 0 is left and 1 is right
We can represent this tree with bitvectors:
0010101
bananas
/ \
1000 001
baaa nns
/ \ / \
a n b s
If we translate each bit vector to its corresponding string, then it becomes:
bananas
/ \
baaa nns
/ \ / \
a b n s
Each node of the tree consists of a bitvector whose values indicate whether
the character at a particular index is in the left (0) or right (1) child of the
tree.
## RSA
At this point, we can talk about RSA. RSA stands for (R)ank, (S)elect, (A)ccess.
### Rank Example
WaveletTree.Rank(c, n) returns the rank of character c at index n in a sequence, i.e. how many
times c has occurred in a sequence before index n.
To get WaveletTree.Rank(a, 4) of bananas where a's encoding is 00
1. root.Rank(0, 4) of 0010101 is 3
2. Visit Left Child
3. child.Rank(0, 3) of 1000 is 2
4. Visit Left Child
5. We are at a leaf node, so return our last recorded rank: 2
### Select Example
To get WaveletTree.Select(n, 1) of bananas where n's encoding is 01
1. Go down to n's leaf using the path encoding is 01
2. Go back to n's leaf's parent
3. parent.Select(0, 1) of 001 is 0
4. Go to the next parent
5. parent.Select(1, 0) of 0010101 is 2
6. return 2 since we are at the root.
### Access Example
Take the tree we constructed earlier to represent the sequence "bananas".
0010101
/ \
1000 001
/ \ / \
a n b s
To access the 4th character of the sequence, we would call WaveletTree.Access(3),
which performs the following operations:
1. root[3] is 0 and root.Rank(0, 3) is 2
2. Since root[3] is 0, visit left child
3. child[2] is 0 and child.Rank(0, 2) is 1
4. Since child[2] is 0, visit left child
5. Left child is a leaf, so we've found our value (a)!
NOTE: The waveletTree does not literally have to be a tree. There are other forms that it may
exist in like the concatenation of order level representation of all its node's bitvectors...
as one example. Please reference the implementation if you'd like to understand how this
specific waveletTree works.
*/
// waveletTree is a data structure that allows us to index a sequence
// in a memory efficient way that allows us to conduct RSA, (R)ank (S)elect (A)ccess
// queries on strings. This is very useful in situations where you'd like to understand
// certain aspects of a sequence like:
// * the number of times a character appears
// * counting how the frequency of a character up to certain offset
// * locating characters of certain rank within the sequence
// * accessing the character at a given position
type waveletTree struct {
root *node
alpha []charInfo
length int
}
// Access will return the ith character of the original
// string used to build the waveletTree
func (wt waveletTree) Access(i int) byte {
if wt.root.isLeaf() {
return *wt.root.char
}
curr := wt.root
for !curr.isLeaf() {
bit := curr.data.Access(i)
i = curr.data.Rank(bit, i)
if bit {
curr = curr.right
} else {
curr = curr.left
}
}
return *curr.char
}
// Rank allows us to get the rank of a specified character in
// the original string
func (wt waveletTree) Rank(char byte, i int) int {
if wt.root.isLeaf() {
return wt.root.data.Rank(true, i)
}
curr := wt.root
ci, ok := wt.lookupCharInfo(char)
if !ok {
return 0
}
level := 0
var rank int
for !curr.isLeaf() {
pathBit := ci.path.getBit(ci.path.len() - 1 - level)
rank = curr.data.Rank(pathBit, i)
if pathBit {
curr = curr.right
} else {
curr = curr.left
}
level++
i = rank
}
return rank
}
// Select allows us to get the corresponding position of a character
// in the original string given its rank.
func (wt waveletTree) Select(char byte, rank int) int {
if wt.root.isLeaf() {
s, ok := wt.root.data.Select(true, rank)
if !ok {
msg := fmt.Sprintf("could not find a corresponding bit for node.Select(true, %d) root as leaf node", rank)
panic(msg)
}
return s
}
curr := wt.root
ci, ok := wt.lookupCharInfo(char)
if !ok {
return 0
}
level := 0
for !curr.isLeaf() {
pathBit := ci.path.getBit(ci.path.len() - 1 - level)
if pathBit {
curr = curr.right
} else {
curr = curr.left
}
level++
}
for curr.parent != nil {
curr = curr.parent
level--
pathBit := ci.path.getBit(ci.path.len() - 1 - level)
nextRank, ok := curr.data.Select(pathBit, rank)
if !ok {
msg := fmt.Sprintf("could not find a corresponding bit for node.Select(%t, %d) for characterInfo %+v", pathBit, rank, ci)
panic(msg)
}
rank = nextRank
}
return rank
}
func (wt waveletTree) reconstruct() string {
str := ""
for i := 0; i < wt.length; i++ {
str += string(wt.Access(i))
}
return str
}
func (wt waveletTree) lookupCharInfo(char byte) (charInfo, bool) {
for i := range wt.alpha {
if wt.alpha[i].char == char {
return wt.alpha[i], true
}
}
return charInfo{}, false
}
type node struct {
data rsaBitVector
char *byte
parent *node
left *node
right *node
}
func (n node) isLeaf() bool {
return n.char != nil
}
type charInfo struct {
char byte
maxRank int
path bitvector
}
func newWaveletTreeFromString(str string) (waveletTree, error) {
err := validateWaveletTreeBuildInput(&str)
if err != nil {
return waveletTree{}, err
}
bytes := []byte(str)
alpha := getCharInfoDescByRank(bytes)
root := buildWaveletTree(0, alpha, bytes)
// Handle the case where the provided sequence only has an alphabet
// of size 1
if root.isLeaf() {
bv := newBitVector(len(bytes))
for i := 0; i < bv.len(); i++ {
bv.setBit(i, true)
}
root.data = newRSABitVectorFromBitVector(bv)
}
return waveletTree{
root: root,
alpha: alpha,
length: len(str),
}, nil
}
func buildWaveletTree(currentLevel int, alpha []charInfo, bytes []byte) *node {
if len(alpha) == 0 {
return nil
}
if len(alpha) == 1 {
return &node{char: &alpha[0].char}
}
leftAlpha, rightAlpha := partitionAlpha(currentLevel, alpha)
var leftBytes []byte
var rightBytes []byte
bv := newBitVector(len(bytes))
for i := range bytes {
if isInAlpha(rightAlpha, bytes[i]) {
bv.setBit(i, true)
rightBytes = append(rightBytes, bytes[i])
} else {
leftBytes = append(leftBytes, bytes[i])
}
}
root := &node{
data: newRSABitVectorFromBitVector(bv),
}
leftTree := buildWaveletTree(currentLevel+1, leftAlpha, leftBytes)
rightTree := buildWaveletTree(currentLevel+1, rightAlpha, rightBytes)
root.left = leftTree
root.right = rightTree
if leftTree != nil {
leftTree.parent = root
}
if rightTree != nil {
rightTree.parent = root
}
return root
}
func isInAlpha(alpha []charInfo, b byte) bool {
for _, a := range alpha {
if a.char == b {
return true
}
}
return false
}
// partitionAlpha partitions the alphabet in half based on whether its corresponding path bit
// is a 0 or 1. 0 will comprise the left tree while 1 will comprise the right. The alphabet
// should be sorted in such a way that we remove the most amount of characters nearest to the
// root of the tree to reduce the memory footprint as much as possible.
func partitionAlpha(currentLevel int, alpha []charInfo) (left []charInfo, right []charInfo) {
for _, a := range alpha {
if a.path.getBit(a.path.len() - 1 - currentLevel) {
right = append(right, a)
} else {
left = append(left, a)
}
}
return left, right
}
// getCharInfoDescByRank takes in the bytes of the original
// string and return a sorted list of character metadata descending
// by rank. The character metadata is important for building the rest
// of the tree along with querying it later on. The sorting is important
// because this allows us to build the tree in the most memory efficient
// way since the characters with the greatest counts will be removed first
// before build the subsequent nodes in the lower levels.
// NOTE: alphabets are expected to be small for real usecases
func getCharInfoDescByRank(b []byte) []charInfo {
ranks := make(map[byte]int)
for i := 0; i < len(b); i++ {
if _, ok := ranks[b[i]]; ok {
ranks[b[i]] += 1
} else {
ranks[b[i]] = 0
}
}
var sortedInfo []charInfo
for k := range ranks {
sortedInfo = append(sortedInfo, charInfo{char: k, maxRank: ranks[k]})
}
slices.SortFunc(sortedInfo, func(a, b charInfo) bool {
if a.maxRank == b.maxRank {
return a.char < b.char
}
return a.maxRank > b.maxRank
})
numOfBits := getTreeHeight(sortedInfo)
for i := range sortedInfo {
bv := newBitVector(numOfBits)
encodeCharPathIntoBitVector(bv, uint64(i))
sortedInfo[i].path = bv
}
return sortedInfo
}
func encodeCharPathIntoBitVector(bv bitvector, n uint64) {
shift := 0
for n>>shift > 0 {
if n>>shift%2 == 1 {
bv.setBit(bv.len()-1-shift, true)
} else {
bv.setBit(bv.len()-1-shift, false)
}
shift++
}
}
func getTreeHeight(alpha []charInfo) int {
return int(math.Log2(float64(len(alpha)))) + 1
}
func validateWaveletTreeBuildInput(sequence *string) error {
if len(*sequence) == 0 {
return errors.New("Sequence can not be empty")
}
return nil
}