In [1]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import copy
plt.style.use('/raid/lingo/akyurek/mplstyle')
plt.rc('font', serif='Times')
plt.rc('text', usetex=False)
plt.rcParams['figure.dpi'] = 150
plt.rcParams['figure.facecolor'] = 'white'
"""Google cloud directory that stores results of the experiments"""
BASE_DIR = "./"
METRICS_DIR = os.path.join(BASE_DIR, "metrics")

In [2]:
def read_jsonl(path):
    data = []
    with open(path) as f:
        for line in f:
            data.append(json.loads(line))
    return np.array(data)

def read_facts(abstracts):
    facts = set()
    for a in abstracts:
        for fact in a['facts'].split(';'):
            facts.add(fact)
    return facts

def read_no_facts(abstracts):
    return np.array([len(set(a['facts'].split(';'))) for a in abstracts])

def facts_to_field(facts, field="obj_uri"):
    if field == "obj_uri":
        v = [fact.split(',')[1] for fact in facts]
    elif field == "sub_uri":
        v = [fact.split(',')[2] for fact in facts]
    else:
        v = [fact.split(',')[0] for fact in facts]
    return v

def read_string_field(abstracts, field="obj_uri"):
    return np.array([a[field] for a in abstracts])


def get_sentence(abstract):
    targets = abstract['targets_pretokenized'].replace('<extra_id_0> ', '').strip()
    sentence = abstract['inputs_pretokenized'].replace('<extra_id_0>', targets)
    return sentence

In [3]:
abstracts_path = os.path.join(BASE_DIR, "TREx_lama_templates_v2", "abstracts", "all_used.jsonl")
queries_path = os.path.join(BASE_DIR, "TREx_lama_templates_v2", "all.jsonl.processed")
abstracts = read_jsonl(abstracts_path)
queries = read_jsonl(queries_path)
(len(abstracts), len(queries))

(1064471, 27528)

In [4]:
facts = read_facts(abstracts)
no_facts = read_no_facts(abstracts)

In [5]:
sentences = [get_sentence(abstract) for abstract in abstracts]
len(set(sentences))

441670

In [6]:
len(facts), tuple(f(no_facts) for f in (np.mean, np.std, np.min, np.max))

(317381, (3.4631746661017537, 3.3406391936613984, 1, 1047))

In [7]:
pos_nos_abstracts = tuple(len(set(facts_to_field(facts, field=field))) 
        for field in ('predicate_id', 'obj_uri', 'sub_uri'))


In [8]:
pos_nos_abstracts

(406, 18037, 224043)

In [9]:
objs = read_string_field(queries, field="obj_uri")
subs = read_string_field(queries, field="sub_uri")
predicates = read_string_field(queries, field="predicate_id")
pos_nos_queries = (len(set(predicates)), len(set(objs)), len(set(subs)))

In [10]:
pos_nos_queries 

(41, 1767, 25926)