In [1]:
import findspark
findspark.init() # this must be executed before the below import

In [2]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
from pyspark import SparkFiles

In [3]:
import pandas as pd
import time
import rtree
from rtree import index
import numpy as np
from numpy import genfromtxt
from multiprocessing import Pool
import threading

In [4]:
from DRProcess import *
from DDProcess import *

In [5]:
conf = SparkConf().setAll([("spark.executor.memory", "8g"),("spark.driver.memory","8g"),
                           ("spark.memory.offHeap.enabled",True),("spark.memory.offHeap.size","8g")])

sc = SparkContext(conf=conf)
sqlContext = SQLContext(sc)

In [6]:
sc.getConf().getAll()

[('spark.app.id', 'local-1595604887433'),
 ('spark.memory.offHeap.size', '8g'),
 ('spark.executor.id', 'driver'),
 ('spark.driver.host', '10.0.2.15'),
 ('spark.app.name', 'pyspark-shell'),
 ('spark.rdd.compress', 'True'),
 ('spark.driver.port', '42805'),
 ('spark.driver.memory', '8g'),
 ('spark.serializer.objectStreamReset', '100'),
 ('spark.master', 'local[*]'),
 ('spark.executor.memory', '8g'),
 ('spark.submit.pyFiles', ''),
 ('spark.submit.deployMode', 'client'),
 ('spark.memory.offHeap.enabled', 'True'),
 ('spark.ui.showConsoleProgress', 'true')]

In [7]:
class DumpThread(threading.Thread):
    def __init__(self, thread_id, name, parameters):
        threading.Thread.__init__(self)
        self.thread_id = thread_id
        self.name = name
        self.parameters = parameters
        
    def run(self):
        print('start dumping thread: ', self.thread_id, self.name)
        start_index, end_index, pids, pid_data_dict, hdfs_path, column_names = self.parameters
        for pid in pids[start_index: end_index]:
            path = hdfs_path + 'partition_' + str(pid)+'.parquet'
            pdf = pd.DataFrame(pid_data_dict[pid], columns=column_names)
            df = sqlContext.createDataFrame(pdf)
            df.write.mode('append').parquet(path)
            pid_data_dict[pid] = []
        print('exit dumping thread: ', self.thread_id, self.name)
        
def dump_dict_data_2_hdfs(pid_data_dicts, column_names, hdfs_path, num_threads = 8):
    
    # first merge all the dicts
    base_dict = pid_data_dicts[0]
    for k in range(1, len(pid_data_dicts)):
        for key, val in pid_data_dicts[k].items():
            if key in base_dict:
                base_dict[key] += val
            else:
                base_dict.update({key:val})
        pid_data_dicts[k].clear()
    
    if num_threads == 1:
        print('start dumping single thread (main)')
        pids = list(base_dict.keys())
        for pid in pids:
            path = hdfs_path + 'partition_' + str(pid)+'.parquet'
            pdf = pd.DataFrame(base_dict[pid], columns=column_names)
            df = sqlContext.createDataFrame(pdf)
            df.write.mode('append').parquet(path)
            base_dict[pid] = []
        print('finish dumping single thread (main)')
    
    else:
        # apply multi-threading to save
        pids = list(base_dict.keys())
        step = int(len(pids) / num_threads) + 1
        threads = []
        for i in range(num_threads):
            start_index = i * step
            end_index = (i+1) * step
            parameters = [start_index, end_index, pids, base_dict, hdfs_path, column_names]
            thread = DumpThread(i, 'dump_thread_'+str(i), parameters)
            thread.start()
            threads.append(thread)
            if start_index >= len(pids):
                break   
        for t in threads:
            t.join()

In [8]:
def batch_data_parallel(table_path, partition_path, chunk_size, used_dims, hdfs_path, 
                        num_dims, dump_threshold = 1000000, num_process = 8):
    
    begin_time = time.time()
    
    col_names = ['_c'+str(i) for i in range(num_dims)]
    cols = [i for i in range(num_dims)]
    
    pid_data_dicts = []
    for i in range(num_process):
        pid_data_dicts.append({})
    
    chunks = []
    
    count = 0
    epochs = 0
    processed_data = 0
    pool = Pool(processes = num_process) # the pool should be reused, or incur memory leak!
    
    #for chunk in pd.read_table(table_path, delimiter='|', usecols=cols, names=col_names, chunksize=chunk_size):
    for chunk in pd.read_csv(table_path, usecols=cols, names=col_names, chunksize=chunk_size):
        print('current chunk: ', count)
        chunks.append(chunk)
        if count % num_process == num_process - 1:
            paras = [[chunks[k], used_dims, partition_path, pid_data_dicts[k]] for k in range(num_process)]
            pid_data_dicts = pool.map(process_chunk, [para for para in paras])
            print('===================================================')
            chunks = []
            processed_data += chunk_size * num_process
            if processed_data >= dump_threshold:
                paras = [[pid_data_dicts[k], col_names, hdfs_path] for k in range(num_process)]
                pool.map(dump_data, [para for para in paras])
                #dump_dict_data_2_hdfs(pid_data_dicts, col_names, hdfs_path)
                processed_data = 0
                for i in range(num_process):
                    pid_data_dicts[i].clear()
        count += 1
        
    dict_size = [len(pid_data_dicts[i]) for i in range(num_process)]
    print('after exit, chunks size: ', len(chunks))
    print('after exit, each dict size: ', dict_size)
    # process the last batch
    if len(chunks) != 0:
        paras = [[chunks[k], used_dims, partition_path, pid_data_dicts[k]] for k in range(len(chunks))]
        pid_data_dicts[0:len(chunks)] = pool.map(process_chunk, [para for para in paras])
    pool.close()
    pool.join()
    
    dict_size = [len(pid_data_dicts[i]) for i in range(num_process)]
    print('after last chunk, each dict size: ', dict_size)
    
    if len(pid_data_dicts[0]) != 0:
        paras = [[pid_data_dicts[k], col_names, hdfs_path] for k in range(num_process)]
        pool.map(dump_data, [para for para in paras])
        #dump_dict_data_2_hdfs(pid_data_dicts, col_names, hdfs_path)
    
    finish_time = time.time()
    print('total data routing and persisting time: ', finish_time - begin_time)

In [11]:
# = = = Execution = = =
table_path = '/home/cloudray/Downloads/TPCH_12M_8Field.csv'
# table_path = '/home/cloudray/TPCH/2.18.0_rc2/dbgen/lineitem.tbl'

# partition_path = '/home/cloudray/NORA_Partitions/nora_partitions'
partition_path = '/home/cloudray/NORA_Partitions/qd_tree_partitions'
num_process = 8
num_dims = 8
chunk_size = 10000 
dump_threshold = 80000
used_dims = [1,2]
# hdfs_path = 'hdfs://localhost:9000/user/cloudray/NORA/'
# hdfs_path = 'hdfs://localhost:9000/user/cloudray/QdTree/'
hdfs_path = 'hdfs://localhost:9000/user/cloudray/QdTreeTest/'

batch_data_parallel(table_path, partition_path, chunk_size, used_dims, hdfs_path,
                    num_dims, dump_threshold, num_process)

In [None]:
# total data routing and persisting time:  1398.6325314044952 # Nora
# total data routing and persisting time:  1193.6831953525543 # QdTree
# total data routing and persisting time:  1245.9338216781616 # QdTree Test#
# Try multi-process