In [None]:
from tqdm.autonotebook import tqdm

import findspark

findspark.init()
from pyspark import SparkContext
import pyspark

conf = pyspark.SparkConf().setAll(
    [
        ("spark.executor.memory", "8g"),
        ("spark.executor.cores", "2"),
        ("spark.executor.instances", "7"),
        ("spark.driver.memory", "150g"),
        ("spark.driver.maxResultSize", "100g"),
    ]
)
sc = SparkContext(conf=conf)

from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType, StringType

from pyspark.sql.types import Row
from pyspark.sql import SparkSession

spark = SparkSession(sc)

import json
import numpy as np
import re

from operator import add

from urllib.parse import unquote

In [None]:
def numpy_describe(array):
    print("count", len(array))
    print("min", np.min(array))
    print("max", np.max(array))
    print("mean", np.mean(array))
    print("std", np.std(array))
    print("10%", np.percentile(array, 10))
    print("25%", np.percentile(array, 25))
    print("50%", np.percentile(array, 50))
    print("60%", np.percentile(array, 60))
    print("75%", np.percentile(array, 75))
    print("80%", np.percentile(array, 80))
    print("90%", np.percentile(array, 90))

In [None]:
# you can create the index-enwiki dump use this library https://github.com/jcklie/wikimapper
wikipedia_wikidata_mapping = (
    spark.read.format("jdbc")
    .options(
        url="jdbc:sqlite:/data/deng.595/workspace/wikimapper/index/index_enwiki-latest.db",
        driver="org.sqlite.JDBC",
        dbtable="mapping",
    )
    .load()
)
wikipedia_wikidata_mapping.show()

In [None]:
# we use dbpedia abstracts and types, so information related to freebase can be ignored
wiki_mid_mapping = spark.createDataFrame(
    sc.textFile("../../../freebase_utils/freebase_dumped/mid2wiki.txt")
    .map(lambda x: x.split())
    .map(lambda x: Row(wikipedia_id=int(x[1]), freebase_mid=x[0]))
)

In [None]:
dbpedia_types = dict(
    spark.createDataFrame(
        sc.textFile("../../../freebase_utils/dbpedia_2019_08_30/instance_types_en.ttl")
        .map(lambda x: x.split())
        .map(
            lambda x: Row(
                wikipedia_title=unquote(x[0][1:-1]).replace("http://dbpedia.org/resource/", ""),
                type=x[2][1:-1].split("/")[-1],
            )
        )
    )
    .join(wikipedia_wikidata_mapping, "wikipedia_title", "inner")
    .rdd.map(lambda x: (x["wikidata_id"], [x["type"]]))
    .reduceByKey(add)
    .collect()
)
print(len(dbpedia_types))

In [None]:
dbpedia_abstract = dict(
    spark.createDataFrame(
        sc.textFile("../../../freebase_utils/dbpedia_2019_08_30/short_abstracts_en.ttl")
        .map(lambda x: re.match(r"(<.+>) (<.+>) (\".+\")", x))
        .filter(lambda x: x is not None)
        .map(
            lambda x: Row(
                wikipedia_title=unquote(x.group(1)[1:-1]).replace("http://dbpedia.org/resource/", ""),
                abstract=x.group(3)[1:-1].replace('\\"', '"'),
            )
        )
    )
    .join(wikipedia_wikidata_mapping, "wikipedia_title", "inner")
    .rdd.map(lambda x: (x["wikidata_id"], x["abstract"]))
    .collect()
)
print(len(dbpedia_abstract))

In [None]:
dbpedia_abstract.show()

In [None]:
fb_en_types = spark.createDataFrame(
    sc.textFile("/data/deng.595/workspace/freebase_utils/freebase_dumped/fb_en_types.txt")
    .map(lambda x: x.split("\t"))
    .map(
        lambda x: Row(
            freebase_mid=x[0],
            types=[
                z
                for z in json.loads(x[1])
                if (not z.startswith("user.") and not z.startswith("base.") and not z.startswith("common."))
            ],
        )
    )
)

In [None]:
# load the raw tables
data_dir = "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/"
train_tables = sc.textFile(data_dir + "train_tables.jsonl").map(lambda x: json.loads(x))
val_tables = sc.textFile(data_dir + "dev_tables.jsonl").map(lambda x: json.loads(x))
test_tables = sc.textFile(data_dir + "test_tables.jsonl").map(lambda x: json.loads(x))

In [None]:
test_tables.map(lambda x: x["_id"]).count()

In [None]:
def get_mentions(table):
    results = []
    entity_columns = table.get("entityColumn", [])
    entity_cells = np.array(table.get("entityCell", [[]]))
    rows = table.get("tableData", {})
    num_rows = len(rows)
    num_columns = len(rows[0])
    entities = set()
    for i in range(num_rows):
        for j in entity_columns:
            if entity_cells[i, j] == 1:
                results.append(
                    Row(
                        table_id=table["_id"],
                        table_pgTitle=table["pgTitle"],
                        i=i,
                        j=j,
                        mention=rows[i][j]["surfaceLinks"][0]["surface"],
                        wikipedia_id=rows[i][j]["surfaceLinks"][0]["target"]["id"],
                        wikipedia_title=rows[i][j]["surfaceLinks"][0]["target"]["title"],
                    )
                )
    return results

In [None]:
# data for ours
train_mentions = spark.createDataFrame(train_tables.flatMap(get_mentions))
val_mentions = spark.createDataFrame(val_tables.flatMap(get_mentions))
test_mentions = spark.createDataFrame(test_tables.flatMap(get_mentions))

In [None]:
train_mentions.show()

In [None]:
# data for wikiGS
wikipedia_gs_entity_mentions = spark.createDataFrame(
    sc.textFile("../../data/entity_linking/WikipediaGS_json/entities_instance")
    .map(json.loads)
    .flatMap(
        lambda x: [
            Row(
                i=z[2],
                tableId=unquote(x["tableId"]),
                is_gs=1,
                table_pgTitle=unquote(x["url"]).split("/")[-1].replace("_", " "),
                wikipedia_title=unquote(z[0]).replace("http://dbpedia.org/resource/", ""),
                mention=z[1],
            )
            for z in x["mappings"]
        ]
    )
)

