In [1]:
from agentflow.buffers.segment_tree import SumTree
import agentflow.buffers.segment_tree_c as segment_tree
from agentflow.common.sum_prefix_tree import SumPrefixTree
import numpy as np

## Construct tree objects

In [2]:
n = int(1e6)

x = (np.random.choice(50,size=n)+1).astype(float)
idx = np.arange(n).astype(np.int32)

sumtree = SumTree(n)
sumtree2 = np.zeros(n*2)
sumtree3 = SumPrefixTree(np.zeros(n).copy())

## Set single value

In [3]:
%%timeit

sumtree[3] = 5

2.69 µs ± 30 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [4]:
%%timeit

segment_tree.update_tree(3,5,sumtree2)

634 ns ± 80.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [5]:
%%timeit

sumtree3[3] = 5

5.32 µs ± 309 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [6]:
2430/544.

4.466911764705882

In [7]:
np.abs(sumtree2 - np.array(sumtree._data)).max()

0.0

## Set 100 values

In [8]:
x0 = (np.random.choice(50,size=100)+1).astype(float)
idx0 = np.random.choice(n,size=100).astype(np.int32)

In [9]:
x0.shape

(100,)

In [10]:
%%timeit

for i,v in enumerate(x0):
    sumtree[i] = v

478 µs ± 19.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
%%timeit

segment_tree.update_tree_multi(idx0,x0,sumtree2)

3.21 µs ± 97.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [12]:
%%timeit

sumtree3[idx0] = x0

8.5 µs ± 433 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [13]:
np.abs(sumtree2 - np.array(sumtree._data)).max()

2572.0

In [61]:
478/3.21

148.90965732087227

In [62]:
478/8.5

56.23529411764706

## Set all values

In [15]:
x = (np.random.choice(50,size=n)+1).astype(float)
idx = np.arange(n).astype(np.int32)

In [16]:
x.shape

(1000000,)

In [17]:
%%timeit

for i,v in enumerate(x):
    sumtree[i] = v

4.99 s ± 66.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit

segment_tree.update_tree_multi(idx,x,sumtree2)

15.4 ms ± 946 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
%%timeit

sumtree3[idx] = x

17 ms ± 756 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
np.abs(sumtree2 - np.array(sumtree._data)).max()

0.0

In [21]:
sumtree2.shape

(2000000,)

In [22]:
np.abs(sumtree3 - sumtree2[n:]).max()

0.0

In [23]:
sumtree3[0]

41.0

In [24]:
sumtree3-2

array([39., 41.,  1., ..., 32., 13., 14.])

In [25]:
np.abs(sumtree3.view(np.ndarray) - sumtree2[n:]).max()

0.0

In [26]:
np.abs(sumtree3._sumtree - sumtree2[:n]).max()

0.0

In [27]:
sumtree3

SumPrefixTree([41., 43.,  3., ..., 34., 15., 16.])

