In [None]:
from pyspark.sql.types import BinaryType, StringType, ArrayType
from pyspark.sql.functions import udf, row_number, monotonically_increasing_id, col, collect_list, concat_ws, explode
from pyspark.sql import SparkSession, Window

from protocol_buffers import document_pb2

import pickle

In [None]:
spark_drive_gbs = 3
spark_executor_gbs = 50
cores = 10

path_2018 = '/nfs/trec_news_track/index/2018_bm25_chunks_full_v2/'
path_2019 = '/nfs/trec_news_track/index/2019_bm25_chunks_full_v1/'
path_2018_rm3 = '/nfs/trec_news_track/index/2018_bm25_rm3_chunks_full_v1/'
path_2019_rm3 = '/nfs/trec_news_track/index/2019_bm25_rm3_chunks_full_v2/'

In [None]:
print('\n//////// RUNNING WITH CORES {} //////////'.format(cores))
spark = SparkSession.\
    builder\
    .appName('test')\
    .master('local[{}]'.format(cores)) \
    .config("spark.driver.memory", '{}g'.format(spark_drive_gbs)) \
    .config("spark.executor.memory", '{}g'.format(spark_executor_gbs)) \
    .config("spark.driver.maxResultSize", '{}g'.format(spark_drive_gbs)) \
    .getOrCreate()

In [None]:

doc_to_ent_map = {}

for path in [path_2018, path_2019, path_2018_rm3, path_2019_rm3]:

    df = spark.read.parquet(path)
    df.printSchema()

    @udf(returnType=ArrayType(StringType()))
    def get_ents(article_bytearray):
        rel_entity_link_totals = document_pb2.Document().FromString(pickle.loads(article_bytearray)).rel_entity_link_totals
        entity_links = []
        for rel_entity_link_total in rel_entity_link_totals:
            entity_id = str(rel_entity_link_total.entity_id)
            frequency = 0
            for anchor_text_frequency in rel_entity_link_total.anchor_text_frequencies:
                frequency += int(anchor_text_frequency.frequency)
            for i in range(frequency):
                entity_links.append(entity_id)
        return entity_links

    df_entity = df.withColumn("entities", get_ents("article_bytearray"))
    df_entity_reduced = df_entity.select("doc_id", "entities")
    df_entity_reduced.printSchema()
    for data in df_entity_reduced.collect():
        doc_id = data[0]
        entities = data[1]
        if doc_id not in doc_to_ent_map:
            doc_to_ent_map[doc_id] = entities

In [None]:
import json

path = '/nfs/trec_news_track/data/5_fold/scaled_5fold_0_data/doc_to_entity_map.json'
with open(path, 'w') as f:
    json.dump(doc_to_ent_map, f)