In [5]:
import import_ipynb # this can only be installed from pip (no conda)
from QueryTree import *

In [None]:
# import import_ipynb # this can only be installed from pip (no conda)
# from LKD3 import *

# = = = LKD Advanced Version 3 = = =

def LKDAdvanced3(dataset, query, domains, threshold, level, current_dim = 0):

    '''
    This is the 3rd version of the LKD.
    Main differences:
      1. Retrieve query for the current domain. (compared with version 1)
      2. Consider the #queries cancelled when calculating Min-Median-Distance. (compared with version 2)
      
    Parameters:
      @dataset[i,k]: numpy object, i: ith record  k: kth dimension
      @query[i,k,0/1]: numpy object, i: ith record  k: kth dimension  0/1: min/max 
      @domain[k][0/1]: numpy object, k: kth dimension  0/1: min/max
      @threshold: the minimum number of records required in a partition
      @level: the current level of the split (in KD-Tree), root is 0
      @current_dim: the last split dimension
      
    Return:
      @kdnodes[i][0/1][k][0/1]: array object,  i: ith record  0/1: domain/size  k: kth dimension  0/1: min/max 
    
    '''
    # check if the threshold is already satisfied
    total_size = len(dataset)
    
    #print("level: ",level, "  size: ", total_size)
    if total_size <= 2*threshold:
        kdnodes = []
        kdnodes.append([domains, total_size])
        return kdnodes
    
    # create query tree
    query_trees = []
    for i in range(len(domains)):
        qtree = QueryTree(i)
        qtree.loadQuerySetFromQueries(query)
        qtree.buildQueryTree()
        query_trees.append(qtree)
        
    split_from_median_position_tag = False
    
    removed_queries_each_dim = [0 for D in range(len(domains))] # newly added variable for version 3, initialized as 0
    query_remove_cost_each_dim = [0 for D in range(len(domains))]
    
    while True:
        
        # newly added variable for version 3, calculate the added cost, i.e., cost for execute the query
        query_remove_cost_each_dim = [removed_queries_each_dim[D]*total_size for D in range(len(removed_queries_each_dim))]
        
        split_from_median_position_tag = False
        
        split_distance_each_dim = []
        split_value_each_dim = []
        caches = []
        medians = []

        # for each dimension, we calculated the distance from median to its first non-cross split
        for D in range(len(domains)):

            # median, with fast median algorithm
            median = np.median(dataset[:,D]) # the median value
            split_value = median # by default (i.e., without shift), is median

            # split distance
            median_shift_distance_lower = 0
            median_shift_distance_upper = 0
            min_median_distance = 0
            
            if len(query_trees[D].node_dict) == 0:
                split_distance_each_dim.append(min_median_distance)
                split_value_each_dim.append(split_value)
                medians.append(median)
                continue
            
            # check if the default split position intersect some query
            is_overlap, cache = query_trees[D].queryValue(median)

            # if overlap, find out the shift distance
            if is_overlap:
                overlap_query_lower = cache[1]
                overlap_query_upper = cache[2]

                # check if the 2 ends exceeds the current domain
                if overlap_query_lower <= domains[D][0] and overlap_query_upper >= domains[D][1]: # if yes
                    median_shift_distance_lower = int(total_size / 2)
                    median_shift_distance_upper = int(total_size / 2)
                    min_median_distance = int(total_size / 2)
                else: # if not
                    median_shift_distance_lower = len(dataset[(dataset[:,D]>=overlap_query_lower) & (dataset[:,D] < median)])
                    median_shift_distance_upper = len(dataset[(dataset[:,D]<=overlap_query_upper) & (dataset[:,D] > median)])
                    min_median_distance = min(median_shift_distance_lower, median_shift_distance_upper)
                    if median_shift_distance_lower < median_shift_distance_upper:
                        split_value = overlap_query_lower
                    else:
                        split_value = overlap_query_upper

            # record the split shift (i.e., min median distance) and split value
            split_distance_each_dim.append(min_median_distance)
            split_value_each_dim.append(split_value)
            caches.append(cache)
            medians.append(median)

        # aftern calculating the min median distance for each dimension
        split_distance_each_dim = np.asarray(split_distance_each_dim)
        split_dimension = 0
        split_value = 0
        
        
        # consider the split priority using added cost first
        if max(query_remove_cost_each_dim) == 0:
            pass
        else:
            # has been through diveIn operation
            successful_split_flag = False
            split_order = np.argsort(query_remove_cost_each_dim)  # sort, ascending
            for i in range(len(query_remove_cost_each_dim)):
                split_dimension = np.where(split_order==i)[0][0] # get the ith smallest
