In [None]:
from pyspark.sql.types import BinaryType, StringType, ArrayType
from pyspark.sql.functions import udf, row_number, monotonically_increasing_id, col, collect_list, concat_ws, explode
from pyspark.sql import SparkSession, Window

from protocol_buffers import document_pb2

import pickle

In [None]:
spark_drive_gbs = 3
spark_executor_gbs = 50
cores = 14

para_path = '/nfs/trec_car/data/test_entity/full_data_v3_with_datasets_contents_v4/'


In [None]:
print('\n//////// RUNNING WITH CORES {} //////////'.format(cores))
spark = SparkSession.\
    builder\
    .appName('test')\
    .master('local[{}]'.format(cores)) \
    .config("spark.driver.memory", '{}g'.format(spark_drive_gbs)) \
    .config("spark.executor.memory", '{}g'.format(spark_executor_gbs)) \
    .config("spark.driver.maxResultSize", '{}g'.format(spark_drive_gbs)) \
    .getOrCreate()

In [None]:
dataset_metadata = {
    'entity_train':
        (
        '/nfs/trec_car/data/entity_ranking/multi_task_data/entity_train_all_queries_BM25_1000.run',
        '/nfs/trec_car/data/entity_ranking/benchmarkY1_hierarchical_entity_train_data/benchmarkY1_train_entity.qrels'),

    'entity_dev':
        ('/nfs/trec_car/data/entity_ranking/multi_task_data/entity_dev_all_queries_BM25_1000.run',
         '/nfs/trec_car/data/entity_ranking/benchmarkY1_hierarchical_entity_dev_data/benchmarkY1_dev_entity.qrels'),

    'entity_test':
        ('/nfs/trec_car/data/entity_ranking/multi_task_data/entity_test_all_queries_BM25_1000.run',
         '/nfs/trec_car/data/entity_ranking/testY1_hierarchical_entity_data/testY1_hierarchical_entity.qrels'),

    'passage_train':
        (
        '/nfs/trec_car/data/entity_ranking/benchmarkY1_hierarchical_passage_train_data/benchmarkY1_train_passage_1000.run',
        '/nfs/trec_car/data/entity_ranking/benchmarkY1_hierarchical_passage_train_data/benchmarkY1_train_passage.qrels'),

    'passage_dev':
        ('/nfs/trec_car/data/entity_ranking/benchmarkY1_hierarchical_passage_dev_data/benchmarkY1_dev_passage_1000.run',
         '/nfs/trec_car/data/entity_ranking/benchmarkY1_hierarchical_passage_dev_data/benchmarkY1_dev_passage.qrels'),

    'passage_test':
        ('/nfs/trec_car/data/entity_ranking/testY1_hierarchical_passage_data/testY1_hierarchical_passage_1000.run',
         '/nfs/trec_car/data/entity_ranking/testY1_hierarchical_passage_data/testY1_hierarchical_passage.qrels')
} 

In [None]:
df = spark.read.parquet(para_path)
df.printSchema()

In [None]:
@udf(returnType=ArrayType(StringType()))
def get_ents(content_bytearray):
    synthetic_entity_links = document_pb2.DocumentContent().FromString(pickle.loads(content_bytearray)).synthetic_entity_links
    entity_links = []
    for synthetic_entity_link in synthetic_entity_links:
        entity_links.append(str(synthetic_entity_link.entity_id))
    return entity_links

df_entity = df.withColumn("entities", get_ents("content_bytearray"))
df_entity.printSchema()
            
    
    

In [None]:
base_dir = '/nfs/trec_car/data/entity_ranking/multi_task_data_by_query/'
max_rank = 100

for dataset in ['dev', 'test', 'train']:
    dateset_dir = base_dir + '{}_data/'.format(dataset)
    passage_name = 'passage' + '_{}'.format(dataset)
    passage_path = dataset_metadata[passage_name][0]
    
    print('Building passage->entity mappings for {}: {}'.format(dataset, passage_path))
    run_dict = {}
    doc_ids_list = []
    with open(passage_path, 'r') as f:
        for line in f:
                
            query = line.split()[0]
            doc_id = line.split()[2]
            rank = int(line.split()[3])
            
            if rank <= max_rank:

                if query not in run_dict:
                    run_dict[query] = []
                run_dict[query].append(doc_id)
                doc_ids_list.append(doc_id)
                
    query_list  = sorted(list(run_dict.keys()))
    
    doc_ids_list = list(set(doc_ids_list))
    print("doc_ids_list len = {}".format(len(doc_ids_list)))
    dataset_map = df_entity[df_entity['content_id'].isin(doc_ids_list)].select("content_id", "entities").set_index('content_id').to_dict()
    print("doc_ids_list len = {}".format(len(dataset_map)))

    print(dataset_map)
    
#     df_entity = 
#     for query_i, query in enumerate(query_list):
#         print("Processing {} ({} / {})".format(query, query_i+1, len(query_list)))
        
#         path = base_dir + '{}_entities.json'.format(query_i)
#         query_json = {}
#         df_dataset = df_entity[df_entity['content_id'].isin(run_dict[query])].select("content_id", "entities")
#         for row in df_dataset.collect():
#             print(row)
