In [None]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import copy
import datasets
plt.style.use('/raid/lingo/akyurek/mplstyle')
plt.rc('font', serif='Times')
plt.rc('text', usetex=False)
plt.rcParams['figure.dpi'] = 150
plt.rcParams['figure.facecolor'] = 'white'

In [None]:
BASE_DIR = "LAMA/data/"
METRICS_DIR = os.path.join(BASE_DIR, "metrics")

In [None]:
def read_jsonl(path):
    data = []
    with open(path) as f:
        for line in f:
            data.append(json.loads(line))
    return np.array(data)

def read_facts(abstracts):
    facts = set()
    for a in abstracts:
        for fact in a['facts'].split(';'):
            facts.add(fact)
    return facts

def read_no_facts(abstracts):
    return np.array([len(set(a['facts'].split(';'))) for a in abstracts])

def facts_to_field(facts, field="obj_uri"):
    if field == "obj_uri":
        v = [fact.split(',')[1] for fact in facts]
    elif field == "sub_uri":
        v = [fact.split(',')[2] for fact in facts]
    else:
        v = [fact.split(',')[0] for fact in facts]
    return v

def read_string_field(abstracts, field="obj_uri"):
    return np.array([a[field] for a in abstracts])


def get_sentence(abstract):
    targets = abstract['targets_pretokenized'].replace('<extra_id_0> ', '').strip()
    sentence = abstract['inputs_pretokenized'].replace('<extra_id_0>', targets)
    return sentence

In [None]:
queries = list(datasets.load_dataset("data/ftrace", "queries", split="train"))
abstracts = datasets.load_dataset("data/ftrace", "abstracts", split="train")

In [None]:
(len(abstracts), len(queries))

In [None]:
abstracts[0]

def read_surface_stats(abstracts):
    entities = {}   
    for abstract in abstracts:
        uri = abstract['masked_uri']
        surface_form = abstract['targets_pretokenized'].replace('<extra_id_0> ', '')
        if uri not in entities:
            entities[uri] = set()
        entities[uri].add(surface_form)
    return entities

In [None]:
entities = read_surface_stats(abstracts)

In [None]:
abstract_sentences = list(map(lambda x: set(get_sentence(x).split()), abstracts))

In [None]:
len(abstract_sentences)

In [None]:
fact_to_ids = json.load(open(BASE_DIR + "TREx_lama_templates_v3/abstracts/fact_to_ids_used.json"))

In [None]:
ids_to_abstracts = {a["id"]: a for a in abstracts}

In [None]:
abstract_ids = abstracts["id"]

In [None]:
from src.lama_utils import get_sentence
def get_token_match(query, abstract, filtered_words=None):
    total = 0
    for w in query:
        if filtered_words is not None:
            if w in abstract:
                total += 1
        else:
            if w in filtered_words and w in abstract:
                total += 1
    return total

from src.metric_utils import reciprocal_rank
def mrr_of_exact_match(queries, abstracts, abstract_sentences, topk=250, filter=False):
    rr = []
    precisions = []
    recalls = []
    for query in queries:
        fact = query['predicate_id'] + ',' + query['obj_uri'] + ',' + query['sub_uri']
        
        if filter:
            filtered_words = []
            filtered_words += query['sub_surface'].split(' ').lower()
            filtered_words += query['obj_surface'].split(' ').lower()
        else:
            filtered_words = None
            
        ids = list(map(str, fact_to_ids[fact]))
        correct_idxs = [abstract_ids.index(id) for id in ids]
        query_sentence = set(get_sentence(query).lower().split())
        scores = []
        best_score = 0.0
        best_index = 0
        for index, abstract_sentence in enumerate(abstract_sentences):
            abstract_sentence = list(map(str.lower, abstract_sentence))
            score = get_token_match(query_sentence, abstract_sentence)
            if score > best_score:
                best_index = index
                best_score = score 
            scores.append(score)
        scores = np.array(scores)
        idxs = np.argpartition(scores, -250)[-250 :]
        nn_idxs = idxs[np.argsort(-scores[idxs])]
        nn_scores = scores[nn_idxs].tolist()
        rr.append(reciprocal_rank(nn_idxs, correct_idxs))
    
    return rr
        

In [None]:
rrs = mrr_of_exact_match(queries[:100], abstracts, abstract_sentences)
np.mean(rrs)

In [None]:
rrs = mrr_of_exact_match(queries[:100], abstracts, abstract_sentences, filter=True)
np.mean(rrs)

In [None]:
queries[0]

In [None]:
np.mean([len(v) for k, v in entities.items()])

In [None]:
facts = read_facts(abstracts)
no_facts = read_no_facts(abstracts)

In [None]:
pairs = {}
for fact in facts:
    predicate, *pair  = fact.split(',')
    pair = tuple(pair)
    if pair not in pairs:
        pairs[pair] = set()
    pairs[pair].add(predicate)
np.mean([len(v) for k, v in pairs.items()])

In [None]:
abstracts[0]

In [None]:
sentences = [get_sentence(abstract) for abstract in abstracts]
len(set(sentences))

In [None]:
len(facts), tuple(f(no_facts) for f in (np.mean, np.std, np.min, np.max))

In [None]:
pos_nos_abstracts = tuple(len(set(facts_to_field(facts, field=field))) 
        for field in ('predicate_id', 'obj_uri', 'sub_uri'))


In [None]:
pos_nos_abstracts

In [None]:
objs = read_string_field(queries, field="obj_uri")
subs = read_string_field(queries, field="sub_uri")
predicates = read_string_field(queries, field="predicate_id")
pos_nos_queries = (len(set(predicates)), len(set(objs)), len(set(subs)))

In [None]:
pos_nos_queries 

In [None]:
np.mean([len(v) for k, v in fact_to_ids.items()])

In [None]:
query = np.random.choice(queries)
print("====Query====\n", query)
fact = query['predicate_id'] + ',' + query['obj_uri'] + ',' + query['sub_uri']
fact_ids = fact_to_ids[fact]
current_abstracts = [ids_to_abstracts[str(id)] for id in fact_ids if str(id) in ids_to_abstracts]
print("====Abstracts====\n")
for a in current_abstracts:
    print(a)
    