In [None]:
from pyspark.sql import SparkSession
from pyspark_processing.ranking import get_paragraph_data_from_run_file, get_ranked_entities_from_paragraph_data
from protocol_buffers import document_pb2

from pyspark.sql.types import StringType, ArrayType, FloatType
from pyspark.sql.functions import udf, row_number, explode, desc, col

import pickle
import re

In [None]:
spark_drive_gbs = 20
spark_executor_gbs = 3
cores = 6


In [None]:
print('\n//////// RUNNING WITH CORES {} //////////'.format(cores))
spark = SparkSession.\
    builder\
    .appName('test')\
    .master('local[{}]'.format(cores)) \
    .config("spark.driver.memory", '{}g'.format(spark_drive_gbs)) \
    .config("spark.executor.memory", '{}g'.format(spark_executor_gbs)) \
    .config("spark.driver.maxResultSize", '{}g'.format(spark_drive_gbs)) \
    .getOrCreate()

In [None]:
def get_news_ids_maps(xml_topics_path, rank_type='passage'):
    """ Build dict map from intermediate ids to Washington Post ids {intermediate_id: passage_id} """
    passage_id_map = {}
    entity_id_map = {}
    with open(xml_topics_path, 'r', encoding='utf-8') as f:
        for line in f:
            # Passage intermediate_id
            if '<num>' in line:
                start_i = [m.span() for m in re.finditer('<num> Number: ', line)][0][1]
                end_i = [m.span() for m in re.finditer(' </num>', line)][0][0]
                passage_temp_id = line[start_i:end_i]
            # Passage id
            if '<docid>' in line:
                start_i = [m.span() for m in re.finditer('<docid>', line)][0][1]
                end_i = [m.span() for m in re.finditer('</docid>', line)][0][0]
                passage_id = line[start_i:end_i]
                passage_id_map[passage_temp_id] = passage_id

            if rank_type == 'entity':
                # Entity intermediate_id
                if '<id>' in line:
                    start_i = [m.span() for m in re.finditer('<id> ', line)][0][1]
                    end_i = [m.span() for m in re.finditer(' </id>', line)][0][0]
                    entity_temp_id = line[start_i:end_i]
                # Entity id
                if '<link>' in line:
                    start_i = [m.span() for m in re.finditer('<link>', line)][0][1]
                    end_i = [m.span() for m in re.finditer('</link>', line)][0][0]
                    entity_id = line[start_i:end_i]
                    entity_id_map[entity_temp_id] = entity_id

    return passage_id_map, entity_id_map

def get_top_100_rank(spark, run_path, rank_type='entity', k=100, xml_topics_path=None):
    """"""
    if xml_topics_path == None:
        passage_id_map, entity_id_map = None, None
    else:
        passage_id_map, entity_id_map = get_news_ids_maps(xml_topics_path=xml_topics_path, rank_type=rank_type)

    data = []
    with open(run_path, 'r', encoding='utf-8') as f_run:
        for line in f_run:
            query, _, doc_id, rank, _, _ = line.split()
            if int(rank) <= k:
                if rank_type == 'passage':
                    if passage_id_map != None:
                        if query in passage_id_map:
                            data.append([passage_id_map[str(query)], str(doc_id), int(rank)])
                    else:
                        data.append([str(query), str(doc_id), int(rank)])
                elif rank_type == 'entity':
                    if entity_id_map != None:
                        if query in entity_id_map:
                            data.append([entity_id_map[str(query)], str(doc_id), int(rank)])
                    else:
                        data.append([str(query), str(doc_id), int(rank)])
    return spark.createDataFrame(data, ["query", "doc_id", "{}_rank".format(rank_type)])


@udf(returnType=FloatType())
def get_entity_links_count(entity_link_ids):
    return float(len(entity_link_ids))


@udf(returnType=ArrayType(StringType()))
def get_synthetic_entity_link_ids_passage(article_bytearray):
    """"""
    if article_bytearray == None:
        return []
    article = pickle.loads(article_bytearray)
    synthetic_entity_links = document_pb2.Document.FromString(article).document_contents[0].rel_entity_links
    return [str(s.entity_id) for s in synthetic_entity_links]


@udf(returnType=ArrayType(StringType()))
def get_synthetic_entity_link_ids_entity(doc_bytearray):
    """"""
    if doc_bytearray == None:
        return []
    doc = pickle.loads(doc_bytearray)
    synthetic_entity_links = []
    for document_content in document_pb2.Document.FromString(doc).document_contents:
        synthetic_entity_links += document_content.synthetic_entity_links
    return [str(s.entity_id) for s in synthetic_entity_links]