In [28]:
sumtree2[:sumtree2.size//2]

array([0.0000000e+00, 2.5510188e+07, 1.3377109e+07, ..., 5.6000000e+01,
       5.2000000e+01, 3.1000000e+01])

In [63]:
5.*1000/15

333.3333333333333

## get prefix sum id

In [30]:
vv = np.arange(0,sum(x)+1,10000)
vv.shape

(2552,)

In [31]:
for v in vv:
    assert sumtree.get_prefix_sum_idx(v) == segment_tree.get_prefix_sum_idx(v,sumtree2)
    
v = sum(x)
assert sumtree.get_prefix_sum_idx(v) == segment_tree.get_prefix_sum_idx(v,sumtree2)

In [32]:
%%timeit

sumtree.get_prefix_sum_idx(30)

7.96 µs ± 443 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [33]:
%%timeit

segment_tree.get_prefix_sum_idx(30,sumtree2)

647 ns ± 30.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [34]:
7.11*1000/590.

12.05084745762712

## multiple get prefix sum id

In [35]:
y = np.random.choice(int(x.sum()),size=100).astype(float)

In [36]:
%%timeit

[sumtree.get_prefix_sum_idx(v) for v in y]

646 µs ± 56.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [37]:
%%timeit

[segment_tree.get_prefix_sum_idx(v,sumtree2) for v in y]

82.1 µs ± 6.16 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [38]:
%%timeit

output = np.zeros(y.shape,dtype=np.int32)
segment_tree.get_prefix_sum_multi_idx(output,y,sumtree2)

11.2 µs ± 591 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [39]:
out2 = np.zeros(y.shape,dtype=np.int32)

In [40]:
%%timeit

segment_tree.get_prefix_sum_multi_idx(out2,y,sumtree2)

9.55 µs ± 627 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [41]:
600/10.

60.0

In [42]:
y.min(),y.max()

(83116.0, 25460951.0)

In [43]:
%%timeit

sumtree3.get_prefix_sum_id(y)

14.8 µs ± 704 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [44]:
out3 = sumtree3.get_prefix_sum_id(y)

In [45]:
sumtree2[180433+n]

5.0

In [46]:
sumtree3[180433]

5.0

In [47]:
out2

array([150607, 748453, 561545, 363554, 362917, 247240,  70321, 650780,
       952118,  80532, 564256, 478853, 801895, 183423,  10863, 527573,
       585203, 512556, 565609, 100457, 820381, 621307, 262547, 359132,
       793441, 316282, 889707,  46611, 401367, 641540,  56080, 868937,
       644187, 384276, 880044, 819522, 539211,  57759, 378523, 867895,
       720757, 588535, 949677, 355063, 554042, 523020, 286418, 232621,
       384638, 176357,  86372, 547606, 757587, 534945, 457207, 185753,
       751510, 573701, 275118, 862458, 849892, 830122, 732306, 471091,
       830927, 163616, 405055, 760504, 845765, 758619, 591031,  51840,
       579984, 353599, 847317,  70382, 749994,  97674, 643960, 676494,
       629038,  98182, 898473, 657284, 143543, 403460,  79021,  92045,
       738845, 815629, 972919, 679172, 414040, 673638, 561306, 723013,
       626587, 244344,  14978, 862941], dtype=int32)

In [48]:
out3

array([150607, 748453, 561545, 363554, 362917, 247240,  70321, 650780,
       952118,  80532, 564256, 478853, 801895, 183423,  10863, 527573,
       585203, 512556, 565609, 100457, 820381, 621307, 262547, 359132,
       793441, 316282, 889707,  46611, 401367, 641540,  56080, 868937,
       644187, 384276, 880044, 819522, 539211,  57759, 378523, 867895,
       720757, 588535, 949677, 355063, 554042, 523020, 286418, 232621,
       384638, 176357,  86372, 547606, 757587, 534945, 457207, 185753,
       751510, 573701, 275118, 862458, 849892, 830122, 732306, 471091,
       830927, 163616, 405055, 760504, 845765, 758619, 591031,  51840,
       579984, 353599, 847317,  70382, 749994,  97674, 643960, 676494,
       629038,  98182, 898473, 657284, 143543, 403460,  79021,  92045,
       738845, 815629, 972919, 679172, 414040, 673638, 561306, 723013,
       626587, 244344,  14978, 862941], dtype=int32)

In [49]:
sumtree3._flat_base

array([41., 43.,  3., ..., 34., 15., 16.])

In [50]:
y

array([ 2605228., 17851275., 13087759.,  8045696.,  8029448.,  5071786.,
         557646., 15360565., 23040973.,   814836., 13157329., 10978650.,
       19212132.,  3440653., 24547201., 12220977., 13691254., 11838215.,
       13192216.,  1323144., 19683067., 14608503.,  5466034.,  7932183.,
       18997469.,  6840377., 21448330., 25460951.,  9005010., 15126364.,
         192866., 20918417., 15194271.,  8573579., 21202213., 19660574.,
       12517036.,   234787.,  8426645., 20891782., 17147788., 13777933.,
       22980389.,  7828353., 12895863., 12105404.,  6077205.,  4698573.,
        8583141.,  3259355.,   964056., 12733367., 18083098., 12408249.,
       10431056.,  3499492., 17929246., 13397999.,  5789379., 20753252.,
       20434879., 19931366., 17441792., 10779132., 19951658.,  2935956.,
        9099380., 18157537., 20328501., 18108974., 13840698.,    83116.,
       13557287.,  7790536., 20368530.,   559386., 17889766.,  1251770.,
       15188621., 16019875., 14807316.,  1265181., 

In [51]:
out3

array([150607, 748453, 561545, 363554, 362917, 247240,  70321, 650780,
       952118,  80532, 564256, 478853, 801895, 183423,  10863, 527573,
       585203, 512556, 565609, 100457, 820381, 621307, 262547, 359132,
       793441, 316282, 889707,  46611, 401367, 641540,  56080, 868937,
       644187, 384276, 880044, 819522, 539211,  57759, 378523, 867895,
       720757, 588535, 949677, 355063, 554042, 523020, 286418, 232621,
       384638, 176357,  86372, 547606, 757587, 534945, 457207, 185753,
       751510, 573701, 275118, 862458, 849892, 830122, 732306, 471091,
       830927, 163616, 405055, 760504, 845765, 758619, 591031,  51840,
       579984, 353599, 847317,  70382, 749994,  97674, 643960, 676494,
       629038,  98182, 898473, 657284, 143543, 403460,  79021,  92045,
       738845, 815629, 972919, 679172, 414040, 673638, 561306, 723013,
       626587, 244344,  14978, 862941], dtype=int32)

In [52]:
(out2 == out3).mean()

1.0

## Set vs Get (sumtree2)

In [53]:
K = 100
x = (np.random.choice(50,size=K)+1).astype(float)
idx = np.random.choice(n,size=K).astype(np.int32)

In [54]:
%%timeit

segment_tree.update_tree_multi(idx,y,sumtree2)

3.24 µs ± 99.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [55]:
idx = np.arange(K).astype(np.int32)
y = np.random.choice(int(x.sum()),size=K).astype(float)

In [56]:
%%timeit

segment_tree.get_prefix_sum_multi_idx(out2,y,sumtree2)

9.3 µs ± 390 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


## Set vs Get (sumtree3)

In [57]:
K = 100
x = (np.random.choice(50,size=K)+1).astype(float)
idx = np.random.choice(n,size=K).astype(np.int32)

In [58]:
%%timeit

sumtree3[idx] = x

9.02 µs ± 369 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [59]:
idx = np.arange(K).astype(np.int32)
y = np.random.choice(int(x.sum()),size=K).astype(float)

In [60]:
%%timeit

sumtree3.get_prefix_sum_id(y)

13.5 µs ± 641 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
