In [None]:
import json
import numpy as np
import collections
import copy
from os import listdir
from os.path import isfile, join

In [None]:
import findspark

findspark.init()
from pyspark import SparkContext
import pyspark

conf = pyspark.SparkConf().setAll(
    [
        ("spark.executor.memory", "8g"),
        ("spark.executor.cores", "2"),
        ("spark.executor.instances", "7"),
        ("spark.driver.memory", "32g"),
        ("spark.driver.maxResultSize", "10g"),
    ]
)
sc = SparkContext(conf=conf)

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.sql.types import Row
from pyspark.sql import SparkSession

spark = SparkSession(sc)

In [None]:
def convert_ndarray_back(x):
    x["entityCell"] = np.array(x["entityCell"])
    return x


data_dir = "../../data/"
train_tables = sc.textFile(data_dir + "train_tables.jsonl").map(lambda x: convert_ndarray_back(json.loads(x.strip())))

In [None]:
def get_core_entity_caption_label(x):
    core_entities = set()
    for i, j in zip(*x["entityCell"].nonzero()):
        if j == 0 and j in x["entityColumn"]:
            core_entities.add(x["tableData"][i][j]["surfaceLinks"][0]["target"]["id"])
    return list(core_entities), x["_id"], x["tableCaption"], x["processed_tableHeaders"][0]

In [None]:
from operator import add

In [None]:
table_rdd = train_tables.map(get_core_entity_caption_label)
entity_rdd = table_rdd.flatMap(lambda x: [(z, x[1], x[2], x[3]) for z in x[0]])

In [None]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover

In [None]:
table_df = spark.createDataFrame(table_rdd, ["entities", "table_id", "caption", "header"])

In [None]:
caption_tokenizer = Tokenizer(inputCol="caption", outputCol="caption_term")
header_tokenizer = Tokenizer(inputCol="header", outputCol="header_term")
list_stopwords = StopWordsRemover.loadDefaultStopWords("english")
caption_remover = StopWordsRemover(inputCol="caption_term", outputCol="caption_term_cleaned")
header_remover = StopWordsRemover(inputCol="header_term", outputCol="header_term_cleaned")

In [None]:
list_stopwords

In [None]:
table_df_tokenizered = header_remover.transform(
    header_tokenizer.transform(caption_remover.transform(caption_tokenizer.transform(table_df)))
).select("entities", "table_id", "caption_term_cleaned", "header_term_cleaned", "header")

In [None]:
table_df_tokenizered.show()

In [None]:
caption_term_freq = (
    table_df_tokenizered.select("caption_term_cleaned")
    .rdd.flatMap(lambda x: [(z, 1) for z in x["caption_term_cleaned"]])
    .reduceByKey(add)
    .collect()
)
header_term_freq = (
    table_df_tokenizered.select("header_term_cleaned")
    .rdd.flatMap(lambda x: [(z, 1) for z in x["header_term_cleaned"]])
    .reduceByKey(add)
    .collect()
)
header_freq = table_df_tokenizered.select("header").rdd.map(lambda x: (x["header"], 1)).reduceByKey(add).collect()

In [None]:
len(header_freq)

In [None]:
entity_df = table_df_tokenizered.select(
    F.explode("entities").alias("entity"), "table_id", "caption_term_cleaned", "header_term_cleaned", "header"
)

In [None]:
entity_caption_term_freq = (
    entity_df.select("entity", "caption_term_cleaned")
    .rdd.flatMap(lambda x: [((x["entity"], z), 1) for z in x["caption_term_cleaned"]])
    .reduceByKey(add)
    .map(lambda x: (x[0][0], [(x[0][1], x[1])]))
    .reduceByKey(add)
    .collect()
)
entity_header_term_freq = (
    entity_df.select("entity", "header_term_cleaned")
    .rdd.flatMap(lambda x: [((x["entity"], z), 1) for z in x["header_term_cleaned"]])
    .reduceByKey(add)
    .map(lambda x: (x[0][0], [(x[0][1], x[1])]))
    .reduceByKey(add)
    .collect()
)
entity_header_freq = (
    entity_df.select("entity", "header")
    .rdd.map(lambda x: ((x["entity"], x["header"]), 1))
    .reduceByKey(add)
    .map(lambda x: (x[0][0], [(x[0][1], x[1])]))
    .reduceByKey(add)
    .collect()
)

In [None]:
entity_tables = (
    entity_df.select("entity", "table_id")
    .groupBy("entity")
    .agg(F.collect_list("table_id").alias("tables"))
    .rdd.map(lambda x: (x["entity"], x["tables"]))
    .collect()
)

In [None]:
import pickle

In [None]:
with open("../../data/entity_tables.pkl", "wb") as f:
    pickle.dump(entity_tables, f)

In [None]:
for e in entity_header_freq:
    entity_header_freq[e] = [sum([count for _, count in entity_header_freq[e]]), dict(entity_header_freq[e])]

with open("../../data/entity_header_freq.pkl", "wb") as f:
    pickle.dump(entity_header_freq, f)

In [None]:
entity_header_term_freq = dict(entity_header_term_freq)
for e in entity_header_term_freq:
    entity_header_term_freq[e] = [
        sum([count for _, count in entity_header_term_freq[e]]),
        dict(entity_header_term_freq[e]),
    ]

