## Note

This notebook is designed to preprocess a subset of SemTab challenge data, in particular the ones from this paper: https://scholar.google.com/citations?view_op=view_citation&hl=it&user=SqU0PwIAAAAJ&sortby=pubdate&citation_for_view=SqU0PwIAAAAJ:Y0pCki6q_DkC.

The data is in raw format and the objective of the pre-processing is to collect the description, the name and the candidates of every entity and prepare the data to be ingested by TURL (https://arxiv.org/abs/2006.14806)

In [None]:
from __future__ import annotations

import sys

sys.path.append("..")

import ast
import glob
import json
import os
import pickle
import time
import urllib.parse
import urllib.request
from operator import add, itemgetter
from typing import Any, Dict, List
from urllib.parse import unquote

import findspark
import pyspark
from pandarallel import pandarallel
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import Row

pandarallel.initialize(progress_bar=True, nb_workers=64, use_memory_fs=True)

import mapply
import pandas as pd
from tqdm import tqdm

mapply.init(n_workers=64, chunk_size=1, max_chunks_per_worker=0, progressbar=True)

tqdm.pandas()

In [None]:
def wikidata_lookup(query: Any, retry: int = 3, dbpedia_types: Dict[str, List[str]] | None = None):
    service_url = (
        "https://www.wikidata.org/w/api.php?action=wbsearchentities&search={}&language=en&limit=50&format=json"
    )
    if query != "":
        try:
            url = service_url.format(urllib.parse.quote(str(query)))
        except Exception:
            print(query)
            return [query, []]
        for _ in range(retry):
            try:
                response = urllib.request.urlopen(url)
            except urllib.error.HTTPError as e:
                if e.code == 429 or e.code == 503:
                    response = e.code
                    time.sleep(1)
                    continue
                else:
                    response = e.code
                    break
            except urllib.error.URLError as e:
                response = None
                break
            else:
                response = json.loads(response.read())
                break
        if isinstance(response, dict):
            response = [
                [
                    z.get("id"),
                    z.get("label", ""),
                    z.get("description", ""),
                    dbpedia_types.get(z.get("id"), []) if dbpedia_types is not None else [],
                ]
                for z in response.get("search", [])
            ]
        else:
            response = []
    else:
        response = []
    return [query, response]


def lamapi_lookup(
    query: Any,
    retry: int = 3,
    dbpedia_types: Dict[str, List[str]] | None = None,
    fuzzy: bool = False,
):
    service_url = (
        "http://149.132.176.50:8097/lookup/entity-retrieval?name={}&token=insideslab-lamapi-2024&kg=wikidata&fuzzy={}"
    )
    if query != "":
        try:
            url = service_url.format(urllib.parse.quote(str(query)), fuzzy)
        except Exception:
            print(query)
            return [query, []]
        for _ in range(retry):
            try:
                response = urllib.request.urlopen(url)
            except urllib.error.HTTPError as e:
                if e.code == 429 or e.code == 503:
                    response = e.code
                    time.sleep(1)
                    continue
                else:
                    response = e.code
                    break
            except urllib.error.URLError as e:
                response = None
                break
            else:
                response = json.loads(response.read())
                break
        if isinstance(response, dict):
            response = [
                [
                    z.get("id"),
                    z.get("name", ""),
                    z.get("description", ""),
                    dbpedia_types.get(z.get("id"), []) if dbpedia_types is not None else [],
                    z.get("es_score", 0.0),
                    z.get("ed_score", 0.0),
                    z.get("cosine_similarity", 0.0),
                ]
                for z in response.get(str(query).lower(), [])
            ]
            response = sorted(response, key=lambda l: (float(l[-3]), float(l[-2]), float(l[-1])), reverse=True)
            response = response[:50]
            response = [z[:-3] for z in response]
        else:
            response = []
    else:
        response = []
    return [query, response]


def wikidata_description_from_qids(qid: str, retry: int = 3) -> str:
    service_url = "https://www.wikidata.org/w/api.php?action=wbgetentities&ids={}&languages=en&format=json"
    url = service_url.format(urllib.parse.quote(qid))
    if qid.lower() == "nil":
        return ""
    for _ in range(retry):
        try:
            response = urllib.request.urlopen(url)
        except urllib.error.HTTPError as e:
            if e.code == 429 or e.code == 503:
                response = e.code
                time.sleep(1)
                continue
            else:
                response = e.code
                break
        except urllib.error.URLError as e:
            response = None
            break
        else:
            response = json.loads(response.read())
            break
    if isinstance(response, dict):
        try:
            desc = response.get("entities", "")[qid].get("descriptions", {}).get("en", {}).get("value", "")
        except Exception:
            print(response)
            raise Exception()
    else:
        desc = ""
    return desc

### Create spark session

In [None]:
findspark.init()
conf = pyspark.SparkConf().setAll(
    [
        ("spark.executor.memory", "8g"),
        ("spark.executor.cores", "2"),
        ("spark.executor.instances", "7"),
        ("spark.driver.memory", "150g"),
        ("spark.driver.maxResultSize", "100g"),
        ("spark.driver.extraClassPath", "/home/fbelotti/Downloads/sqlite-jdbc-3.36.0.3.jar"),
    ]
)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

In [None]:
wikidata_lookup("Advance Australia Fair*")

In [None]:
lamapi_lookup("Advance Australia Fair*", fuzzy=True)

### Preparing dataset

We start by reading the ground-truth dataset from the challenge

In [None]:
# The table_type for file saving
table = "Round1_T2D"
table_type = table.lower()

# The path to the file containing the ground-truth
gt_path = "~/semtab-data/raw/Round1_T2D/gt/CEA_Round1_gt_WD.csv".format(table)
# gt_path = "~/semtab-data/raw/HardTablesR2/gt/cea.csv"
# gt_path = "~/semtab-data/raw/HardTablesR3/gt/cea.csv"
# gt_path = "~/semtab-data/raw/2T_Round4/gt/cea.csv"
# gt_path = "~/semtab-data/raw/Round3_2019/gt/CEA_Round3_gt_WD.csv"
# gt_path = "~/semtab-data/raw/Round4_2020/gt/cea.csv"

# The path to the real tables folder
tables_folder = "~/semtab-data/raw/{}/tables/".format(table)

# Expand the paths
tables_folder = os.path.expanduser(tables_folder)
gt_path = os.path.expanduser(gt_path)

In [None]:
# The `names` could be different for different datasets
gt_df = pd.read_csv(gt_path, encoding="utf-8", names=["tableName", "row", "col", "id"])
gt_df

In [None]:
gt_df["id"].at[5]

In [None]:
gt_df["id"] = gt_df["id"].astype(str).apply(lambda x: [qid.split("/")[-1] for qid in x.split()][0])
gt_df["tableName"] = gt_df["tableName"].astype(str)
gt_df["row"] = gt_df["row"].astype(int) - 1  # Do not consider the header
gt_df["col"] = gt_df["col"].astype(int)

In [None]:
gt_df

In [None]:
table = pd.read_csv(os.path.join(tables_folder, gt_df["tableName"].at[0]) + ".csv", encoding="utf-8")
table

In [None]:
# Get the mentions for each table
table_names = gt_df["tableName"].unique()
for table_name in tqdm(table_names):
    table_path = os.path.join(tables_folder, table_name + ".csv")
    table = pd.read_csv(table_path, encoding="utf-8")
    for i, row in gt_df[gt_df["tableName"] == table_name].iterrows():
        gt_df.at[i, "mention"] = table.iloc[row["row"], row["col"]]

In [None]:
gt_df

### Drop rows without a mention to be linked

In [None]:
gt_df.dropna(subset=["mention"], inplace=True)

In [None]:
total_number_of_mentions = len(gt_df)
total_number_of_mentions

### Get the types from the dbpedia_types mapping

In [None]:
# you can create the index-enwiki dump use this library https://github.com/jcklie/wikimapper
wikipedia_wikidata_mapping = (
    spark.read.format("jdbc")
    .options(
        url="jdbc:sqlite:/home/fbelotti/turl-data/index_enwiki-20190420.db",
        driver="org.sqlite.JDBC",
        dbtable="mapping",
    )
    .load()
)
wikipedia_wikidata_mapping.show()

In [None]:
dbpedia_types = dict(
    spark.createDataFrame(
        sc.textFile("/home/fbelotti/turl-data/dbpedia_types/2019_08_30/instance_type_en.ttl")
        .map(lambda x: x.split())
        .map(
            lambda x: Row(
                wikipedia_title=unquote(x[0][1:-1]).replace("http://dbpedia.org/resource/", ""),
                type=x[2][1:-1].split("/")[-1],
            )
        )
    )
    .join(wikipedia_wikidata_mapping, "wikipedia_title", "inner")
    .rdd.map(lambda x: (x["wikidata_id"], [x["type"]]))
    .reduceByKey(add)
    .collect()
)
print(len(dbpedia_types))

In [None]:
gt_df["types"] = gt_df["id"].parallel_apply(lambda x: dbpedia_types.get(x, []))

In [None]:
gt_df

### Save gt dataset pre-processed (i.e. with types every mention already computed)

In [None]:
gt_df.to_csv(os.path.join(os.path.dirname(gt_path), "cea_gt_with_types.csv"), index=False)

### Get description for every QID

In [None]:
qids = pd.DataFrame(gt_df["id"].unique(), columns=["id"])
qids["description"] = qids.parallel_apply(lambda row: wikidata_description_from_qids(row["id"]), axis=1)

In [None]:
gt_df = gt_df.merge(qids, on="id", how="left")
gt_df

### Save gt dataset pre-processed (i.e. with types and description for every mention already computed)

In [None]:
gt_df.to_csv(os.path.join(os.path.dirname(gt_path), "cea_gt_with_types_and_desc.csv"), index=False)

In [None]:
gt_df = pd.read_csv(os.path.join(os.path.dirname(gt_path), "cea_gt_with_types_and_desc.csv"))
gt_df

### Get candidates for every mention

In [None]:
unique_mentions = gt_df.drop_duplicates(subset=["mention"])
unique_mentions

In [None]:
lookup = "lamapi"

In [None]:
if lookup == "wikidata":
    unique_mentions.loc[:, "candidates"] = unique_mentions.parallel_apply(
        lambda row: wikidata_lookup(row["mention"], dbpedia_types=dbpedia_types)[1], axis=1
    )
elif lookup == "lamapi":
    unique_mentions.loc[:, "candidates"] = unique_mentions.parallel_apply(
        lambda row: lamapi_lookup(row["mention"], dbpedia_types=dbpedia_types, fuzzy=False)[1], axis=1
    )
else:
    raise ValueError("Invalid lookup")

In [None]:
unique_mentions.to_csv(os.path.join(os.path.dirname(gt_path), "unique_mentions_{}.csv".format(lookup)), index=False)

In [None]:
unique_mentions[unique_mentions["candidates"].apply(len) != 0]

In [None]:
unique_mentions_with_candidates = {}
for i, row in tqdm(unique_mentions.iterrows(), total=unique_mentions.shape[0]):
    unique_mentions_with_candidates[row["mention"]] = row["candidates"]

In [None]:
gt_df["candidates"] = ""
for i, row in tqdm(gt_df.iterrows(), total=gt_df.shape[0]):
    cand = unique_mentions_with_candidates[row["mention"]]
    gt_df.at[i, "candidates"] = str(cand)

In [None]:
gt_df["candidates"] = gt_df["candidates"].parallel_apply(lambda x: ast.literal_eval(x))

In [None]:
# gt_df["candidates"] = gt_df.parallel_apply(
#     lambda row: wikidata_lookup(row["mention"], dbpedia_types=dbpedia_types)[1], axis=1
# )

In [None]:
gt_df

In [None]:
gt_df.at[0, "candidates"]

In [None]:
gt_df_candidates_with_mention = gt_df

### Save gt dataset pre-processed (i.e. with types and candidates for every mention already computed)

In [None]:
gt_df.to_csv(
    os.path.join(os.path.dirname(gt_path), "cea_gt_with_wikidata_candidates_{}.csv".format(lookup)), index=False
)

### Read the gt file with candidates (if needed): this is the full gt dataset, without anything removed

In [None]:
gt_df_candidates_with_mention = pd.read_csv(
    os.path.join(os.path.dirname(gt_path), "cea_gt_with_wikidata_candidates_{}.csv".format(lookup))
)
gt_df_candidates_with_mention

In [None]:
gt_df_candidates_with_mention["description"] = gt_df_candidates_with_mention["description"].fillna("").astype(str)
gt_df_candidates_with_mention

In [None]:
gt_df_candidates_with_mention["candidates"] = gt_df_candidates_with_mention["candidates"].parallel_apply(
    lambda x: ast.literal_eval(x)
)

In [None]:
gt_df_candidates_with_mention["types"] = gt_df_candidates_with_mention["types"].parallel_apply(
    lambda x: ast.literal_eval(x)
)

### Create dataset for TURL evaluation

We can have two cases:

1. We keep only those mentions that are contained in the candidate list generated by the wikidata lookup: in this case we evaluate the overall system
2. If the mention is not present in the candidate list, we add it ourself: in this case we evaluate only the disambiguation model (TURL in this case) 

#### Remove rows that do not contain any candidates (nothing to link to)

In [None]:
gt_df_candidates = gt_df_candidates_with_mention[gt_df_candidates_with_mention["candidates"].apply(len).gt(0)]
gt_df_candidates

#### Get only the rows where the list of candidates contains the mention (we want to test the disambiguation algorithm)

In [None]:
gt_df_candidates_with_mention = gt_df_candidates[
    gt_df_candidates.apply(lambda x: x["id"] in list(map(itemgetter(0), x["candidates"])), axis=1)
]
gt_df_candidates_with_mention

### Read headers from all the tables

In [None]:
table_names = gt_df_candidates_with_mention["tableName"].unique().tolist()
table_names[:5]

In [None]:
raw_tables_names = os.listdir(tables_folder)
headers = {}
for table_name in tqdm(raw_tables_names):
    if ".csv" not in table_name:
        continue
    table_path = os.path.join(tables_folder, table_name)
    headers[os.path.splitext(table_name)[0]] = (
        pd.read_csv(os.path.join(tables_folder, table_name), nrows=0, encoding="utf-8").columns.str.lower().tolist()
    )
headers

### Prepare data for TURL evaluation

In [None]:
insert_target_mention_in_candidates = True

In [None]:
tables = []
total_mention_per_table = 50
table
for table_name in tqdm(table_names):
    table_sample = gt_df_candidates_with_mention[gt_df_candidates_with_mention["tableName"] == table_name].sort_values(
        ["row", "col"], ascending=[True, True]
    )

    # Table-meta information
    page_title = ""
    section_title = ""
    caption = ""
    table_headers = list(map(str, headers[table_name]))

    # Table mentions to be linked
    all_mentions = table_sample.apply(lambda x: [[int(x["row"]), int(x["col"])], str(x["mention"])], axis=1)
    if len(all_mentions) == 0:
        continue
    else:
        all_mentions = all_mentions.tolist()

    # Loop over `all_mentions` in chunks of `total_mention_per_table`
    tmpt = total_mention_per_table if total_mention_per_table > 0 else len(all_mentions)
    for i in range(0, len(all_mentions), tmpt):
        mentions = all_mentions[i : i + tmpt]

        # Create candidates for each mention
        labels = []
        all_candidates = []
        entities_index = []
        for row_idx, row in table_sample[i : i + tmpt].iterrows():
            candidates = row["candidates"]
            try:
                label_index = list(map(itemgetter(0), candidates)).index(row["id"])
            except ValueError:
                if insert_target_mention_in_candidates:
                    label_index = 0
                    candidates = [[row["id"], row["mention"], row["description"], row["types"]]] + candidates
                else:
                    continue
            label_index += len(all_candidates)
            candidates_without_id = [x[1:] for x in candidates]
            candidates_without_id_str = []
            for candidate in candidates_without_id:
                mention = str(candidate[0])
                description = str(candidate[1])
                types = candidate[2]
                candidates_without_id_str.append([mention, description, types])
            labels.append(int(label_index))
            all_candidates.extend(candidates_without_id_str)
            entities_index.append(list(range(len(all_candidates) - len(candidates), len(all_candidates))))
        if len(all_candidates) != 0:
            # print(
            #     "Table",
            #     table_name,
            #     "has",
            #     len(mentions),
            #     "mentions and",
            #     len(all_candidates),
            #     "candidates",
            # )
            tables.append(
                [
                    str(table_name),
                    page_title,
                    section_title,
                    caption,
                    table_headers,
                    mentions,
                    all_candidates,
                    labels,
                    entities_index,
                ]
            )

### Dump the dataset for TURL evaluation

In [None]:
print(
    "Saving tables to",
    "/home/fbelotti/turl-data/{}{}{}.table_entity_linking.json".format(
        table_type, "_all" if insert_target_mention_in_candidates else "", "_" + lookup
    ),
)
with open(
    "/home/fbelotti/turl-data/{}{}{}.table_entity_linking.json".format(
        table_type, "_all" if insert_target_mention_in_candidates else "", "_" + lookup
    ),
    "w",
) as f:
    json.dump(tables, f)

### Pre-process dataset with TURL ELDataset

In [None]:
import sys

sys.path.append("..")

from src.data_loader.el_data_loaders import ELDataset
from src.utils.util import load_dbpedia_type_vocab

if __name__ == "__main__":
    data_dir = "/home/fbelotti/turl-data"
    type_vocab = load_dbpedia_type_vocab(data_dir)
    train_dataset = ELDataset(
        data_dir,
        type_vocab,
        max_input_tok=500,
        src=(table_type + "_all" if insert_target_mention_in_candidates else table_type) + "_" + lookup,
        max_length=[50, 10, 10, 100],
        force_new=True,
        tokenizer=None,
    )

## Evaluation

In [None]:
def get_labels_and_candidate(tables):
    results = []
    # For every entity mention in the table
    for i, entity in enumerate(tables[5]):
        # If the candidate entities for the mention are empty, skip
        if len(tables[8][i]) == 0:
            continue
        # ((table_id, entity row, entity col), [index of the candidate entity in the candidate entities list, candidate indexes, candidate entities])
        results.append(((tables[0], entity[0][0], entity[0][1]), [tables[7][i], tables[8][i], tables[6]]))
    return results


def get_tp(result):
    result = result[1]
    # result[0] contains: label index (in the candidate list), candidate span (in the candidate list), candidates
    # result[1] contains: sorted predicted indexes, sorted predicted scores
    pred = []
    lookup = [result[0][1][0], 0]  # Lookup the first candidate
    # The prediction is first predicted candidate
    # TODO: consider the case where the first predicted candidate is not in the candidate span, i.e.
    # a totally different entity has been predicted
    for i, x in enumerate(result[1][0]):
        if x in result[0][1]:
            pred = [x, result[1][1][i]]
            break
    # Get the score of the first candidate returned by the lookup
    for i, x in enumerate(result[1][0]):
        if x == lookup[0]:
            lookup[1] = result[1][1][i]
            break
    final = pred[0] if pred[0] == lookup[0] or (pred[1] * 1.0) > lookup[1] else lookup[0]
    if final == result[0][0]:
        return 1
    else:
        return 0

In [None]:
dataset_type = "round1_t2d_all_lamapi"

In [None]:
with open("/home/fbelotti/turl-data/" + dataset_type + ".table_entity_linking.json", "rb") as f:
    dataset = json.load(f)

In [None]:
dataset[0]

In [None]:
# No dedup: ~/projects/TURL/output/logs/turl/fine-tuning-el/2024-02-08_13-14-19/version_0/checkpoints/checkpoint-last/pytorch_model.bin
# Dedup: ~/projects/TURL/output/logs/turl/fine-tuning-el/2024-02-14_11-01-08/version_0/checkpoints/checkpoint-last/pytorch_model.bin

dedup = False
if dedup:
    prefix = "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-el/2024-02-14_11-01-08/version_0/test/"
else:
    prefix = "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-el/2024-02-08_13-14-19/version_0/test/"
    # prefix = "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-el/2024-03-01_10-32-37/version_0/test/"
    # prefix = "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-el/2024-03-04_15-26-27/version_0/test/"

### Load single result file and compute true positive with Spark

In [None]:
print("Loading results", prefix + dataset_type + "_entity_linking_results{}.pkl".format("_dedup" if dedup else ""))
with open(
    prefix + dataset_type + "_entity_linking_results{}.pkl".format("_dedup" if dedup else ""),
    "rb",
) as f:
    test_results = pickle.load(f)

In [None]:
# Results for table with name
test_results[0][0]

In [None]:
# All mentions in the table with name test_results[0][0], number of mentions
print(len(test_results[0][1]), "mentions in table", test_results[0][0], "\n", test_results[0][1])

In [None]:
# Sorted indexes for every mention in the table with name test_results[0][0]
len(test_results[0][2][2])

In [None]:
dataset_labels_and_cands = sc.parallelize(dataset).flatMap(get_labels_and_candidate)

In [None]:
dataset_and_results = dataset_labels_and_cands.join(
    sc.parallelize(test_results).flatMap(
        lambda x: [((x[0], z[0], z[1]), (x[2][i], x[3][i])) for i, z in enumerate(x[1])]
    )
)

In [None]:
our_tp = dataset_and_results.map(get_tp).sum()

In [None]:
all_predicted = 0
for table in dataset:
    all_predicted += len(table[5])
all_predicted

### Load the split results and compute the true positive 

In [None]:
dataset_dict = {}
for table in tqdm(dataset):
    table_id = table[0]
    table_mentions = table[5]
    table_labels = table[7]
    table_entities_index = table[8]
    if table_id not in dataset_dict:
        dataset_dict[table_id] = {}
    for i, mention in enumerate(table_mentions):
        if (mention[0][0], mention[0][1]) in dataset_dict[table_id]:
            print("Duplicate mention in table", table_id, "at row", mention[0][0], "and col", mention[0][1])
        dataset_dict[table_id][(mention[0][0], mention[0][1])] = (table_labels[i], tuple(table_entities_index[i]))

In [None]:
del dataset

In [None]:
prefix + dataset_type + "_entity_linking_results{}_split*.pkl".format("_dedup" if dedup else "")

In [None]:
glob.glob(prefix + dataset_type + "_entity_linking_results{}_split*.pkl".format("_dedup" if dedup else ""))

In [None]:
our_tp = 0
all_test_files = glob.glob(
    prefix + dataset_type + "_entity_linking_results{}_split*.pkl".format("_dedup" if dedup else "")
)
for file in tqdm(all_test_files):
    with open(file, "rb") as f:
        test_results = pickle.load(f)
    for table in tqdm(test_results):
        for i, (row, col) in enumerate(table[1]):
            r = dataset_dict[table[0]][(row, col)]
            our_tp += get_tp([[], [r, [table[2][i], table[3][i]]]])

In [None]:
all_predicted = 0
for table in dataset_dict.keys():
    all_predicted += len(dataset_dict[table].keys())
all_predicted

### Compute metrics

In [None]:
len(gt_df), all_predicted

In [None]:
all_gt = len(gt_df)
prec = our_tp / all_predicted
rec = our_tp / all_gt
f1 = 2 * (prec * rec) / (prec + rec)
all_predicted, all_gt, f1, prec, rec