In [5]:
import os
import gzip
import json
import marisa_trie

In [6]:
print(os.getcwd())

/home/kg-augmented-lm/notebooks


In [7]:
corpus_path = os.path.join(os.getcwd(), '../data/wikidata5m/wikidata5m_text.txt.gz')
alias_path = os.path.join(os.getcwd(), '../data/wikidata5m/wikidata5m_alias.tar.gz')
nq_path = os.path.join(os.getcwd(), '../data/natural_questions/simplified-nq-train.jsonl.gz')

In [8]:
def filter_alias(alias):
    common_words = {'the', 'a', 'is', 'an', 'of', 'and', 'in', 'on', 'for', 'to', 'by', 'with', 'out', 'at', 'from', 'as', 'into', 'onto', 'upon', 'after', 'before', 'over', 'under', 'between', 'among', 'through', 'during', 'since', 'until', 'against', 'without', 'towards', 'amongst', 'amongst', 'within', 'throughout', 'despite', 'along', 'alongside', 'besides', 'beyond', 'near', 'above', 'below', 'behind', 'underneath', 'underneath', 'opposite', 'adjacent', 'outside', 'inside', 'benefit'}
    if alias in common_words:
        return False
    elif len(alias) < 3:
        return False
    else:
        return True

In [9]:
def build_marisa_trie_from_aliases(alias_file):
    # Read the aliases file
    with gzip.open(alias_file, 'rt', encoding='utf-8', errors='ignore') as file:
        lines = file.readlines()[1:4813490]
    
    alias_to_id = {}
    for line in lines:
        aliases = line.strip().split('\t')
        for alias in aliases[1:]:
            if filter_alias(alias):
                alias_to_id[alias] = aliases[0]

    # Build the marisa-trie
    trie = marisa_trie.Trie(alias_to_id.keys())

    # Get memory usage (rough estimate)

    return trie, alias_to_id

In [10]:
trie, alias_to_id = build_marisa_trie_from_aliases(alias_path)

In [11]:
def link_entities(trie, alias_to_id, document):
    tokens = document.split()  # Basic whitespace-based tokenization
    i = 0
    entities_found = []

    while i < len(tokens):
        longest_match = None
        longest_match_length = 0
        current_string = tokens[i].lower()

        # Check for entity match starting from the current token
        if current_string in trie:
            entity_id = alias_to_id[current_string]
            longest_match = (entity_id, i, i)
            longest_match_length = 1

        # Extend the match to subsequent tokens to find longer matches
        j = i + 1
        while j < len(tokens) and current_string in trie:
            current_string += " " + tokens[j].lower()
            if current_string in trie:
                entity_id = alias_to_id[current_string]
                longest_match = (entity_id, i, j)
                longest_match_length = j - i + 1
            j += 1

        # If a match was found, add it to entities_found and skip the matched tokens
        if longest_match:
            entities_found.append(longest_match)
            i += longest_match_length
        else:
            i += 1

    # Convert token positions to character positions
    char_entities_found = []
    char_pos = 0
    token_index = 0
    for entity in entities_found:
        while token_index < entity[1]:
            char_pos += len(tokens[token_index]) + 1  # +1 for space
            token_index += 1
        start_pos = char_pos
        while token_index <= entity[2]:
            char_pos += len(tokens[token_index]) + 1
            token_index += 1
        end_pos = char_pos - 2  # -1 for last space, -1 to get the end of the word
        char_entities_found.append((entity[0], start_pos, end_pos))
        
    return char_entities_found

In [12]:
alias_to_id['Hawaii']

'Q782'

In [13]:
sorted_alias_to_id = sorted(list(alias_to_id.items()), key = lambda key : len(key[0]))

In [14]:
sorted_alias_to_id[:100]

