In [1]:
import findspark
findspark.init() # this must be executed before the below import

In [2]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
from pyspark import SparkFiles

In [3]:
import pandas as pd
import time
import rtree
from rtree import index
import numpy as np
from numpy import genfromtxt
import threading

In [4]:
sc = SparkContext()
sqlContext = SQLContext(sc)

In [14]:
def process_chunk_row(row, used_dims, pidx, pid_data_dict):
    row_numpy = row.to_numpy()
    row_used_dims_list = row_numpy[used_dims].tolist()
    row_border = tuple(row_used_dims_list+row_used_dims_list)
    pid = list(pidx.intersection(row_border))[0]
    pid_data_dict[pid].append(row_numpy.tolist())


class DRThread(threading.Thread):
    def __init__(self, thread_id, name, parameters):
        threading.Thread.__init__(self)
        self.thread_id = thread_id
        self.name = name
        self.parameters = parameters
        
    def run(self):
        print('start thread: ', self.thread_id, self.name)
        chunk, used_dims, pidx, pid_data_dict = self.parameters
        chunk.apply(lambda row: process_chunk_row(row, used_dims, pidx, pid_data_dict), axis=1)
        print('exit thread: ', self.thread_id, self.name)
        
class DumpThread(threading.Thread):
    def __init__(self, thread_id, name, parameters):
        threading.Thread.__init__(self)
        self.thread_id = thread_id
        self.name = name
        self.parameters = parameters
        
    def run(self):
        print('start dumping thread: ', self.thread_id, self.name)
        start_index, end_index, pids, pid_data_dict, hdfs_path, column_names = self.parameters
        for pid in pids[start_index: end_index]:
            path = hdfs_path + 'partition_' + str(pid)+'.parquet'
            pdf = pd.DataFrame(pid_data_dict[pid], columns=column_names)
            df = sqlContext.createDataFrame(pdf)
            df.write.mode('append').parquet(path)
        print('exit dumping thread: ', self.thread_id, self.name)

In [15]:
def kdnode_2_border(kdnode):
    lower = [domain[0] for domain in kdnode[0]]
    upper = [domain[1] for domain in kdnode[0]]
    border = tuple(lower + upper) # non interleave
    return border

def load_partitions_from_file(path):
    '''
    the loaded stretched_kdnodes: [num_dims, l1,l2,...,ln, u1,u2,...,un, size, id, pid, left_child,id, right_child_id]
    '''
    stretched_kdnodes = genfromtxt(path, delimiter=',')
    num_dims = int(stretched_kdnodes[0,0])
    kdnodes = []
    for i in range(len(stretched_kdnodes)):
        domains = [ [stretched_kdnodes[i,k+1],stretched_kdnodes[i,1+num_dims+k]] for k in range(num_dims) ]
        row = [domains]
        row.append(stretched_kdnodes[i,2*num_dims+1])
        # to be compatible with qd-tree's partition, that do not have the last 4 attributes
        if len(stretched_kdnodes[i]) > 2*num_dims+2:
            row.append(stretched_kdnodes[i,-4])
            row.append(stretched_kdnodes[i,-3])
            row.append(stretched_kdnodes[i,-2])
            row.append(stretched_kdnodes[i,-1])
        kdnodes.append(row)
    return kdnodes

# def dump_data_thread(start_index, end_index, pids, pid_data_dict, hdfs_path):
#     for pid in pids[start_index, end_index]:
#         path = hdfs_path + 'partition_' + str(pid)+'.parquet'
#         pdf = pd.DataFrame(pid_data_dict[pid], columns=column_names)
#         df = sqlContext.createDataFrame(pdf)
#         df.write.mode('append').parquet(path)

def dump_dict_data_2_hdfs(pid_data_dict, column_names, hdfs_path, num_threads = 8):
    pids = list(pid_data_dict.keys())
    step = int(len(pids) / num_threads) + 1
    threads = []
    for i in range(num_threads):
        start_index = i * step
        end_index = (i+1) * step
        parameters = [start_index, end_index, pids, pid_data_dict, hdfs_path, column_names]
        thread = DumpThread(i, 'dump_thread_'+str(i), parameters)
        thread.start()
        threads.append(thread)
        if start_index >= len(pids):
            break   
    for t in threads:
        t.join()
    