In [None]:
wikipedia_gs_entity_mentions.show()

In [None]:
def get_gs_mentions(table):
    tableId = unquote(table["tableId"])
    results = []
    rows = table.get("contents", {})
    for i, row in enumerate(rows):
        for j, cell in enumerate(row):
            if "wikiPageId" in cell:
                results.append(
                    Row(tableId=tableId, i=i, j=j, mention=cell["data"], wikipedia_title=unquote(cell["wikiPageId"]))
                )
    return results

In [None]:
def get_gs_context(table):
    x = {}
    x["tableId"] = unquote(table["tableId"])
    x["pgTitle"] = table.get("title", "")
    if x["pgTitle"] is None:
        x["pgTitle"] = ""
    else:
        x["pgTitle"] = x["pgTitle"].replace("- Wikipedia, the free encyclopedia", "")
    x["sectionTitle"] = ""
    x["tableCaption"] = table.get("context", "")
    if x["tableCaption"] is None:
        x["tableCaption"] = ""
    x["tableCaption"] = x["tableCaption"].replace("[edit]", "")
    headers = []
    for i, row in enumerate(table["contents"]):
        for j, cell in enumerate(row):
            if len(headers) <= j:
                headers.append([""])
            if cell.get("isHeader", False):
                headers[j].append(cell["data"])
    x["processed_tableHeaders"] = [" ".join(h) for h in headers]
    return x

In [None]:
wikipedia_gs_tables = sc.textFile("../../data/entity_linking/WikipediaGS_json/tables_instance").map(json.loads)

In [None]:
wikipedia_gs_raw_mentions = spark.createDataFrame(wikipedia_gs_tables.flatMap(get_gs_mentions))

In [None]:
wikipedia_gs_tables = wikipedia_gs_tables.map(get_gs_context)

In [None]:
wikipedia_gs_tables.take(10)

In [None]:
wikipedia_gs_raw_mentions.show()

In [None]:
wikipedia_gs_entity_mentions = wikipedia_gs_entity_mentions.join(
    wikipedia_gs_raw_mentions, ["i", "tableId", "mention", "wikipedia_title"], "inner"
)

In [None]:
print(train_mentions.count())
train_mentions = train_mentions.join(
    wikipedia_gs_entity_mentions.select("table_pgTitle", "is_gs").dropDuplicates(), "table_pgTitle", "left"
).where(F.isnull("is_gs"))
print(train_mentions.count())
print(val_mentions.count())
val_mentions = val_mentions.join(
    wikipedia_gs_entity_mentions.select("table_pgTitle", "is_gs").dropDuplicates(), "table_pgTitle", "left"
).where(F.isnull("is_gs"))
print(val_mentions.count())

In [None]:
print(wikipedia_gs_entity_mentions.select("wikipedia_title").dropDuplicates().count())
print(dbpedia_types.select("wikipedia_title").dropDuplicates().count())
print(
    wikipedia_gs_entity_mentions.select("wikipedia_title")
    .join(dbpedia_types, "wikipedia_title", "inner")
    .select("wikipedia_title")
    .dropDuplicates()
    .count()
)
print(
    wikipedia_gs_entity_mentions.select("wikipedia_title")
    .join(dbpedia_abstract, "wikipedia_title", "inner")
    .select("wikipedia_title")
    .dropDuplicates()
    .count()
)

In [None]:
print(wikipedia_gs_entity_mentions.count())
print(wikipedia_gs_entity_mentions.select("mention").dropDuplicates().count())
wikipedia_gs_entity_mentions.show()

In [None]:
entity_mentions = sc.textFile("../../data/entity_linking/tableMentions.json").map(json.loads)
display(entity_mentions.take(1))
display(entity_mentions.count())

In [None]:
entity_mentions_surface = entity_mentions.map(lambda x: (x["surfaceForm"])).distinct()
print(entity_mentions_surface.count())

In [None]:
entity_mentions_surface = wikipedia_gs_entity_mentions.rdd.map(lambda x: x["mention"]).distinct().collect()
print(len(entity_mentions_surface))

In [None]:
from google.cloud import language
from google.oauth2 import service_account
import urllib.parse
import urllib.request
from multiprocessing import Pool
import time

In [None]:
def wikidata_lookup(query):
    service_url = (
        "https://www.wikidata.org/w/api.php?action=wbsearchentities&search={}&language=en&limit=50&format=json"
    )
    url = service_url.format(urllib.parse.quote(query))
    for i in range(3):
        try:
            response = urllib.request.urlopen(url)
        except urllib.error.HTTPError as e:
            if e.code == 429 or e.code == 503:
                response = e.code
                time.sleep(1)
                continue
            else:
                response = e.code
                break
        except urllib.error.URLError as e:
            response = None
            break
        else:
            response = json.loads(response.read())
            break
    #     if isinstance(response, dict):
    #         response = [[z.get('id'),z.get('label'),z.get('description')] for z in response.get('search', [])]
    return [query, response]

In [None]:
wikidata_lookup("Michael Grant")

In [None]:
entity_wikidata_candidates = []

In [None]:
if entity_wikidata_candidates is not None:
    i = len(entity_wikidata_candidates)
else:
    entity_wikidata_candidates = []
    i = 0
pool = Pool(processes=16)
while i < len(entity_mentions_surface):
    print(i)
    tmp = list(tqdm(pool.imap(wikidata_lookup, entity_mentions_surface[i : i + 10000], chunksize=150), total=10000))
    entity_wikidata_candidates.extend(tmp)
    i += 10000
pool.close()

In [None]:
len(entity_wikidata_candidates)

In [None]:
import re

entity_mentions_surface_normed_0 = [
    re.sub("^\W|\W$", "", x) for x in entity_mentions_surface if re.sub("^\W|\W$", "", x) != x
]
entity_wikidata_candidates_normed_0 = []
i = 0
pool = Pool(processes=16)
while i < len(entity_mentions_surface_normed_0):
    print(i)
    tmp = list(
        tqdm(pool.imap(wikidata_lookup, entity_mentions_surface_normed_0[i : i + 10000], chunksize=300), total=10000)
    )
    entity_wikidata_candidates_normed_0.extend(tmp)
    i += 10000