#                 print(split_dimension)
                split_value = medians[split_dimension]
                # if the partition is small and able to split, to avoid redundant reocrds that make it unsplitable by value
                if total_size < 3*threshold and split_value_each_dim[split_dimension] == 0:
                    # split from median directly
                    split_from_median_position_tag = True
                    break
                # check if the subnodes greater than threshold
                sub_dataset1 = dataset[dataset[:,split_dimension] <= split_value]
                sub_dataset2 = dataset[dataset[:,split_dimension] > split_value]
                if len(sub_dataset1) < threshold or len(sub_dataset2) < threshold:
                    continue 
                else:
                    successful_split_flag = True
                    break
        
        if split_from_median_position_tag:
            break
        
        # Using Advanced Split
        # if every dimension is not able to split!
        if min(split_distance_each_dim) >= int((total_size / 2)-5):
            for D in range(len(domains)):
                removed_queries = query_trees[D].diveIn(caches[D][0], medians[D])
                removed_queries_each_dim[D] += removed_queries  # newly added variable for version 3
            continue
        
        # degradation mechansim (if every dimension is valid to split, then using round robin)
        # the 5 here is an error tolerance
        if max(split_distance_each_dim) <= 5: # this should have the 2 sub partitions above the threshold size
            split_dimension = current_dim + 1
            if split_dimension >= len(domains):
                split_dimension %= len(domains)
            split_value = np.median(dataset[:,split_dimension])
            break # jump to the split
        
        # the normal case
        successful_split_flag = False
        split_order = np.argsort(split_distance_each_dim)  # sort, ascending
        for i in range(len(split_distance_each_dim)):
            split_dimension = np.where(split_order==i)[0][0] # get the ith smallest
            split_value = split_value_each_dim[split_dimension]
            # if the partition is small and able to split, to avoid redundant reocrds that make it unsplitable by value
            if total_size < 3*threshold and split_value_each_dim[split_dimension] == 0:
                # split from median directly
                split_from_median_position_tag = True
                break
            # check if the subnodes greater than threshold
            sub_dataset1 = dataset[dataset[:,split_dimension] <= split_value]
            sub_dataset2 = dataset[dataset[:,split_dimension] > split_value]
            if len(sub_dataset1) < threshold or len(sub_dataset2) < threshold:
                continue 
            else:
                successful_split_flag = True
                break
        
        if split_from_median_position_tag:
            break
        
        # Using Advanced Split if none of the above split can create legal sub partitions
        if successful_split_flag:
            break # jump to the split
        else: # Using Advanced Split
            dive_count = 0
            for D in range(len(domains)):
                if len(caches[D]) != 0:
                    removed_queries = query_trees[D].diveIn(caches[D][0], medians[D])
                    removed_queries_each_dim[D] += removed_queries # newly added variable for version 3
                    dive_count += 1
            if dive_count == 0: # indicate none-of the split is OK
                kdnodes = []
                kdnodes.append([domains, total_size])
                return kdnodes
            continue
    
    
    sub_dataset1 = []
    sub_dataset2 = []
    if split_from_median_position_tag:
        dataset = dataset[np.argsort((dataset[:,split_dimension]))]
        sub_dataset1 = dataset[0:int(total_size/2)]
        sub_dataset2 = dataset[int(total_size/2):-1]
    else:
        # split the dataset according to the split position
        sub_dataset1 = dataset[dataset[:,split_dimension] <= split_value]
        sub_dataset2 = dataset[dataset[:,split_dimension] > split_value]

    # change the domains
    sub_domains1 = np.copy(domains)
    sub_domains1[split_dimension][1] = split_value
    sub_domains2 = np.copy(domains)
    sub_domains2[split_dimension][0] = split_value

    # filter the queries for each sub node
    sub_query1 = query[query[:,split_dimension,0] < split_value]
    sub_query2 = query[query[:,split_dimension,1] > split_value]
    
    # used to see the current depth
    level += 1

    # recursion
    kdnodes = []
    kdnodes.extend(LKDAdvanced3(sub_dataset1, sub_query1, sub_domains1, threshold, level, split_dimension))
    kdnodes.extend(LKDAdvanced3(sub_dataset2, sub_query2, sub_domains2, threshold, level, split_dimension))
    
#     if level == 1:
#         print(len(query_trees[0].primary_nodes))
#         print(len(query_trees[0].node_dict))
#         print(len(query_trees[1].primary_nodes))
#         print(len(query_trees[1].node_dict))
    return kdnodes