In [1]:
import numpy as np
from numpy import genfromtxt
# import the synthectic query workload
# long range: from -180 to 180, the first column
# lat range: from -90 to 90, the second column
query_bound_dim1 = genfromtxt('/Users/lizhe/Desktop/LearnedKDTree/DataAndWorkload/SyntheticWorkload/Dim1_QueryBound_Tweet_C10_P10_S5.csv', delimiter=',')
query_bound_dim2 = genfromtxt('/Users/lizhe/Desktop/LearnedKDTree/DataAndWorkload/SyntheticWorkload/Dim2_QueryBound_Tweet_C10_P10_S5.csv', delimiter=',')
# the twitter dataset
dataset = genfromtxt("/Users/lizhe/Library/Mobile Documents/com~apple~CloudDocs/SortedSingleDimPOIs2.csv", delimiter=',')

In [9]:
# the method to generate kdnodes, where each node should not exceed the page size threshold
# @dataset, the origin dataset, contains each dims value in each row
# @query_bounds, the [] that contains each dim in each row
# @threshold, the page size threshold
def LearnedKDTree(dataset, query_bound, Dimorder, domains, threshold=32000):
    
    # load query rectangles
    # scan to balance the cross of rectangles and data
    # save split position
    
    currentDim = 0
    
    kdnodes = ResuriveDivide(dataset, query_bound, currentDim, Dimorder, domains, threshold, 0)
    
    return kdnodes

In [12]:
# asssumption: the query boundings will not overlap

# divide the KD-Tree recursively
# @dataset, should contains the data only in this subnode; numpy object
# @query_bound, should contains all the bounds; numpy object
# @currentDim, the dimension this iteration should focus on, an index in the Dimorder; integer
# @domains, the current domain of the nodes of every dimension [first lower, second upper],[]...; array object
def ResuriveDivide(dataset, query_bound, currentDim, Dimorder, domains, threshold, level):
    
    print("level: ",level)
    print("dataset size: ",len(dataset))
    
    # check if the threshold is already satisfied
    total_size = len(dataset)
    if total_size <= threshold:
        # the kdnodes should be an global object outside the function
        kdnodes = []
        kdnodes.append([domains,total_size])
        return kdnodes
    
    # the current dimension
    divideDim = Dimorder[currentDim]
    
    # sort according to the current dimension
    dataset = dataset[dataset[:,divideDim].argsort()]
    
    # find the medium
    medium = dataset[int(total_size/2),divideDim]
    medium_low = domains[divideDim][0]
    medium_up = domains[divideDim][1]
    
    # start check split position from the medium
    split_position = int(total_size/2)
    split_low = 0
    split_up = total_size
    
    # check if the split position intersect some query boundings in this dim
    for i in range(len(query_bound[divideDim])):
        
        # if intersect some query bounds
        if medium > query_bound[divideDim][i][0] and medium < query_bound[divideDim][i][1]:
            
            # check if the two end already exceeds domain
            if query_bound[divideDim][i][0] < domains[divideDim][0] and query_bound[divideDim][i][1] > domains[divideDim][1]:
                break;
            
            else:
                if query_bound[divideDim][i][0] > domains[divideDim][0]:
                # get the number of records from medium to the end
                    for j in range(split_position-1,-1,-1):
                        if dataset[j][divideDim] <= query_bound[divideDim][i][0]:
                            split_low = j
                            medium_low = dataset[split_low,divideDim]
                            break
                
                if query_bound[divideDim][i][1] < domains[divideDim][1]:
                # get the number of records from medium to the end
                    for j in range(split_position,total_size,1):
                        if dataset[j][divideDim] >= query_bound[divideDim][i][1]:
                            split_up = j
                            medium_up = dataset[split_up,divideDim]
                            break
                
            # if not exceeds then choose the one that is closest from the medium (in terms of #records!)
            if (total_size/2) - split_low < (split_up - total_size/2) and split_low != 0:
                split_position = split_low
                medium = medium_low
            elif (total_size/2) - split_low >= (split_up - total_size/2) and split_up != total_size :
                split_position = split_up
                medium = medium_up
            
            # after handle the overlap bounding, we can skip the remaining, as there will be at most 1 as assumned
            break;
            
    # split the dataset according to the split position
    sub_dataset1 = dataset[0:split_position,:]
    sub_dataset2 = dataset[split_position:-1,:]
    
    # change the domains
    sub_domains1 = np.copy(domains)
    sub_domains1[divideDim][1] = medium
    sub_domains2 = np.copy(domains)
    sub_domains2[divideDim][0] = medium
    
    # change the divideDim
    currentDim += 1
    if currentDim >= len(Dimorder):
        currentDim %= len(Dimorder)
    
    # used to see the current depth
    level += 1
    
    # recursion
    kdnodes = []
    kdnodes.extend(ResuriveDivide(sub_dataset1, query_bound, currentDim, Dimorder, sub_domains1, threshold, level))
    kdnodes.extend(ResuriveDivide(sub_dataset2, query_bound, currentDim, Dimorder, sub_domains2, threshold, level))
    
    print("kdnodes: ",len(kdnodes))
    
    return kdnodes