with open("../../data/entity_header_term_freq.pkl", "wb") as f:
    pickle.dump(entity_header_term_freq, f)

In [None]:
entity_caption_term_freq = dict(entity_caption_term_freq)
for e in entity_caption_term_freq:
    entity_caption_term_freq[e] = [
        sum([count for _, count in entity_caption_term_freq[e]]),
        dict(entity_caption_term_freq[e]),
    ]

with open("../../data/entity_caption_term_freq.pkl", "wb") as f:
    pickle.dump(entity_caption_term_freq, f)

In [None]:
caption_term_freq = dict(caption_term_freq)
with open("../../data/caption_term_freq.pkl", "wb") as f:
    pickle.dump([sum([count for _, count in caption_term_freq.items()]), caption_term_freq], f)

header_term_freq = dict(header_term_freq)
with open("../../data/header_term_freq.pkl", "wb") as f:
    pickle.dump([sum([count for _, count in header_term_freq.items()]), header_term_freq], f)

header_freq = dict(header_freq)
with open("../../data/header_freq.pkl", "wb") as f:
    pickle.dump([sum([count for _, count in header_freq.items()]), header_freq], f)

In [None]:
for e in entity_tables:
    if len(entity_tables[e]) != sum([count for _, count in entity_header_freq[e]]):
        print(e, len(entity_tables[e]), sum([count for _, count in entity_header_freq[e]]))
        break

In [None]:
caption_term_freq[0]

In [None]:
entity_header_freq[1677]

In [None]:
entity_rdd.filter(lambda x: x[0] == 5839439).take(10)

In [None]:
from metric import *

In [None]:
with open("../../data/dev_result.pkl", "rb") as f:
    dev_result = pickle.load(f)

In [None]:
def load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=1):
    entity_vocab = {}
    bad_title = 0
    few_entity = 0
    with open(os.path.join(data_dir, "entity_vocab.txt"), "r", encoding="utf-8") as f:
        for line in f:
            _, entity_id, entity_title, entity_mid, count = line.strip().split("\t")
            if ignore_bad_title and entity_title == "":
                bad_title += 1
            elif int(count) < min_ent_count:
                few_entity += 1
            else:
                entity_vocab[len(entity_vocab)] = {
                    "wiki_id": int(entity_id),
                    "wiki_title": entity_title,
                    "mid": entity_mid,
                    "count": int(count),
                }
    print(
        "total number of entity: %d\nremove because of empty title: %d\nremove because count<%d: %d"
        % (len(entity_vocab), bad_title, min_ent_count, few_entity)
    )
    return entity_vocab

In [None]:
entity_vocab = load_entity_vocab("../../data", True, 2)
train_all_entities = set([x["wiki_id"] for _, x in entity_vocab.items()])

In [None]:
dev_final = {}
for id, result in dev_result.items():
    _, target_entities, pneural, pall, pee, pce, ple, cand_e, cand_c = result
    target_entities = set(target_entities)
    cand_e = set([e for e in cand_e if e in train_all_entities])
    cand_c = set([e for e in cand_c if e in train_all_entities])
    cand_all = set([e for e in cand_c | cand_e if e in train_all_entities])
    recall_e = len(cand_e & target_entities) / len(target_entities)
    recall_c = len(cand_c & target_entities) / len(target_entities)
    recall_all = len(cand_all & target_entities) / len(target_entities)

    ranked_neural = sorted(pneural.items(), key=lambda z: z[1] + 30 * pee[z[0]], reverse=True)
    ranked_neural = [1 if z[0] in target_entities else 0 for z in ranked_neural if z[0] in train_all_entities]
    ap_neural = average_precision(ranked_neural)

    ranked_all = sorted(pall.items(), key=lambda z: 100 * pee[z[0]] + 1 * pce[z[0]] + 0.5 * ple[z[0]], reverse=True)
    ranked_all = [1 if z[0] in target_entities else 0 for z in ranked_all if z[0] in train_all_entities]
    ap_all = average_precision(ranked_all)

    #     ranked_e = sorted(pee.items(),key=lambda z:z[1],reverse=True)
    #     ranked_e = [1 if z[0] in target_entities else 0 for z in ranked_e if z[0] in train_all_entities]
    #     assert len(ranked_e) == len(ranked_neural)
    #     ap_e = average_precision(ranked_e)

    #     ranked_c = sorted(pce.items(),key=lambda z:z[1],reverse=True)
    #     ap_c = average_precision([1 if z[0] in target_entities else 0 for z in ranked_c if z[0] in train_all_entities])

    #     ranked_l = sorted(ple.items(),key=lambda z:z[1],reverse=True)
    #     ap_l = average_precision([1 if z[0] in target_entities else 0 for z in ranked_l if z[0] in train_all_entities])

    dev_final[id] = [recall_all, recall_e, recall_c, ap_neural, ap_all, ap_e, ap_c, ap_l]

for i in range(8):
    print(np.mean([z[i] for _, z in dev_final.items()]))

In [None]:
dev_result["13591903-1"][2]

In [None]:
len([1 for z in dev_final if z[4] >= z[5]])

In [None]:
[(i, z[3], z[4], z[5]) for i, z in dev_final.items() if z[4] >= z[5]]