pool.close()

In [None]:
missing_wikidata_candidates = []
i = 0
pool = Pool(processes=16)
while i < len(missing_mentions):
    print(i)
    tmp = list(tqdm(pool.imap(wikidata_lookup, missing_mentions[i : i + 10000], chunksize=300), total=10000))
    missing_wikidata_candidates.extend(tmp)
    i += 10000
pool.close()

In [None]:
len(entity_mentions_surface_normed_0) / len(entity_mentions_surface)

In [None]:
entity_wikidata_candidates_normed_0_dict = {
    x[0]: x[1] for x in entity_wikidata_candidates_normed_0 if (isinstance(x[1], list) and len(x[1]) != 0)
}

In [None]:
for i, x in enumerate(entity_wikidata_candidates):
    processed = re.sub("^\W|\W$", "", x[0])
    if processed != x[0] and processed in entity_wikidata_candidates_normed_0_dict:
        entity_wikidata_candidates[i][1] += entity_wikidata_candidates_normed_0_dict[processed]

In [None]:
entity_wikidata_target = (
    spark.createDataFrame(
        entity_mentions.map(
            lambda x: Row(
                id=x["_id"]["$oid"],
                mention=x["surfaceForm"],
                wikipedia_id=x["goldAnnotation"]["titleId"],
                cell_id="%d_%d_%d_%d" % (x["pgId"], x["tableId"], x["cellRow"], x["cellCol"]),
            )
        )
    )
    .join(wikipedia_wikidata_mapping, "wikipedia_id", "inner")
    .join(wiki_mid_mapping, "wikipedia_id", "inner")
)

In [None]:
entity_wikidata_target = wikipedia_gs_entity_mentions.join(wikipedia_wikidata_mapping, "wikipedia_title", "left").join(
    wiki_mid_mapping, "wikipedia_id", "left"
)

In [None]:
entity_wikidata_target.show()

In [None]:
import os


def load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=1):
    entity_vocab = {}
    bad_title = 0
    few_entity = 0
    with open(os.path.join(data_dir, "entity_vocab.txt"), "r", encoding="utf-8") as f:
        for line in f:
            _, entity_id, entity_title, entity_mid, count = line.strip().split("\t")
            if ignore_bad_title and entity_title == "":
                bad_title += 1
            elif int(count) < min_ent_count:
                few_entity += 1
            else:
                entity_vocab[len(entity_vocab)] = {
                    "wiki_id": int(entity_id),
                    "wiki_title": entity_title,
                    "mid": entity_mid,
                    "count": int(count),
                }
    print(
        "total number of entity: %d\nremove because of empty title: %d\nremove because count<%d: %d"
        % (len(entity_vocab), bad_title, min_ent_count, few_entity)
    )
    return entity_vocab


data_dir = "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/"
entity_vocab = load_entity_vocab(data_dir, True, 2)
train_all_entities = set([x["mid"] for _, x in entity_vocab.items() if x["mid"] != ""])
train_all_entities_wiki_id = set([x["wiki_id"] for _, x in entity_vocab.items()])

In [None]:
# with open('wikipedia_gs_wikidata_candidates.json', "w", encoding='utf8') as f:
#     json.dump(entity_wikidata_candidates, f)
with open("wikipedia_gs_wikidata_candidates.json", "r", encoding="utf8") as f:
    entity_wikidata_candidates = json.load(f)

In [None]:
# with open('wikidata_candidates.json', "w", encoding='utf8') as f:
#     json.dump(entity_wikidata_candidates, f)
with open("wikidata_candidates.json", "r", encoding="utf8") as f:
    entity_wikidata_candidates = json.load(f)

In [None]:
entity_wikidata_candidates += missing_wikidata_candidates

In [None]:
entity_wikidata_candidates_df = spark.createDataFrame(
    sc.parallelize(entity_wikidata_candidates).map(
        lambda x: Row(mention=x[0], candidates=x[1] if isinstance(x[1], list) else [])
    )
)

In [None]:
wikipedia_gs_entity_mentions_with_candidate = wikipedia_gs_entity_mentions.join(
    entity_wikidata_candidates_df, "mention", "left"
).join(wikipedia_wikidata_mapping, "wikipedia_title", "left")

In [None]:
train_mentions_with_candidate = train_mentions.join(entity_wikidata_candidates_df, "mention", "left").join(
    wikipedia_wikidata_mapping, "wikipedia_title", "inner"
)
val_mentions_with_candidate = val_mentions.join(entity_wikidata_candidates_df, "mention", "left").join(
    wikipedia_wikidata_mapping, "wikipedia_title", "inner"
)
test_mentions_with_candidate = test_mentions.join(entity_wikidata_candidates_df, "mention", "left").join(
    wikipedia_wikidata_mapping, "wikipedia_title", "inner"
)

In [None]:
val_mentions_with_candidate.show()

In [None]:
missing_mentions = val_mentions_with_candidate.where(F.isnull("candidates")).rdd.map(lambda x: x["mention"]).collect()
missing_mentions += val_mentions_with_candidate.where(F.isnull("candidates")).rdd.map(lambda x: x["mention"]).collect()
missing_mentions = list(set(missing_mentions))
print(len(missing_mentions))

In [None]:
train_mentions_with_candidate.dropDuplicates(["mention", "wikidata_id"]).where(F.size("candidates") != 0).count()

In [None]:
entity_wikidata_target_candidate = entity_wikidata_target.join(entity_wikidata_candidates_df, "mention", "inner")

In [None]:
print(train_mentions_with_candidate.where(F.size("candidates") != 0).count())
print(
    train_mentions_with_candidate.where(F.size("candidates") != 0)
    .join(dbpedia_types.select("wikipedia_title").dropDuplicates(), "wikipedia_title", "inner")
    .count()
)
print(
    train_mentions_with_candidate.where(F.size("candidates") != 0)
    .join(dbpedia_abstract.select("wikipedia_title").dropDuplicates(), "wikipedia_title", "inner")
    .count()
)

In [None]:
train_mentions_with_candidate.show()