def batch_data_parallel(table_path, partition_path, chunk_size, used_dims, hdfs_path, 
                        num_dims, dump_threshold = 1000000, max_threads = 8):
    
    begin_time = time.time()
    
    col_names = ['_c'+str(i) for i in range(num_dims)]
    cols = [i for i in range(num_dims)]
    
    partitions = load_partitions_from_file(partition_path)
    
    p = index.Property()
    p.leaf_capacity = 32
    p.index_capacity =32
    p.NearMinimumOverlaoFactor = 16
    p.fill_factor = 0.8
    p.overwrite = True
    
    pidxs = [] # the rtree index has problem in mutli-threading, create an index for each thread
    for k in range(num_threads):
        partition_index = index.Index(properties = p)
        for i in range(len(partitions)):
            #partition_index.insert(int(partitions[i][-4]), kdnode_2_border(partitions[i])) 
            partition_index.insert(i, kdnode_2_border(partitions[i]))
        pidxs.append(partition_index)
    
    pid_data_dict = {}
    for i in range(len(partitions)):
        pid_data_dict.update({i:[]})
    
    count = 0
    epochs = 0
    processed_data = 0
    threads = []
    #for chunk in pd.read_table(table_path, delimiter='|', usecols=cols, names=col_names, chunksize=chunk_size):
    for chunk in pd.read_csv(table_path, usecols=cols, names=col_names, chunksize=chunk_size):
        
        print('current chunk: ', count)
        tid = count % max_threads      
        parameters = [chunk, used_dims, pidxs[tid], pid_data_dict]
        thread = DRThread(tid, 'thread_'+str(tid)+'_'+str(count), parameters)
        thread.start()
        threads.append(thread)
        count += 1
        
        if tid == max_threads-1:
            for t in threads:
                t.join()
            threads = []
            epochs += 1
            processed_data += chunk_size * max_threads
            if processed_data >= dump_threshold:
                dump_dict_data_2_hdfs(pid_data_dict, col_names, hdfs_path)
                for key in pid_data_dict.keys():
                    pid_data_dict[key]=[]
                processed_data = 0
                
            print('===================================================')
    dump_dict_data_2_hdfs(pid_data_dict, col_names, hdfs_path) # last batch
    
    finish_time = time.time()
    print('total data routing and persisting time: ', finish_time - begin_time)

In [16]:
# = = = Unit Test = = =
table_path = '/home/cloudray/Downloads/TPCH_12M_8Field.csv'
# table_path = '/home/cloudray/TPCH/2.18.0_rc2/dbgen/lineitem.tbl'
partition_path = '/home/cloudray/NORA_Partitions/qd_tree_partitions'
num_threads = 8
num_dims = 8
chunk_size = 10000
used_dims = [1,2]
hdfs_path = 'hdfs://localhost:9000/user/cloudray/QdTree/'

batch_data_parallel(table_path, partition_path, chunk_size, used_dims, hdfs_path,
                    num_dims, dump_threshold = 1000000, max_threads = 8)

current chunk:  0
start thread:  0 thread_0_0
current chunk:  1
start thread:  1 thread_1_1
current chunk:  2
start thread:  2 thread_2_2
current chunk:  3
start thread:  3 thread_3_3
current chunk:  4
start thread:  4 thread_4_4
current chunk:  5
start thread:  5 thread_5_5
current chunk:  6
start thread:  6 thread_6_6
current chunk:  7
start thread:  7 thread_7_7
exit thread:  1 thread_1_1
exit thread:  3 thread_3_3
exit thread:  4 thread_4_4
exit thread:  0 thread_0_0
exit thread:  7 thread_7_7
exit thread:  6 thread_6_6
exit thread:  5 thread_5_5
exit thread:  2 thread_2_2
current chunk:  8
start thread:  0 thread_0_8
current chunk:  9
start thread:  1 thread_1_9
current chunk:  10
start thread:  2 thread_2_10
current chunk:  11
start thread:  3 thread_3_11
current chunk:  12
start thread:  4 thread_4_12
current chunk:  13
start thread:  5 thread_5_13
current chunk:  14
start thread:  6 thread_6_14
current chunk:  15
start thread:  7 thread_7_15
exit thread:  6 thread_6_14
exit thr

KeyboardInterrupt: 

In [None]:
# try concurrent write to hdfs