# Use longformer tokenizer + page id to indice based on 'entity.jsonl'

In [1]:
import json
from tqdm import tqdm
import numpy as np

import os
import string

In [2]:
from transformers import LongformerTokenizer

In [3]:
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

## Tokenized elq format

- 'id': 'doc_id'
- 'text': 'text'.join
- 'mentions': index span (char) of mentions, list of lists (end idx + 1)
- 'tokenized_text_ids': tokenized text id
- 'tokenized_mention_idxs': similar to span_position, list of lists (end idx + 1)
- 'label_id': id based on 'entity.jsonl'
- 'wikidata_id': not include, plan to remove in code
- 'entity': 'wiki_titles'
- 'label': 'wiki_contexts'

In [4]:
def get_tokenized_bounds(mentions, full_example, tokenizer):
    example_ranges = mentions
    
    char_in_mention_idx_map = [[] for _ in range(len(full_example))]
    all_mention_bounds = []
    for m, ment in enumerate(example_ranges):
        for c in range(ment[0], ment[1]):
            char_in_mention_idx_map[c].append(m)
        all_mention_bounds.append(ment[0])
        all_mention_bounds.append(ment[1])
    all_mention_bounds = [0] + all_mention_bounds + [len(full_example)]
    all_mention_bounds = list(set(all_mention_bounds))
    all_mention_bounds.sort()
    
    example_chunks = [full_example[all_mention_bounds[b]:(all_mention_bounds[b+1])] for b in range(len(all_mention_bounds) - 1)]
    chunk_idx_to_mention_idx_map = []
    bound_idx = 0
    for c, chunk in enumerate(example_chunks):
        assert bound_idx == all_mention_bounds[c]
        try:
            chunk_idx_to_mention_idx_map.append(char_in_mention_idx_map[all_mention_bounds[c]])
        except:
            print("error checkpoint")
            import pdb
            pdb.set_trace()
        bound_idx += len(chunk)
    mention_idx_to_chunk_idx_map = {}
    chunk_idx_to_tokenized_bounds = {}
    mention_idxs = []
    all_token_ids = []
    cum_len = 0
    for c, chunk in enumerate(example_chunks):
        #chunk_tokens = tokenizer.encode(chunk)
        chunk_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(chunk))
        all_token_ids += chunk_tokens
        chunk_bounds = [cum_len, cum_len+len(chunk_tokens)]
        for m in chunk_idx_to_mention_idx_map[c]:
            if m not in mention_idx_to_chunk_idx_map:
                mention_idx_to_chunk_idx_map[m] = chunk_bounds
            else:
                existing_chunk_bounds = mention_idx_to_chunk_idx_map[m]
                mention_idx_to_chunk_idx_map[m] = [
                    min(existing_chunk_bounds[0], chunk_bounds[0]),
                    max(existing_chunk_bounds[1], chunk_bounds[1]),
                ]
        cum_len += len(chunk_tokens)
    for mention_idx in range(len(mention_idx_to_chunk_idx_map)):
        assert mention_idx in mention_idx_to_chunk_idx_map
        mention_tokenized_bound = mention_idx_to_chunk_idx_map[mention_idx]
        mention_idxs.append(mention_tokenized_bound)
    for m in range(len(mention_idxs)):
        mention_bounds = mentions[m]
        mention_tok_bounds = mention_idxs[m]
        tokenized_mention = tokenizer.decode(all_token_ids[
            mention_tok_bounds[0]:mention_tok_bounds[1]
        ])
        #target_mention = full_example[mention_bounds[0]:mention_bounds[1]].lower()
        target_mention = full_example[mention_bounds[0]:mention_bounds[1]]
        try:
            assert tokenized_mention == target_mention
        except:
            # only keep letters and whitespace
            only_letter_tokenized_mention = ""
            only_letter_target_mention = ""
            for char in tokenized_mention:
                if char in string.ascii_letters:
                    only_letter_tokenized_mention += char
            for char in target_mention:
                if char in string.ascii_letters:
                    only_letter_target_mention += char
            print("{} {}".format(tokenized_mention, target_mention))
            try:
                assert only_letter_tokenized_mention.lower() == only_letter_target_mention.lower()
            except:
                print(only_letter_tokenized_mention, only_letter_target_mention)
                import pdb
                pdb.set_trace()
    return all_token_ids, mention_idxs