In [None]:
def build_for_own(x):
    all_processed = []
    table_id = x[0]
    pgTitle = x[1][1][0]
    secTitle = x[1][1][1]
    caption = x[1][1][2]
    headers = x[1][1][3]
    all_entities = x[1][0]
    while len(all_entities) > 0:
        entities = [[[z[0], z[1]], z[2]] for z in all_entities[:50]]
        candidate_entities = {}
        for z in all_entities[:50]:
            for cand in z[4]:
                if cand[0] not in candidate_entities:
                    candidate_entities[cand[0]] = [
                        len(candidate_entities),
                        cand[1],
                        cand[2],
                        dbpedia_types.get(cand[0], []),
                    ]
        labels = [candidate_entities[z[3]][0] for z in all_entities[:50]]
        cand_for_each = [[candidate_entities[cand[0]][0] for cand in z[4]] for z in all_entities[:50]]
        tmp_candidate_entities = [0] * len(candidate_entities)
        for k, v in candidate_entities.items():
            tmp_candidate_entities[v[0]] = v[1:]
        all_processed.append(
            [table_id, pgTitle, secTitle, caption, headers, entities, tmp_candidate_entities, labels, cand_for_each]
        )
        all_entities = all_entities[50:]
    return all_processed

In [None]:
def build_for_own_with_wikidata_id(x):
    all_processed = []
    table_id = x[0]
    pgTitle = x[1][1][0]
    secTitle = x[1][1][1]
    caption = x[1][1][2]
    headers = x[1][1][3]
    all_entities = x[1][0]
    while len(all_entities) > 0:
        entities = [[[z[0], z[1]], z[2]] for z in all_entities[:50]]
        candidate_entities = {}
        for z in all_entities[:50]:
            for cand in z[4]:
                if cand[0] not in candidate_entities:
                    candidate_entities[cand[0]] = [
                        len(candidate_entities),
                        cand[1],
                        cand[2],
                        dbpedia_types.get(cand[0], []),
                        cand[0],
                    ]
        labels = [candidate_entities[z[3]][0] for z in all_entities[:50]]
        cand_for_each = [[candidate_entities[cand[0]][0] for cand in z[4]] for z in all_entities[:50]]
        tmp_candidate_entities = [0] * len(candidate_entities)
        for k, v in candidate_entities.items():
            tmp_candidate_entities[v[0]] = v[1:]
        all_processed.append(
            [table_id, pgTitle, secTitle, caption, headers, entities, tmp_candidate_entities, labels, cand_for_each]
        )
        all_entities = all_entities[50:]
    return all_processed

In [None]:
# only output examples with recall>0 for reranking. Including empty candidates or all wrong candidates
train_mentions_local = (
    train_mentions_with_candidate.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .dropDuplicates(["mention", "wikidata_id"])
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(
        train_tables.map(
            lambda x: (x["_id"], [x["pgTitle"], x["sectionTitle"], x["tableCaption"], x["processed_tableHeaders"]])
        )
    )
    .flatMap(build_for_own)
    .collect()
)

In [None]:
val_mentions_local = (
    val_mentions_with_candidate.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .dropDuplicates(["mention", "wikidata_id"])
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(
        val_tables.map(
            lambda x: (x["_id"], [x["pgTitle"], x["sectionTitle"], x["tableCaption"], x["processed_tableHeaders"]])
        )
    )
    .flatMap(build_for_own)
    .collect()
)

In [None]:
# 08/20
test_mentions_local = (
    test_mentions_with_candidate.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(
        test_tables.map(
            lambda x: (x["_id"], [x["pgTitle"], x["sectionTitle"], x["tableCaption"], x["processed_tableHeaders"]])
        )
    )
    .flatMap(build_for_own)
    .collect()
)
print(len(test_mentions_local))
test_mentions_local_with_wikidata_id = (
    test_mentions_with_candidate.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(
        test_tables.map(
            lambda x: (x["_id"], [x["pgTitle"], x["sectionTitle"], x["tableCaption"], x["processed_tableHeaders"]])
        )
    )
    .flatMap(build_for_own_with_wikidata_id)
    .collect()
)
print(len(test_mentions_local_with_wikidata_id))

In [None]:
print(sc.parallelize(test_mentions_local).map(lambda x: x[0]).distinct().count())
print(sc.parallelize(test_mentions_local).map(lambda x: len(x[5])).sum())

In [None]:
print(len(train_mentions_local))
print(len(val_mentions_local))

In [None]:
with open(data_dir + "train.table_entity_linking.json", "w") as f:
    json.dump(train_mentions_local, f)
with open(data_dir + "dev.table_entity_linking.json", "w") as f:
    json.dump(val_mentions_local, f)

In [None]:
with open(data_dir + "test_own.table_entity_linking.json", "w") as f:
    json.dump(test_mentions_local, f)

In [None]:
with open(data_dir + "test_own_0820.table_entity_linking.with_wikidata_id.json", "w") as f:
    json.dump(test_mentions_local_with_wikidata_id, f)
with open(data_dir + "test_own_0820.table_entity_linking.json", "w") as f:
    json.dump(test_mentions_local, f)

In [None]:
with open(data_dir + "train.table_entity_linking.json", "r") as f:
    train_mentions_local = sc.parallelize(json.load(f))

In [None]:
test_mentions_local[0]