def get_passage_df(spark, passage_run_path, xml_topics_path, passage_parquet_path):
    """"""
    passage_rank_df = get_top_100_rank(spark=spark,
                                       run_path=passage_run_path,
                                       rank_type='passage',
                                       k=100,
                                       xml_topics_path=xml_topics_path)
    passage_df = spark.read.parquet(passage_parquet_path).select("doc_id", "article_bytearray")
    passage_df_with_entity_links = passage_df.withColumn("entity_links", get_synthetic_entity_link_ids_passage("article_bytearray"))
    passage_df_with_entity_links_counts = passage_df_with_entity_links.withColumn("entity_links_count", get_entity_links_count("entity_links"))
    passage_join_df = passage_rank_df.join(passage_df_with_entity_links_counts, on=['doc_id'], how='left')
    passage_join_df_with_counts_exploded = passage_join_df.select("query", "doc_id", "passage_rank", "entity_links_count", explode("entity_links").alias("entity_id"))
    return passage_join_df_with_counts_exploded


def get_entity_df(spark, entity_run_path, entity_parquet_path, xml_topics_path):
    """"""
    entity_rank_df = get_top_100_rank(spark=spark,
                                      run_path=entity_run_path,
                                      rank_type='entity',
                                      k=100,
                                      xml_topics_path=xml_topics_path)
    print(entity_rank_df)
    entity_df = spark.read.parquet(entity_parquet_path).select(col("page_id").alias("doc_id"), "doc_bytearray")
    entity_df.printSchema
    entity_join_df = entity_rank_df.join(entity_df, on=['doc_id'], how='left')
    #entity_df_with_entity_links = entity_join_df.withColumn("entity_links", get_synthetic_entity_link_ids_entity("doc_bytearray"))
    entity_df_with_entity_links_reduced = entity_join_df.select(col("doc_id").alias("entity_id"), "query", "entity_rank")
    return entity_df_with_entity_links_reduced

def build_news_graph(spark, passage_run_path, passage_xml_topics_path, passage_parquet_path, entity_run_path,
                   entity_parquet_path, entity_xml_topics_path):
    print("BUILDING PASSAGE DF")
    passage_df = get_passage_df(spark=spark,
                                passage_run_path=passage_run_path,
                                xml_topics_path=passage_xml_topics_path,
                                passage_parquet_path=passage_parquet_path)
    passage_df.show()

    print("BUILDING ENTITY DF")
    entity_df = get_entity_df(spark=spark,
                              entity_run_path=entity_run_path,
                              xml_topics_path=entity_xml_topics_path,
                              entity_parquet_path=entity_parquet_path)
    entity_df.show()

    print("JOINING")
    df = passage_df.join(entity_df, on=['query', 'entity_id'], how='left').fillna(0.0)
    df.show()

    print("BUILDING GRAPH WEIGHTS")
    @udf(returnType=FloatType())
    def get_graph_weight(entity_links_count):
        return 1 / entity_links_count

    @udf(returnType=FloatType())
    def get_passage_score(passage_rank):
        return 1 / passage_rank

#     @udf(returnType=FloatType())
#     def get_norm_score(score, norm_factor):
#         if (score == 0) or (norm_factor == 0):
#             return 0.0
#         return score / norm_factor

    @udf(returnType=FloatType())
    def get_entity_score(sum_graph_wight, passage_score):
        return sum_graph_wight * passage_score

    df_weighting = df.withColumn("graph_weigh", get_graph_weight("entity_links_count"))
    df_entity_rank = df_weighting.groupBy("query", "entity_id", "passage_rank").agg({"graph_weigh": "sum"}).fillna(0.0)

    df_norm = df_entity_rank.groupBy("query").agg({"sum(graph_weigh)": "sum"})
    df_entity_rank_with_norm = df_entity_rank.join(df_norm, on=["query"])
    #df_entity_rank_with_norm_score = df_entity_rank_with_norm.withColumn("norm_score", get_norm_score("sum(graph_weigh)", "sum(sum(graph_weigh))"))
    df_entity_rank_with_passage_score = df_entity_rank_with_norm_score.withColumn("entity_score", get_passage_score("sum(sum(graph_weigh))", "passage_rank"))

    df_entity_rank_with_passage_score.show()
    return df_entity_rank_with_passage_score.toPandas().sort_values(["query", "passage_score"], ascending=False)
#df = get_paragraph_data_from_run_file(spark, run_path, para_path, max_counter=10000000000)

In [None]:

passage_run_path = '/nfs/trec_news_track/runs/anserini/background_2018/anserini.bm5.default.run'
passage_xml_topics_path = '/nfs/trec_news_track/data/2018/newsir18-topics.txt'
passage_parquet_path = '/nfs/trec_news_track/index/2018_bm25_rm3_chunks_full_v1/'
entity_run_path = '/nfs/trec_news_track/runs/anserini/entity_2018/entity_ranking_BM25_1000.run'
entity_parquet_path = '/nfs/trec_car/data/test_entity/full_data_v3_with_datasets/'
entity_xml_topics_path = None

build_news_graph(spark=spark, 
                 passage_run_path=passage_run_path, 
                 passage_xml_topics_path=passage_xml_topics_path, 
                 passage_parquet_path=passage_parquet_path, 
                 entity_run_path=entity_run_path,
                 entity_parquet_path=entity_parquet_path, 
                 entity_xml_topics_path=entity_xml_topics_path)