In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
from genre import GENRE
from genre.trie import Trie
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_fariseq

In [None]:
import fairseq
fairseq.__file__

In [None]:
# loading trie from KILT Wikipedia titles
with open("data/kilt_titles_trie.pkl", "rb") as f:
    trie = pickle.load(f)
    
# defining the funciton to apply the constraints with the entities trie
def prefix_allowed_tokens_fn(batch_id, sent):
    return trie.get(sent.tolist())

# Entity Disambiguation

In [None]:
# loading model
model = (
    GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago")
    .eval()
    .to("cpu:0")
)

In [None]:
# create inputs
sentences = [" [START_ENT] London [END_ENT] is the capital of the UK."]

# generating from the model
model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

In [None]:
import blink.main_dense as main_dense
import argparse
import logging

models_path = "/private/home/ndecao/BLINK/models/" # the path where you stored the BLINK models

config = {
    "test_entities": None,
    "test_mentions": None,
    "interactive": False,
    "biencoder_model": models_path+"biencoder_wiki_large.bin",
    "biencoder_config": models_path+"biencoder_wiki_large.json",
    "entity_catalogue": models_path+"entity.jsonl",
    "entity_encoding": models_path+"all_entities_large.t7",
    "crossencoder_model": models_path+"crossencoder_wiki_large.bin",
    "crossencoder_config": models_path+"crossencoder_wiki_large.json",
    "fast": True, # set this to be true if speed is a concern
    "output_path": "logs/", # logging directory
    "faiss_index": None,
    "index_path": None,
    "top_k": 100,
}

args = argparse.Namespace(**config)

models = main_dense.load_models(args, logger=None)

In [None]:
args = argparse.Namespace(**{**config, **models[1]})

In [None]:
data_to_link = [{
    "id": d["id"],
    "label": "unknown",
    "label_id": -1,
    "context_left": d["meta"]["right_context"].lower(),
    "mention": d["meta"]["mention"].lower(),
    "context_right": d["meta"]["left_context"].lower(),
} for d in dataset_gold[:100]]

In [None]:
args.top_k = 100
models[1]["top_k"] = 100

In [None]:
%%timeit -r 5 -n 1
main_dense.run(args, logging, *models, test_data=data_to_link)

In [None]:
args.__dict__, models[1]

In [None]:
from genre.utils import chunk_it, create_input
from tqdm.auto import tqdm

In [None]:
data_to_link = [create_input(d, max_length=128) for d in dataset_gold[:100]]

In [None]:
%%timeit -r 5 -n 1
for e in tqdm(chunk_it(data_to_link, 10)):
    model.sample(
        e,
        max_len_b=15,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    )
#  45s CPU / 10s GPU

In [None]:
model = model.to("cpu:0")

In [None]:
import jsonlines
import pickle

with jsonlines.open("/private/home/ndecao/KILT/data/aidayago2-dev-kilt.jsonl") as f:
    dataset_gold = [e for e in f]

# Document Retieval

In [None]:
# loading model
model = (
    GENRE.from_pretrained("models/fairseq_wikipage_retrieval")
    .eval()
    .to("cuda:0")
)

In [None]:
# create inputs
sentences = ["Stripes had Conrad Dunn featured in it"]

# generating from the model
model.sample(
    sentences,
#     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

# End-to-End Entity Linking

In [None]:
# loading model
model = (
    GENRE.from_pretrained("models/fairseq_e2e_entity_linking_wiki_abs")
    .eval()
    .to("cuda:0")
)

In [None]:
# create inputs
sentences = [" London is the capital of the UK "]

# no constrains on mention and candidates
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_fariseq(model, sentences)

# generating from the model
model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

In [None]:
# constraining the mentions with a prefix tree - no constrains on candidates
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_fariseq(
    model,
    sentences,
    mention_trie=Trie([
        model.encode(e).tolist()[1:]
        for e in [" London"]
    ])
)

# generating from the model
model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

In [None]:
# constraining the candidate sets given a mention
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_fariseq(
    model,
    sentences,
    mention_to_candidates_dict={
        "London": ["London"],
        "UK": ["UK"],
    }
)

# generating from the model
model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

In [None]:
# constraining the candidates with a prefix tree - no constrains on mentions
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_fariseq(
    model,
    sentences,
    mention_trie=mention_trie,
)

# generating from the model
model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

In [None]:
import tqdm

In [None]:
tqdm.__version__

In [None]:
print("Loading model")
model = (
    GENRE.from_pretrained("models/fairseq_e2e_entity_linking_wiki_abs")
    .eval()
    .to("cuda:0")
)

In [None]:
print("Loading mention_to_candidates_dict")
with open("data/mention_to_candidates_dict_gerbil.pkl", "rb") as f:
    mention_to_candidates_dict = pickle.load(f)

In [None]:
print("Loading mention_trie")
with open("data/mention_trie_gerbil.pkl", "rb") as f:
    mention_trie = pickle.load(f)

In [None]:
mention_to_candidates_dict["Obama"]

In [None]:
l = [float(e) for e in "90.8 & 93.3 & 89.8 & 90.9 & 76.0 & 87.5".split("&")]
sum(l) / len(l), len(l)

In [None]:
new_mention_trie = Trie([])

In [None]:
def set_trie(t_i, t_o, depth=0):
    for l, subt in (tqdm(t_i._leaves.items()) if depth == 0 else t_i._leaves.items()):
        t_o._leaves[l] = Trie([])
        set_trie(subt, t_o._leaves[l], depth=depth+1)

In [None]:
set_trie(mention_trie, new_mention_trie)

In [None]:
from tqdm.auto import tqdm

In [None]:
mention_trie._leaves[23084]._leaves[534]._leaves[2]._leaves

In [None]:
new_mention_trie._leaves[23084]._leaves[534]._leaves[2]._leaves

In [None]:
with open("data/mention_trie_gerbil.pkl", "wb") as f:
    pickle.dump(new_mention_trie, f)

In [None]:
with open("data/wiki_redirects.txt") as f:
    wiki_redirects = f.readlines()

In [None]:
wiki_redirects = [
    line.strip().split("\t")
    for line in tqdm(wiki_redirects)
]

In [None]:
from tqdm.auto import tqdm

In [None]:
wiki_redirects = load_redirections()

In [None]:
len(wiki_redirects)