In [None]:
%load_ext autoreload
%autoreload 2

# Entity Disambiguation (fairseq)

In [None]:
from genre import GENRE

# loading model
model = (
    GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago")
    .eval()
    .to("cuda:0")
)

In [None]:
import pickle

# loading trie from KILT titles
with open("/checkpoint/fabiopetroni/GENRE/home/GeNeRe/__GENRE/data/kilt/trie.pkl", "rb") as f:
    trie = pickle.load(f)

In [None]:
def prefix_allowed_tokens_fn(batch_id, sent):
    return trie.get(sent.tolist())

model.sample(
    [" [START_ENT] London [END_ENT] is the capital of the UK."],
    beam=5,
    max_len_b=15,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

# Entity Disambiguation (huggingface transformers)

In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer

# loading tokenizer and model
tokenizer = BartTokenizer.from_pretrained("models/hf_entity_disambiguation_aidayago")
model = (
    BartForConditionalGeneration.from_pretrained("models/hf_entity_disambiguation_aidayago")
    .eval()
    .to("cuda:0")
)

In [None]:
def prefix_allowed_tokens_fn(batch_tokens):
    return [
        [
            trie.get(tokens.tolist())
            for tokens in beam_tokens
        ]
        for beam_tokens in batch_tokens
    ]

input_args = {
    k: v.to(model.device) for k, v in tokenizer.batch_encode_plus(
        [" [START_ENT] London [END_ENT] is the capital of the UK."],
        return_tensors="pt"
    ).items()
}

tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

# ------ from this point on the code might not work: I am refactoring -----

In [None]:
model_dict = torch.load("checkpoint_blink.pt")
model_dict.pop("last_optimizer_state")
torch.save(model_dict, "/private/home/ndecao/models/ed/checkpoint_blink.pt")

# Document Retieval

In [None]:
# loading model -- takes 5-20 sec
model_path = "models/kilt"
checkpoint_file = "checkpoint.pt"
model = (
    GENRE.from_pretrained(model_path, checkpoint_file=checkpoint_file)
    .eval()
    .to("cuda:0")
)

In [None]:
# No need to re-load the trie from above
# also no need to re-define `prefix_allowed_tokens_fn`
model.sample(
    ["Which US nuclear reactor hada major accident in 1979?"],
    beam=10,
    max_len_b=15,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

# End-to-End Entity Linking

In [None]:
from genre.entity_linking_model import GeNeRe

In [None]:
# loading model (there is some logging from the Kolitas code)
# it takes 4-5 min to load (I will optimize it when I'm back - it can be reduced to half of it with some tricks)
# I haven't refatored this yet so:
# i) code is a mess (in `genre.entity_linking_model`)
# ii) it is different class than the models above
model = GeNeRe(
    model_path='models/el',
    checkpoint_file='checkpoint_aidayago.pt',
    device="cuda:0",
)

In [None]:
# there is some logging
sentence = "London is the capital of the UK."
spans = model.get_prediction(sentence)
spans

In [None]:
model.get_markdown(sentence, spans)

# End-to-End Entity Linking v2

In [None]:
from genre.entity_linking_model_v2 import GENREForEndToEndEntityLinking

In [None]:
model = GENREForEndToEndEntityLinking.from_pretrained(
    'models/el',
    checkpoint_file='checkpoint_aidayago.pt',
    device="cuda:0",
    mention_trie_file="data/el_v2/mention_trie.pkl",
    candidates_dict_file="data/el_v2/candidates_dict.pkl",
)

In [None]:
model.sample(
    [" London is the capital of the UK."],
    beam=6,
    max_len_b=1024,
)