# Calculate entity and relation F1 score

In [120]:
from sklearn.metrics import f1_score
import json
import re
import numpy as np
import pandas as pd


## Load and preprocess queries from ref dataset

In [151]:
prefix_pattern = [
    [r'<http://dbpedia.org/resource/(.*?)>\.?', 'dbr:'],
    [r'<http://dbpedia.org/property/(.*?)>\.?', 'dbp:'],
    [r'<http://dbpedia.org/ontology/(.*?)>\.?', 'dbo:'],
    [r'<http://dbpedia.org/class/yago/(.*?)>\.?', 'yago:'],
    [r'onto:(.*)', 'dbo:'],
    [r'<http://www.wikidata.org/prop/direct/(.*?)>', 'wdt:'],
    [r'<http://www.wikidata.org/entity/(.*?)>', 'wd:'],
    [r'http://www.wikidata.org/prop/(.*?)', 'p:'],
    [r'<http://www.w3.org/2000/01/rdf-schema#(.*?)', 'rdfs:']
]

def delete_sparql_prefix(sparql_query):
    if "prefix" not in sparql_query.casefold():
        return sparql_query
    if "ASK" in sparql_query:
        return "ASK" + sparql_query.split("ASK", 1)[1]
    return "SELECT" + sparql_query.split("SELECT", 1)[1]


def replace_prefix_abbr(sparql_query):
    for pattern in prefix_pattern:
        sparql_query = re.sub(pattern[0], pattern[1] + r'\1', sparql_query)
    sparql_query = re.sub(' +', ' ', sparql_query)
    return sparql_query

def read_dataset_in_lang(dataset, lang):
    with open(dataset, "r") as f:
        data = json.load(f)

    questions_list = data["questions"] 
    df = pd.DataFrame(columns=["id", "language", "question", "query"])   
    for question_dict in questions_list:
        row = dict()
        row["id"] = question_dict["id"]
        row["language"] = lang
        for question in question_dict["question"]:
            if question["language"] == lang:
                row["question"] = question["string"]
                row["query"] = replace_prefix_abbr(delete_sparql_prefix(question_dict["query"]["sparql"]))
                row_df = pd.DataFrame(row, index=[row["id"]])
                df = pd.concat([df, row_df], axis=0, ignore_index=True)
                break
    return df

def read_ref_pred(ref_dataset, pred_dataset, lang):
    ref_df = read_dataset_in_lang(ref_dataset, lang)
    pred_df = read_dataset_in_lang(pred_dataset, lang)
    ref_df = ref_df.rename(columns={"query": "query_ref"})
    pred_df = pred_df.rename(columns={"query": "query_pred"})
    return ref_df.merge(pred_df, on=["id", "language", "question"])


In [122]:
df = read_ref_pred("qald_9_plus_test_wikidata.json", "edrf_lt.json", "lt")
df.head

<bound method NDFrame.head of       id language                                           question  \
0     99       lt                  Kokia Solt Leik Sičio laiko zona?   
1     98       lt                                 Kas nužudė Cezarį?   
2     86       lt              Pats Aukščiausias kalnas Vokietijoje?   
3     81       lt  Kurios valstijos gubernatorius yra Butchas Ott...   
4     66       lt  Kurie aktoriai gimė tą pačią dieną kaip Rachel...   
..   ...      ...                                                ...   
121  101       lt            Kiek įmonių įkūrė „Facebook“ steigėjas?   
122   87       lt              Kuri knyga turi daugiausiai puslapių?   
123  148       lt                     Kokia didžiausia JAV valstija?   
124   43       lt  Pasakyk man kompanijų svetaines su daugiau nei...   
125  179       lt        Kokios buvo trys Kolumbo laivų pavadinimai?   

                                             query_ref  \