In [13]:
# test
query_bound = [query_bound_dim1,query_bound_dim2]
Dimorder = [0,1]
domains = [[-180,180],[-90,90]]
kdnodes = LearnedKDTree(dataset, query_bound, Dimorder, domains, threshold=32000)
print(len(kdnodes))
print(kdnodes)

level:  0
dataset size:  1157570
level:  1
dataset size:  578785
level:  2
dataset size:  309798
level:  3
dataset size:  154899
level:  4
dataset size:  4775
level:  4
dataset size:  150123
level:  5
dataset size:  75061
level:  6
dataset size:  37530
level:  7
dataset size:  31844
level:  7
dataset size:  5685
kdnodes:  2
level:  6
dataset size:  37530
level:  7
dataset size:  18765
level:  7
dataset size:  18764
kdnodes:  2
kdnodes:  4
level:  5
dataset size:  75061
level:  6
dataset size:  37530
level:  7
dataset size:  18765
level:  7
dataset size:  18764
kdnodes:  2
level:  6
dataset size:  37530
level:  7
dataset size:  18765
level:  7
dataset size:  18764
kdnodes:  2
kdnodes:  4
kdnodes:  8
kdnodes:  9
level:  3
dataset size:  154898
level:  4
dataset size:  77449
level:  5
dataset size:  75595
level:  6
dataset size:  37797
level:  7
dataset size:  18898
level:  7
dataset size:  18898
kdnodes:  2
level:  6
dataset size:  37797
level:  7
dataset size:  9
level:  7
dataset size:

In [24]:
# transfer the form from [[(x,x),(x,x)],x] to [x,x,x,x,x], specified to 2D only!!!!!!
nodes = []
for i in range(len(kdnodes)):
    each_node = [kdnodes[i][0][0][0], kdnodes[i][0][0][1],kdnodes[i][0][1][0],kdnodes[i][0][1][1],kdnodes[i][1]]
    nodes.append(each_node)
print(nodes)
np.savetxt('/Users/lizhe/Desktop/LearnedKDTree/DataAndWorkload/SyntheticWorkload/KDnodes_Tweet_C10_P10_S5.csv',nodes,delimiter=',')

[[-180, -5, -90, -76, 4775], [-180, -26, -76, -53, 31844], [-26, -23, -76, -53, 5685], [-180, -25, -53, -32, 18765], [-25, -23, -53, -32, 18764], [-23, -21, -76, -43, 18765], [-21, -5, -76, -43, 18764], [-23, -14, -43, -32, 18765], [-14, -5, -43, -32, 18764], [-5, 19, -90, -99, 18898], [19, 30, -90, -99, 18898], [-5, 0, -99, -86, 9], [0, 30, -99, -96, 18893], [0, 30, -96, -86, 18893], [30, 32, -90, -86, 1853], [-5, 1, -86, -76, 1927], [-5, 1, -76, -32, 30295], [1, 32, -86, -79, 22612], [1, 32, -79, -32, 22611], [-180, -6, -32, 106, 16265], [-180, -6, 106, 107, 16264], [-6, 1, -32, 104, 16264], [-6, 1, 104, 107, 16264], [-180, -7, 107, 112, 16264], [-180, -7, 112, 90, 16264], [-7, 1, 107, 110, 16264], [-7, 1, 110, 90, 16264], [1, 24, -32, 98, 16880], [24, 30, -32, 98, 16880], [1, 13, 98, 100, 16880], [13, 30, 98, 100, 16880], [30, 32, -32, 100, 1908], [1, 3, 100, 102, 16777], [3, 30, 100, 102, 16777], [1, 14, 102, 90, 16777], [14, 30, 102, 90, 16776], [30, 32, 100, 90, 2321], [32, 35, -

In [21]:
a = [[1,2,3],[2,3,4]]
a = np.reshape(a,(1,-1))
a = [a,9]
print(a[0])

[[1 2 3 2 3 4]]