[('AER', 'Q1138408'),
 ('雪代巴', 'Q7574735'),
 ('緋村巴', 'Q7574735'),
 ('경기도', 'Q20937'),
 ('京畿道', 'Q20937'),
 ('la♭', 'Q549905'),
 ('La♭', 'Q549905'),
 ('j65', 'Q2594466'),
 ('J65', 'Q2594466'),
 ('φσα', 'Q2740587'),
 ('ΦΣΑ', 'Q2740587'),
 ('하리수', 'Q482786'),
 ('이경엽', 'Q482786'),
 ('이경은', 'Q482786'),
 ('박가희', 'Q488725'),
 ('박지영', 'Q488725'),
 ('賈曉晨', 'Q712083'),
 ('rd/', 'Q630961'),
 ('RD/', 'Q630961'),
 ('NET', 'Q64138'),
 ('vog', 'Q1433263'),
 ('VOG', 'Q1433263'),
 ('H₂O', 'Q130615'),
 ('h₂o', 'Q130615'),
 ('b26', 'Q562167'),
 ('B26', 'Q562167'),
 ('A26', 'Q562167'),
 ('a26', 'Q562167'),
 ('iyf', 'Q190193'),
 ('Iyf', 'Q190193'),
 ('1am', 'Q855630'),
 ('1mm', 'Q855630'),
 ('1Gm', 'Q855630'),
 ('1cm', 'Q855630'),
 ('1nm', 'Q855630'),
 ('1pm', 'Q855630'),
 ('1Ym', 'Q855630'),
 ('1fm', 'Q855630'),
 ('1 m', 'Q855630'),
 ('1Zm', 'Q855630'),
 ('1hm', 'Q855630'),
 ('1Pm', 'Q855630'),
 ('1zm', 'Q855630'),
 ('1gm', 'Q855630'),
 ('1km', 'Q855630'),
 ('1ym', 'Q855630'),
 ('1Tm', 'Q855630'),
 ('1dm'

In [15]:
sample_text = "Barack Obama was born in Hawaii that is for sure. He was the president of the United States of America. He was married to Michelle Obama. He has two daughters. He was the first black president of the United States of America. The capital of the United States of America is Washington D.C. The current president of the United States of America is Donald Trump. Donald Trump is a republican. Donald Trump is married to Melania Trump. Donald Trump has five children. Donald Trump is the 45th president of the United States of America."
entities = link_entities(trie, alias_to_id, sample_text)
print(entities)

[('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)]


In [17]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

with gzip.open(nq_path, 'rt', encoding='utf-8') as f:
    short_document_texts = 0
    single_entity_answers = 0
    multiple_entity_answers = 0
    short_document_lengths = []
    one_node_short_answers = []
    multiple_node_short_answers = []
    one_node_short_answer_entities = []
    for i, line in enumerate(f):
        entry = json.loads(line)
        
        question = entry['question_text']
        short_answer = entry['annotations'][0]['short_answers']
        long_answer = entry['annotations'][0]['long_answer']['candidate_index']
        
        if (short_answer != []) and (long_answer != -1):
            # num_token_ids = len(tokenizer(entry['document_text'])['input_ids'])
            
            # if num_token_ids < 256:
            #     short_document_texts += 1
            #     short_document_lengths.append(num_token_ids)
            
            answer = " ".join(entry['document_text'].split(' ')[short_answer[0]['start_token']:short_answer[0]['end_token']])
            entities = link_entities(trie, alias_to_id, answer)
            if len(entities) == 1:
                single_entity_answers += 1
                one_node_short_answers.append(answer)
                one_node_short_answer_entities.append(entities[0][0])
            elif len(entities) > 1:
                multiple_entity_answers += 1
                multiple_node_short_answers.append(answer)
                
                # if i > 10000:
                #     print(counter)
                #     break
    answer_id_pairs = list(zip(one_node_short_answers, one_node_short_answer_entities))
    
    # print(f"Short document texts out of all: {short_document_texts}/{i} ({short_document_texts/i*100:.2f}%))")
    # print(f"Average length of short documents: {sum(short_document_lengths)/len(short_document_lengths):.2f}")
    
    # print('\n')    
    
    print(f"1-node short answers out of all: {single_entity_answers}/{i} ({single_entity_answers/i*100:.2f}%))")
    print(answer_id_pairs[:10])
    
    print('\n')
    
    print(f"Multiple-node short answers out of all: {multiple_entity_answers}/{i} ({multiple_entity_answers/i*100:.2f}%))")
    print(multiple_node_short_answers[:10])

1-node short answers out of all: 58423/307372 (19.01%))
[('Tracy McConnell', 'Q15362106'), ('Tom Brady', 'Q313381'), ('Pom Klementieff', 'Q3395911'), ('SFR Sport', 'Q724353'), ('Persian Gulf', 'Q34675'), ('Kareem Abdul - Jabbar', 'Q700685'), ('six', 'Q23488'), ('1936', 'Q18649'), ("Leon `` Kida '' Burns", 'Q3085442'), ('Freddie Highmore', 'Q296887')]


Multiple-node short answers out of all: 36670/307372 (11.93%))
["a newsletter sent to an advertising firm 's customers", 'the use of lush string arrangements with a real orchestra and often , background vocals provided by a choir', 'an adult licensed driver who is at least 21 years of age or older and in the passenger seat of the vehicle at all times', 'March 30 , 2018', 'Miroslav Lajčák of Slovakia', 'every Wednesday and Saturday evening at 10 : 59 p.m. Eastern Time', 'August 21 , 2017', 'January 26 , 2018', 'Tokyo for the 2020 Summer Olympics', 'an attempt to establish the sovereignty it had claimed over them']


In [99]:
zip_answer_entities = zip(one_node_short_answers, one_node_short_answer_entities)
print(list(zip_answer_entities))

[('Tracy McConnell', 'Q15362106'), ('Tom Brady', 'Q313381'), ('Pom Klementieff', 'Q3395911'), ('SFR Sport', 'Q724353'), ('Persian Gulf', 'Q34675'), ('Kareem Abdul - Jabbar', 'Q700685'), ('six', 'Q23488'), ('1936', 'Q18649'), ("Leon `` Kida '' Burns", 'Q3085442'), ('Freddie Highmore', 'Q296887'), ('brain', 'Q562892'), ('Adam Sandler', 'Q132952'), ('Manchester United', 'Q18656'), ('The Lego Movie', 'Q2608065'), ('atoms', 'Q1324392'), ('Freak the Freak Out', 'Q489322'), ('Seyseys , chief of the Canarsees', 'Q1162163'), ('2017', 'Q25290'), ('Instagram', 'Q209330'), ('$140 billion +', 'Q16021'), ('The Black Dirt Region', 'Q36784'), ('tabla', 'Q213100'), ('Gary Oldman', 'Q83492'), ('season six finale', 'Q23488'), ('April 25 , 2018', 'Q25291'), ('Ellen Burstyn', 'Q211144'), ('Dublin', 'Q1761'), ('Tina Cole', 'Q212772'), ('Torrens title', 'Q56019'), ('Rosemary Butler', 'Q7368322'), ('2016', 'Q25245'), ('Arkansas', 'Q1612'), ('Michael Jackson', 'Q2831'), ('UB40', 'Q560153'), ('13.7 % of the GDP

In [None]:
with gzip.open(corpus_path, 'rb') as f:
    corpus = f.readlines()