In [5]:
def to_longformer_tokenized(sample):
    res = {}
    res['id'] = sample['doc_id']
    res['text'] = ' '.join(sample['text'])
    
    # generate char indice 'mentions'
    assert len(sample['text'])==len(sample['start_idxs'])==len(sample['end_idxs'])
    char_cnt = 0
    mentions = []
    new_mention = []
    open_mention = False
    for i, word in enumerate(sample['text']):
        if sample['start_idxs'][i]==1:
            new_mention.append(char_cnt)
        char_cnt += len(word) + 1
        if sample['end_idxs'][i]==1:
            new_mention.append(char_cnt-1)
            mentions.append(new_mention)
            new_mention = []
    # sanity check
    assert len(mentions)==len(sample['mentions'])
    for i in range(len(mentions)):
        try:
            assert res['text'][mentions[i][0]:mentions[i][1]]==sample['mentions'][i]
        except:
            print(res['text'][mentions[i][0]:mentions[i][1]], ' ', sample['mentions'][i])
    res['mentions'] = mentions
    
    # generate tokenized texts and mention bounds
    # code from Belinda: create_all_entity_finetuning.py
    all_token_ids, mention_idxs = get_tokenized_bounds(res['mentions'], res['text'], tokenizer)
    res['tokenized_text_ids'] = all_token_ids
    res['tokenized_mention_idxs'] = mention_idxs
    assert len(mention_idxs)==len(sample['wiki_titles'])
    
    res['label_id'] = sample['wiki_ids']
    res['entity'] = sample['wiki_titles']
    res['label'] = sample['wiki_contexts']
    
    return res

In [6]:
in_fpath = 'AIDA-YAGO2-en_desc'
out_fpath = 'AIDA-YAGO2_longformer/tokenized'

fnames = ['train.json', 'dev.json', 'test.json']
num_longs = []

for fname in fnames:
    in_fname = os.path.join(in_fpath, fname)
    with open(in_fname) as fin:
        orig_data = json.load(fin)
    
    longformer_tokenized = []
    for sample in tqdm(orig_data):
        longformer_example = to_longformer_tokenized(sample)
        longformer_tokenized.append(longformer_example)
    
    fname = fname+'l'
    out_fname = os.path.join(out_fpath, fname)
    
    num_long = []
    with open(out_fname, 'w') as wf:
        for i, example in tqdm(enumerate(longformer_tokenized)):
            if len(example['tokenized_text_ids']) > 512:
                num_long.append(i)
            b = wf.write(json.dumps(example) + "\n")
    num_longs.append(num_long)

 54%|█████▍    | 514/946 [00:03<00:02, 185.65it/s]

Rep .   Rep.
Rep. Rep .
Goldman , Sachs & Co   Goldman, Sachs & Co
Goldman, Sachs & Co Goldman , Sachs & Co


 63%|██████▎   | 592/946 [00:03<00:02, 169.04it/s]

Wisc .   Wisc.
Wisc .   Wisc.
Wisc. Wisc .
Wisc. Wisc .
Washington , D.C.   Washington, D.C.
Washington, D.C. Washington , D.C.


 84%|████████▍ | 796/946 [00:04<00:00, 190.56it/s]

Colo .   Colo.
Colo. Colo .


 96%|█████████▌| 910/946 [00:05<00:00, 205.29it/s]

Colo .   Colo.
Colo. Colo .


100%|██████████| 946/946 [00:05<00:00, 167.93it/s]
946it [00:00, 8867.31it/s]
100%|██████████| 216/216 [00:01<00:00, 175.88it/s]
216it [00:00, 5008.82it/s]
100%|██████████| 231/231 [00:01<00:00, 181.11it/s]
231it [00:00, 5954.05it/s]


In [7]:
import requests

In [8]:
with open('models/id2title.json') as f:
    id2title = json.load(f)

In [11]:
all_wiki_ents = open("models/entity.jsonl").readlines()

In [12]:
all_wiki_ents = [json.loads(line) for line in all_wiki_ents]
print(len(all_wiki_ents))

5903527


In [13]:
#title2id = {line['title']: i for i, line in enumerate(all_wiki_ents)}

In [14]:
page2id = {line['idx'].split('=')[-1]: i for i, line in enumerate(all_wiki_ents)}

In [10]:
def _get_title_from_api(pageid, client=None):
    url = f"https://en.wikipedia.org/w/api.php?action=query&pageids={pageid}&format=json"

    try:
        # Package the request, send the request and catch the response: r
        r = requests.get(url)

        # Decode the JSON data into a dictionary: json_data
        json_data = r.json()

        if len(json_data["query"]["pages"]) > 1:
            print("WARNING: more than one result returned from wikipedia api")

        for _, v in json_data["query"]["pages"].items():
            title = v["title"]
    except:
        pass
    return title

In [21]:
def reindex(fpath, split):
    fname = fpath+split
    examples = []
    filelines = open(fname).readlines()
    for line in filelines:
        json_line = json.loads(line)
        examples.append(json_line)
    
    for e, example in tqdm(enumerate(examples)):
        old_label_id = example['label_id']
        entity = example['entity']
        
        new_label_id = []
        for i, old_id in enumerate(old_label_id):
            new_id = page2id[str(old_id)]
            new_label_id.append(new_id)
            try:
                assert all_wiki_ents[new_id]['title'] == entity[i]
            except:
                # try compare with wiki url result
                #old_id = int(old_id)
                title = id2title.get(str(old_id))
                if title is None:
                    title = _get_title_from_api(int(old_id))
                    id2title[old_id] = title
                try:
                    assert all_wiki_ents[new_id]['title'] == title or entity[i] == title 
                except:
                    print(e, ' ', example['id'], ' ', old_id, ' ', new_id, ' ', entity[i], ' ', all_wiki_ents[new_id]['title'])