0    SELECT DISTINCT ?o1 WHERE { wd:Q23337 wdt:

## Calculate functions

In [123]:
def collect_entities(query):
    entity_list = re.findall(r'wd:[A-Z]*[0-9]*', query)
    return entity_list


In [124]:
query = "SELECT DISTINCT ?uri WHERE { ?uri wdt:P106 wd:Q116 ; wdt:P26 ?spouse . ?spouse wdt:P27 wd:Q183 . } "
entity_list = collect_entities(query)
entity_list


['wd:Q116', 'wd:Q183']

In [125]:
def collect_relations(query):
    relation_list = re.findall(r'wdt:[A-Z]*[0-9]*', query)
    return relation_list


In [126]:
query = "SELECT DISTINCT ?uri WHERE { ?uri wdt:P106 wd:Q116 ; wdt:P26 ?spouse . ?spouse wdt:P27 wd:Q183 . } "
relation_list = collect_relations(query)
relation_list


['wdt:P106', 'wdt:P26', 'wdt:P27']

In [127]:
def calculate_f1(ref, pred):
    max_length = max(len(ref), len(pred))
    if len(ref) < max_length:
        ref.extend(["N/A_ref" for i in range(max_length - len(ref))])
    if len(pred) < max_length:
        pred.extend(["N/A_pred" for i in range(max_length - len(pred))])
    if ref == [] and pred == []:
        ref, pred = ["N/A_ref"], ["N/A_pred"]
    return f1_score(ref, pred, average="macro")


In [135]:
def collect_entity_relation_list(ref_dataset, pred_dataset, lang):
    ref_pred_df = read_ref_pred(ref_dataset, pred_dataset, lang)

    entity_list_ref = []
    relation_list_ref = []

    entity_list_pred = []
    relation_list_pred = []

    for i in range(len(ref_pred_df)):
        query_ref = ref_pred_df.loc[i, "query_ref"]
        query_pred = ref_pred_df.loc[i, "query_pred"]
        
        entity_list_ref.append(collect_entities(query_ref)) 
        relation_list_ref.append(collect_relations(query_ref))
        entity_list_pred.append(collect_entities(query_pred))
        relation_list_pred.append(collect_relations(query_pred))

    return entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred

In [136]:
def cal_entity_f1(entity_list_ref, entity_list_pred):
    entity_f1_list = []

    for i in range(len(entity_list_ref)):
        entity_f1_list.append(calculate_f1(entity_list_ref[i], entity_list_pred[i]))

    entity_f1_list = np.asarray(entity_f1_list)
    entity_f1 = np.average(entity_f1_list)
    return round(entity_f1, 3)

def cal_relation_f1(relation_list_ref, relation_list_pred):
    relation_f1_list = []

    for i in range(len(relation_list_ref)):
        relation_f1_list.append(calculate_f1(relation_list_ref[i], relation_list_pred[i]))

    relation_f1_list = np.asarray(relation_f1_list)
    relation_f1_list = np.nan_to_num(relation_f1_list)
    relation_f1 = np.average(relation_f1_list)
    return round(relation_f1, 3)

## Calculate

### edrf

| pred_file | entity F1 | relation F1 |
|-----------|-----------|-------------|
| edrf_en   | 0.174     | 0.269       |
| edrf_de   | 0.166     | 0.266       |
| edrf_ru   | 0.174     | 0.305       |
| edrf_lt   | 0.151     | 0.191       |

#### edrf_en

In [137]:
pred_dataset = "edrf_en.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset, "en")

In [138]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.174

In [139]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.269

#### edrf_de

In [140]:
pred_dataset = "edrf_de.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset, "de")

In [141]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.166

In [142]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.266

#### edrf_ru

In [143]:
pred_dataset = "edrf_ru.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset,"ru")

In [144]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.174

In [145]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.305

#### edrf_lt

In [146]:
pred_dataset = "edrf_lt.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset, "lt")

In [149]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.151

In [150]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.191

### zero-shot

| pred_file      | entity F1 | relation F1 |
|----------------|-----------|-------------|
| zero-shot_en   | 0.179     | 0.333       |
| zero-shot_de   | 0.14      | 0.19        |
| zero-shot_ru   | 0.09      | 0.229       |

#### zero-shot en

In [43]:
pred_dataset = "zero-shot_en.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset)

In [44]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.179

In [45]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.333

#### zero-shot de

In [49]:
pred_dataset = "zero-shot_de.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset)

In [50]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.14

In [51]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.19

#### zero-shot ru

In [52]:
pred_dataset = "zero-shot_ru.json"
ref_dataset = "qald_9_plus_test_wikidata.json"
entity_list_ref, relation_list_ref, entity_list_pred, relation_list_pred = collect_entity_relation_list(ref_dataset, pred_dataset)

In [53]:
entity_f1 = cal_entity_f1(entity_list_ref, entity_list_pred)
entity_f1

0.09

In [54]:
relation_f1 = cal_relation_f1(relation_list_ref, relation_list_pred)
relation_f1

0.229