In [None]:
test_mentions_local = (
    wikipedia_gs_entity_mentions_with_candidate.select("tableId", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["tableId"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(
        wikipedia_gs_tables.map(
            lambda x: (x["tableId"], [x["pgTitle"], x["sectionTitle"], x["tableCaption"], x["processed_tableHeaders"]])
        )
    )
    .flatMap(build_for_own)
    .collect()
)

In [None]:
data_dir = "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/"
with open(data_dir + "test.table_entity_linking.json", "w") as f:
    json.dump(test_mentions_local, f)

In [None]:
def get_labels_and_candidate(tables):
    results = []
    for i, entity in enumerate(tables[5]):
        results.append(((tables[0], entity[0][0], entity[0][1]), [tables[7][i], tables[8][i]]))
    return results

# Evaluation with dumped model results

In [None]:
import pickle

with open(
    "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/test_entity_linking_results_2.pkl",
    "rb",
) as f:
    gs_test_results = pickle.load(f)

In [None]:
import pickle

with open(
    "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/test_own_entity_linking_results_2.pkl",
    "rb",
) as f:
    test_results = pickle.load(f)

In [None]:
import pickle

with open(
    "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/test_own_0820_entity_linking_results_0.pkl",
    "rb",
) as f:
    test_results = pickle.load(f)

In [None]:
def get_tp(result):
    result = result[1]
    for x in result[1]:
        if x in result[0][1]:
            if x == result[0][0]:
                return 1
            else:
                return 0
    return 0

In [None]:
our_tp = (
    sc.parallelize(test_mentions_local)
    .flatMap(get_labels_and_candidate)
    .join(sc.parallelize(test_results).flatMap(lambda x: [((x[0], z[0], z[1]), x[2][i]) for i, z in enumerate(x[1])]))
    .map(get_tp)
    .sum()
)

In [None]:
mentioned_dbpedia_types = (
    sc.parallelize(train_mentions_local).map(lambda x: set([z for y in x[6] for z in y[2]])).reduce(lambda a, b: a | b)
)

In [None]:
with open(data_dir + "dbpedia_type_vocab.txt", "w") as f:
    f.write("{}\t{}\n".format(0, "[PAD]"))
    for i, t in enumerate(mentioned_dbpedia_types):
        f.write("{}\t{}\n".format(i + 1, t))

In [None]:
wrong_mentions = spark.createDataFrame(
    entity_wikidata_target_candidate.rdd.filter(
        lambda x: x["wikidata_id"] not in [z[0] for z in x["candidates"][:1]]
        and x["wikidata_id"] in [z[0] for z in x["candidates"][:]]
    )
)

In [None]:
print(wrong_mentions.where("wikipedia_id is not null").distinct().count())
entities = set(
    wrong_mentions.where("wikipedia_id is not null").rdd.map(lambda x: x["wikipedia_id"]).distinct().collect()
)
print(len(entities))
print(len(entities & train_all_entities_wiki_id))

In [None]:
print(wrong_mentions.where(F.size("candidates") != 0).count())
print(
    wrong_mentions.where(F.size("candidates") != 0)
    .join(dbpedia_types.select("wikipedia_title").dropDuplicates(), "wikipedia_title", "inner")
    .count()
)
print(
    wrong_mentions.where(F.size("candidates") != 0)
    .join(dbpedia_abstract.select("wikipedia_title").dropDuplicates(), "wikipedia_title", "inner")
    .count()
)

In [None]:
wrong_mentions.where(F.size("candidates") != 0).join(dbpedia_types, "wikipedia_title", "left").where(
    "type is null"
).show()

In [None]:
def get_index(x, cands):
    for i, z in enumerate(cands):
        if x == z:
            return i
    return 999


best_recall = (
    val_mentions_with_candidate.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .dropDuplicates(["mention", "wikidata_id"])
    .where(~F.isnull("candidates"))
    .rdd.filter(lambda x: len(x["candidates"]) != 0)
    .map(lambda x: get_index(x["wikidata_id"], [z[0] for z in x["candidates"]]))
    .collect()
)

In [None]:
numpy_describe(best_recall)

In [None]:
numpy_describe(best_recall)

In [None]:
for i in range(60, 80):
    print(i, np.percentile(best_recall, i))

In [None]:
wikipedia_gs_entity_mentions.count()

In [None]:
gs_wikidata_P

In [None]:
gs_wikidata_TP

In [None]:
our_tp

In [None]:
test_mentions_with_candidate.count()

In [None]:
test_wikidata_all_predicted = test_mentions_with_candidate.where(F.size("candidates") >= 1).count()
test_wikidata_TP = (
    test_mentions_with_candidate.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"][:1]] else 0)
    .sum()
)
test_wikidata_P = test_mentions_with_candidate.count()
test_wikidata_best_TP = (
    test_mentions_with_candidate.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"]] else 0)
    .sum()
)

In [None]:
precision = test_wikidata_TP / test_wikidata_all_predicted
recall = test_wikidata_TP / test_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
precision = test_wikidata_best_TP / test_wikidata_all_predicted
recall = test_wikidata_best_TP / test_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
precision = our_tp / test_wikidata_all_predicted
recall = our_tp / test_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
print("no description")
precision = our_tp / test_wikidata_all_predicted
recall = our_tp / test_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
print("no type")
precision = our_tp / test_wikidata_all_predicted
recall = our_tp / test_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
gs_wikidata_all_predicted / gs_wikidata_P

In [None]:
gs_wikidata_all_predicted = wikipedia_gs_entity_mentions_with_candidate.where(F.size("candidates") >= 1).count()
gs_wikidata_TP = (
    wikipedia_gs_entity_mentions_with_candidate.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"][:1]] else 0)
    .sum()
)
gs_wikidata_P = wikipedia_gs_entity_mentions_with_candidate.count()
gs_wikidata_best_TP = (
    wikipedia_gs_entity_mentions_with_candidate.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"]] else 0)
    .sum()
)

In [None]:
precision = gs_wikidata_TP / gs_wikidata_all_predicted
recall = gs_wikidata_TP / gs_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
precision = gs_wikidata_best_TP / gs_wikidata_all_predicted
recall = gs_wikidata_best_TP / gs_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
precision = our_tp / gs_wikidata_all_predicted
recall = our_tp / gs_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
print("no description")
precision = our_tp / gs_wikidata_all_predicted
recall = our_tp / gs_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
print("no type")
precision = our_tp / gs_wikidata_all_predicted
recall = our_tp / gs_wikidata_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
our_tp / gs_wikidata_best_TP

In [None]:
entity_wikidata_target_candidate.where(F.size("candidates") >= 1).count()

In [None]:
len(recall) / entity_wikidata_target_candidate.count()

In [None]:
wiki_35k_test = spark.createDataFrame(
    sc.textFile("../../data/entity_linking/35k_test.ids.txt").map(lambda x: Row(id=x))
)

In [None]:
wiki_35k_test_recall = (
    entity_wikidata_target_candidate.join(wiki_35k_test, "id", "inner")
    .rdd.map(lambda x: x["wikidata_id"] in [z[0] for z in x["candidates"]])
    .collect()
)
wiki_35k_test_precision = (
    entity_wikidata_target_candidate.join(wiki_35k_test, "id", "inner")
    .rdd.map(lambda x: x["wikidata_id"] in [z[0] for z in x["candidates"][:1]])
    .collect()
)

In [None]:
sum(wiki_35k_test_recall) / len(wiki_35k_test_recall)

In [None]:
sum(wiki_35k_test_precision) / sum(wiki_35k_test_recall)

In [None]:
wiki_35k_mentions = entity_wikidata_target.join(wiki_35k_test, "id", "inner")

In [None]:
entity_wikidata_target_candidate = entity_wikidata_target.join(entity_googlekg_candidates_df, "mention", "inner").join(
    entity_wikidata_candidates_df, "mention", "inner"
)

In [None]:
entity_wikidata_target_candidate.count()

In [None]:
len(entity_mentions_surface)

# Efthymiou
## T2D

In [None]:
import csv

In [None]:
t2d_tables = sc.wholeTextFiles("../../data/efthymiou/t2d/tables_instance_with_context").map(
    lambda x: (x[0].split("/")[-1][:-5], json.loads(x[1]))
)

In [None]:
t2d_tables.take(1)

In [None]:
sc.wholeTextFiles("../../data/efthymiou/t2d/entities_instance").map(
    lambda x: (x[0].split("/")[-1][:-4], list(csv.reader(x[1].split("\n"))))
).flatMap(lambda x: [y for y in x[1] if len(y) == 3]).count()

In [None]:
t2d_entities = spark.createDataFrame(
    sc.wholeTextFiles("../../data/efthymiou/t2d/entities_instance")
    .map(lambda x: (x[0].split("/")[-1][:-4], list(csv.reader(x[1].split("\n")))))
    .flatMap(
        lambda x: [
            Row(
                table_id=x[0],
                wikipedia_title=y[0].split("/")[-1],
                j=0,
                i=int(y[2]),
                mention=y[1].replace("&nbsp;", "").replace("&nbsp", ""),
            )
            for y in x[1]
            if len(y) == 3
        ]
    )
).join(wikipedia_wikidata_mapping, "wikipedia_title", "inner")

In [None]:
t2d_entities.show()

In [None]:
t2d_entities.count()

In [None]:
t2d_entity_mentions = list(set(t2d_entities.rdd.map(lambda x: x["mention"]).collect()))
print(len(t2d_entity_mentions))

In [None]:
t2d_entity_mentions[:100]

In [None]:
entity_t2d_candidates = []
i = 0
pool = Pool(processes=16)
while i < len(t2d_entity_mentions):
    print(i)
    tmp = list(tqdm(pool.imap(wikidata_lookup, t2d_entity_mentions[i : i + 10000], chunksize=150), total=10000))
    entity_t2d_candidates.extend(tmp)
    i += 10000
pool.close()

In [None]:
entity_t2d_candidates_df = spark.createDataFrame(
    sc.parallelize(entity_t2d_candidates).map(
        lambda x: Row(mention=x[0], candidates=x[1] if isinstance(x[1], list) else [])
    )
)

In [None]:
t2d_entities_with_candidates = t2d_entities.join(entity_t2d_candidates_df, "mention", "left")

In [None]:
t2d_entities_with_candidates.show()

In [None]:
sample = (
    t2d_entities_with_candidates.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(t2d_tables)
    .take(1)[0]
)

In [None]:
sample[1][0]

In [None]:
def build_for_own(x):
    all_processed = []
    table_id = x[0]
    pgTitle = x[1][1]["pageTitle"]
    secTitle = ""
    caption = x[1][1]["title"]
    header_i = x[1][1]["headerRowIndex"]
    subject_j = x[1][1]["keyColumnIndex"]
    headers = [column[header_i] for column in x[1][1]["relation"][subject_j:]]
    all_entities = x[1][0]
    total_num = len(all_entities)
    chunck_num = int(total_num / max([1, int(total_num / 25)])) + 1
    while len(all_entities) > 0:
        entities = []
        candidate_entities = {}
        labels = []
        cand_for_each = []
        for e in all_entities[:chunck_num]:
            row_i = e[0]
            e_mention = e[2]
            entities.append([[row_i, 0], e_mention])
            for cand in e[4]:
                if cand[0] not in candidate_entities:
                    candidate_entities[cand[0]] = [
                        len(candidate_entities),
                        cand[1],
                        cand[2],
                        dbpedia_types.get(cand[0], []),
                    ]
            labels.append(candidate_entities[e[3]][0])
            cand_for_each.append([candidate_entities[cand[0]][0] for cand in e[4]])
            for p, column in enumerate(x[1][1]["relation"][subject_j + 1 : subject_j + 3]):
                if len(column) > row_i:
                    e_mention = column[row_i].replace("&nbsp;", "").replace("&nbsp", "")
                    entities.append([[row_i, p + 1], e_mention])
                    labels.append(0)
                    cand_for_each.append([])

        #         entities = [[[z[0],0],z[2]] for z in all_entities[:50]]
        #         candidate_entities = {}
        #         for z in all_entities[:50]:
        #             for cand in z[4]:
        #                 if cand[0] not in candidate_entities:
        #                     candidate_entities[cand[0]] = [len(candidate_entities),cand[1],cand[2],dbpedia_types.get(cand[0],[])]
        #         labels = [candidate_entities[z[3]][0]  for z in all_entities[:50]]
        #         cand_for_each = [[candidate_entities[cand[0]][0] for cand in z[4]] for z in all_entities[:50]]
        tmp_candidate_entities = [0] * len(candidate_entities)
        for k, v in candidate_entities.items():
            tmp_candidate_entities[v[0]] = v[1:]
        all_processed.append(
            [table_id, pgTitle, secTitle, caption, headers, entities, tmp_candidate_entities, labels, cand_for_each]
        )
        all_entities = all_entities[chunck_num:]
    return all_processed

In [None]:
build_for_own(sample)[1]

In [None]:
t2d_local = (
    t2d_entities_with_candidates.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(t2d_tables)
    .flatMap(build_for_own)
    .collect()
)

In [None]:
t2d_local[70][8]

In [None]:
def get_labels_and_candidate(tables):
    results = []
    for i, entity in enumerate(tables[5]):
        if len(tables[8][i]) == 0:
            continue
        results.append(((tables[0], entity[0][0], entity[0][1]), [tables[7][i], tables[8][i], tables[6]]))
    return results

In [None]:
data_dir = "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/"
with open(data_dir + "t2d.table_entity_linking.json", "w") as f:
    json.dump(t2d_local, f)

In [None]:
t2d_all_predicted = t2d_entities_with_candidates.where(F.size("candidates") >= 1).count()
t2d_TP = (
    t2d_entities_with_candidates.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"][:1]] else 0)
    .sum()
)
t2d_P = t2d_entities_with_candidates.count()
t2d_best_TP = (
    t2d_entities_with_candidates.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"]] else 0)
    .sum()
)

In [None]:
precision = t2d_TP / t2d_all_predicted
recall = t2d_TP / t2d_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
t2d_P

In [None]:
precision = t2d_best_TP / t2d_all_predicted
recall = t2d_best_TP / t2d_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
import pickle

with open(
    "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/t2d_entity_linking_results_0.pkl",
    "rb",
) as f:
    test_results = pickle.load(f)

In [None]:
def get_tp(result):
    result = result[1]
    pred = []
    lookup = [result[0][1][0], 0]
    for i, x in enumerate(result[1][0]):
        if x in result[0][1]:
            pred = [x, result[1][1][i]]
            break
    for i, x in enumerate(result[1][0]):
        if x == lookup[0]:
            lookup[1] = result[1][1][i]
            break
    final = pred[0] if pred[0] == lookup[0] or (pred[1] * 0.8) > lookup[1] else lookup[0]
    if final == result[0][0]:
        return 1
    else:
        return 0

In [None]:
sample_result = (
    sc.parallelize(t2d_local)
    .flatMap(get_labels_and_candidate)
    .join(
        sc.parallelize(test_results).flatMap(
            lambda x: [((x[0], z[0], z[1]), (x[2][i], x[3][i])) for i, z in enumerate(x[1])]
        )
    )
    .take(1)
)

In [None]:
get_tp(sample_result[0])

In [None]:
our_tp = (
    sc.parallelize(t2d_local)
    .flatMap(get_labels_and_candidate)
    .join(
        sc.parallelize(test_results).flatMap(
            lambda x: [((x[0], z[0], z[1]), (x[2][i], x[3][i])) for i, z in enumerate(x[1])]
        )
    )
    .map(get_tp)
    .sum()
)

In [None]:
precision = our_tp / t2d_all_predicted
recall = our_tp / t2d_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
t2d_TP

In [None]:
our_tp

In [None]:
def get_tp(result):
    result = result[1]
    pred = []
    lookup = [result[0][1][0], 0]
    for i, x in enumerate(result[1][0]):
        if x in result[0][1]:
            pred = [x, result[1][1][i]]
            break
    for i, x in enumerate(result[1][0]):
        if x == lookup[0]:
            lookup[1] = result[1][1][i]
            break
    final = pred[0] if pred[0] == lookup[0] or (pred[1] * 0.8) > lookup[1] else lookup[0]
    if final == result[0][0]:
        return (1, result[0][2][final])
    else:
        return (0, result[0][2][final])

In [None]:
sample = (
    sc.parallelize(t2d_local)
    .flatMap(get_labels_and_candidate)
    .join(
        sc.parallelize(test_results).flatMap(
            lambda x: [((x[0], z[0], z[1]), (x[2][i], x[3][i])) for i, z in enumerate(x[1])]
        )
    )
    .take(1)
)

In [None]:
sample[0][1][0][2]

In [None]:
our_results = (
    sc.parallelize(t2d_local)
    .flatMap(get_labels_and_candidate)
    .join(
        sc.parallelize(test_results).flatMap(
            lambda x: [((x[0], z[0], z[1]), (x[2][i], x[3][i])) for i, z in enumerate(x[1])]
        )
    )
    .map(lambda x: (x[0], get_tp(x)))
)

In [None]:
lookup_results = t2d_entities_with_candidates.where(F.size("candidates") >= 1).rdd.map(
    lambda x: (
        (x["table_id"], x["i"], x["j"]),
        (x["mention"], x["candidates"], 1 if x["wikidata_id"] in [z[0] for z in x["candidates"][:1]] else 0),
    )
)

In [None]:
all_results = our_results.join(lookup_results)

In [None]:
errors = all_results.filter(lambda x: x[1][0][0] == 0 and x[1][1][-1] == 1).collect()

In [None]:
correct = all_results.filter(lambda x: x[1][0][0] == 1 and x[1][1][-1] == 0).collect()

In [None]:
print(len(errors))
print(len(correct))

In [None]:
len(correct)

In [None]:
len(correct)

In [None]:
correct[0]

In [None]:
errors[60]

In [None]:
len(set([x[0][0] for x in errors]))

In [None]:
len(set([x[0][0] for x in correct]))

In [None]:
set([x[0][0] for x in errors])

In [None]:
[[x, t2d_tables_local[x]["pageTitle"]] for x in list(set([x[0][0] for x in errors]))]

In [None]:
[x for x in errors if x[0][0] == "41194422_0_7231546114369966811"]

In [None]:
t2d_tables_local = dict(t2d_tables.collect())

In [None]:
t2d_tables_local["71137051_0_8039724067857124984"]

# Limaye

In [None]:
limaye_tables = sc.wholeTextFiles("../../data/efthymiou/LimayeGS/tables_instance").map(
    lambda x: (x[0].split("/")[-1][:-4], list(csv.reader(x[1].split("\n"))))
)

In [None]:
limaye_tables.take(1)

In [None]:
sc.wholeTextFiles("../../data/efthymiou/LimayeGS/entities_instance").map(
    lambda x: (x[0].split("/")[-1][:-4], list(csv.reader(x[1].split("\n"))))
).flatMap(lambda x: [y for y in x[1] if len(y) == 3]).count()

In [None]:
limaye_entities = spark.createDataFrame(
    sc.wholeTextFiles("../../data/efthymiou/LimayeGS/entities_instance")
    .map(lambda x: (x[0].split("/")[-1][:-4], list(csv.reader(x[1].split("\n")))))
    .flatMap(
        lambda x: [
            Row(
                table_id=x[0],
                wikipedia_title=y[0].split("/")[-1],
                j=0,
                i=int(y[2]),
                mention=y[1].replace("&nbsp;", "").replace("&nbsp", ""),
            )
            for y in x[1]
            if len(y) == 3
        ]
    )
).join(wikipedia_wikidata_mapping, "wikipedia_title", "inner")

In [None]:
limaye_entity_mentions = list(set(limaye_entities.rdd.map(lambda x: x["mention"]).collect()))
print(len(limaye_entity_mentions))

In [None]:
limaye_entity_mentions[:100]

In [None]:
entity_limaye_candidates = []
i = 0
pool = Pool(processes=16)
while i < len(limaye_entity_mentions):
    print(i)
    tmp = list(tqdm(pool.imap(wikidata_lookup, limaye_entity_mentions[i : i + 10000], chunksize=150), total=10000))
    entity_limaye_candidates.extend(tmp)
    i += 10000
pool.close()

In [None]:
entity_limaye_candidates_df = spark.createDataFrame(
    sc.parallelize(entity_limaye_candidates).map(
        lambda x: Row(mention=x[0], candidates=x[1] if isinstance(x[1], list) else [])
    )
)

In [None]:
limaye_entities_with_candidates = limaye_entities.join(entity_limaye_candidates_df, "mention", "left")

In [None]:
limaye_entities_with_candidates.show()

In [None]:
sample = (
    limaye_entities_with_candidates.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(limaye_tables)
    .take(1)[0]
)

In [None]:
sample[1][1]

In [None]:
def build_for_own(x):
    all_processed = []
    table_id = x[0]
    pgTitle = ""
    secTitle = ""
    caption = ""
    headers = ["" for j in range(len(x[1][1][0]))]
    all_entities = x[1][0]
    total_num = len(all_entities)
    chunck_num = int(total_num / max([1, int(total_num / 25)])) + 1
    while len(all_entities) > 0:
        entities = []
        candidate_entities = {}
        labels = []
        cand_for_each = []
        for e in all_entities[:chunck_num]:
            row_i = e[0]
            e_mention = e[2]
            entities.append([[row_i, 0], e_mention])
            for cand in e[4]:
                if cand[0] not in candidate_entities:
                    candidate_entities[cand[0]] = [
                        len(candidate_entities),
                        cand[1],
                        cand[2],
                        dbpedia_types.get(cand[0], []),
                    ]
            labels.append(candidate_entities[e[3]][0])
            cand_for_each.append([candidate_entities[cand[0]][0] for cand in e[4]])
            for p, cell in enumerate(x[1][1][row_i][1:]):
                e_mention = cell
                if e_mention != "":
                    entities.append([[row_i, p + 1], e_mention])
                    labels.append(0)
                    cand_for_each.append([])

        #         entities = [[[z[0],0],z[2]] for z in all_entities[:50]]
        #         candidate_entities = {}
        #         for z in all_entities[:50]:
        #             for cand in z[4]:
        #                 if cand[0] not in candidate_entities:
        #                     candidate_entities[cand[0]] = [len(candidate_entities),cand[1],cand[2],dbpedia_types.get(cand[0],[])]
        #         labels = [candidate_entities[z[3]][0]  for z in all_entities[:50]]
        #         cand_for_each = [[candidate_entities[cand[0]][0] for cand in z[4]] for z in all_entities[:50]]
        tmp_candidate_entities = [0] * len(candidate_entities)
        for k, v in candidate_entities.items():
            tmp_candidate_entities[v[0]] = v[1:]
        all_processed.append(
            [table_id, pgTitle, secTitle, caption, headers, entities, tmp_candidate_entities, labels, cand_for_each]
        )
        all_entities = all_entities[chunck_num:]
    return all_processed

In [None]:
build_for_own(sample)[0]

In [None]:
limaye_local = (
    limaye_entities_with_candidates.select("table_id", "wikidata_id", "candidates", "i", "j", "mention")
    .where(~F.isnull("candidates"))
    .rdd.map(lambda x: [x["table_id"], x["i"], x["j"], x["mention"], x["wikidata_id"], x["candidates"]])
    .filter(lambda x: x[4] in [z[0] for z in x[5]])
    .map(lambda x: (x[0], [x[1:]]))
    .reduceByKey(add)
    .join(limaye_tables)
    .flatMap(build_for_own)
    .collect()
)

In [None]:
limaye_local[70][8]

In [None]:
def get_labels_and_candidate(tables):
    results = []
    for i, entity in enumerate(tables[5]):
        if len(tables[8][i]) == 0:
            continue
        results.append(((tables[0], entity[0][0], entity[0][1]), [tables[7][i], tables[8][i], tables[6]]))
    return results

In [None]:
data_dir = "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/"
with open(data_dir + "limaye.table_entity_linking.json", "w") as f:
    json.dump(limaye_local, f)

In [None]:
limaye_all_predicted = limaye_entities_with_candidates.where(F.size("candidates") >= 1).count()
limaye_TP = (
    limaye_entities_with_candidates.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"][:1]] else 0)
    .sum()
)
limaye_P = limaye_entities_with_candidates.count()
limaye_best_TP = (
    limaye_entities_with_candidates.where(F.size("candidates") >= 1)
    .rdd.map(lambda x: 1 if x["wikidata_id"] in [z[0] for z in x["candidates"]] else 0)
    .sum()
)

In [None]:
precision = limaye_TP / limaye_all_predicted
recall = limaye_TP / limaye_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
import pickle

with open(
    "/srv/samba/group_workspace_1/deng.595/workspace/table_transformer/data/wikitable_entity/v2/limaye_entity_linking_results_0.pkl",
    "rb",
) as f:
    test_results = pickle.load(f)

In [None]:
our_tp = (
    sc.parallelize(limaye_local)
    .flatMap(get_labels_and_candidate)
    .join(sc.parallelize(test_results).flatMap(lambda x: [((x[0], z[0], z[1]), x[2][i]) for i, z in enumerate(x[1])]))
    .map(get_tp)
    .sum()
)

In [None]:
precision = our_tp / limaye_all_predicted
recall = our_tp / limaye_P
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)