#                 else:
#                     print(e, ' ', example['id'], ' ', old_id, ' ', new_id, ' ', entity[i], ' ', all_wiki_ents[new_id]['title'])
                entity[i] = all_wiki_ents[new_id]['title']
        example['label_id'] = new_label_id
        example['entity'] = entity
    return examples

In [22]:
splits = ['train.jsonl', 'dev.jsonl', 'test.jsonl']

inpath = f'AIDA-YAGO2_longformer/'
outpath = f'AIDA-YAGO2_longformer/tokenized/'
for split in splits:
    examples = reindex(inpath, split)
    with open(outpath+split, 'w') as wf:
        for example in tqdm(examples):
            b = wf.write(json.dumps(example) + "\n")

127it [00:05, 34.94it/s]

128   129 Viacom   24580262   2891949   Viacom (1971–2005)   Viacom (original)


241it [00:08, 72.76it/s]

235   236 Promodes   2688005   664924   Les Échos (France)   Les Échos (newspaper)


251it [00:08, 67.17it/s]

247   248 RUGBY   1196374   372695   Halifax RLFC   Halifax R.L.F.C.


260it [00:08, 45.14it/s]

259   260 SOCCER   10410246   1546285   OKS 1945 Olsztyn   Stomil Olsztyn (football)
266   267 SOCCER   1537131   443668   V.C. Eendracht Aalst 2002   SC Eendracht Aalst


327it [00:10, 31.64it/s]

322   323 RUGBY   1196374   372695   Halifax RLFC   Halifax R.L.F.C.
322   323 RUGBY   1196374   372695   Halifax RLFC   Halifax R.L.F.C.


340it [00:10, 40.08it/s]

341   342 SOCCER   10410246   1546285   OKS 1945 Olsztyn   Stomil Olsztyn (football)
341   342 SOCCER   10410246   1546285   OKS 1945 Olsztyn   Stomil Olsztyn (football)


370it [00:11, 26.64it/s]

365   366 SOCCER   1537131   443668   V.C. Eendracht Aalst 2002   SC Eendracht Aalst
365   366 SOCCER   1537131   443668   V.C. Eendracht Aalst 2002   SC Eendracht Aalst


417it [00:12, 51.50it/s]

414   415 RUGBY   1196374   372695   Halifax RLFC   Halifax R.L.F.C.


485it [00:14, 41.62it/s]

473   474 Senate   403248   171440   Sultan, Crown Prince of Saudi Arabia   Sultan bin Abdulaziz Al Saud


551it [00:15, 47.50it/s]

545   546 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma


641it [00:16, 77.12it/s]

639   640 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
640   641 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma


661it [00:17, 47.96it/s]

654   655 CRICKET   3182138   746081   Dave Richardson   Dave Richardson (cricketer)


748it [00:18, 65.08it/s]

726   727 Barrier   1873300   512081   Vale (mining company)   Vale (company)
726   727 Barrier   1873300   512081   Vale (mining company)   Vale (company)
726   727 Barrier   1873300   512081   Vale (mining company)   Vale (company)
726   727 Barrier   1873300   512081   Vale (mining company)   Vale (company)


801it [00:19, 85.16it/s]

790   791 PRESS   2286075   591510   Legal Department   Legal Department, Hong Kong
801   802 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
801   802 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
801   802 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
801   802 CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma


825it [00:19, 80.89it/s]

806   807 LOMBARDI   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma


870it [00:19, 115.86it/s]

875   876 PRESS   4665846   956613   Muslim Commercial Bank   MCB Bank Limited


946it [00:20, 45.19it/s] 
100%|██████████| 946/946 [00:01<00:00, 547.46it/s]
135it [00:00, 195.75it/s]

108   1055testa CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
108   1055testa CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
108   1055testa CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma
108   1055testa CYCLING   2354465   604821   Rabobank (cycling team)   Team Jumbo–Visma


216it [00:00, 253.14it/s]
100%|██████████| 216/216 [00:00<00:00, 5886.86it/s]
26it [00:00, 106.46it/s]

10   1173testb RUGBY   5746768   1087118   Dan Crowley   Dan Crowley (rugby player)
12   1175testb SOCCER   2384790   610490   AFC Progresul Bucureşti   AS Progresul București
12   1175testb SOCCER   2384790   610490   AFC Progresul Bucureşti   AS Progresul București
39   1202testb SOCCER   616593   235776   Luis Enrique Martínez García   Luis Enrique (footballer)


96it [00:00, 155.98it/s]

75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Newmont Mining Corporation   Newmont Goldcorp
75   1238testb Wall   1100754   349968   Ne

231it [00:01, 135.84it/s]
100%|██████████| 231/231 [00:00<00:00, 4652.19it/s]


In [23]:
with open('models/page2id.json', 'w') as f:
    json.dump(page2id, f)