In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
from transformers import BartForConditionalGeneration, BartTokenizer
from genre.trie import Trie
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_hf

In [None]:
# loading trie from KILT Wikipedia titles
with open("data/kilt_titles_trie.pkl", "rb") as f:
    trie = pickle.load(f)
    
# defining the funciton to apply the constraints with the entities trie
def prefix_allowed_tokens_fn(batch_id, sent):
    return trie.get(sent.tolist())

# Entity Disambiguation

In [None]:
# loading tokenizer and model
tokenizer = BartTokenizer.from_pretrained("models/hf_entity_disambiguation_aidayago")
model = (
    BartForConditionalGeneration.from_pretrained("models/hf_entity_disambiguation_aidayago")
    .eval()
    .to("cuda:0")
)

In [None]:
# create inputs
sentences = [" [START_ENT] London [END_ENT] is the capital of the UK."]
input_args = {
    k: v.to(model.device) for k, v in tokenizer.batch_encode_plus(
        sentences,
        padding=True,
        return_tensors="pt"
    ).items()
}

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

# Document Retieval

In [None]:
# loading tokenizer and model
tokenizer = BartTokenizer.from_pretrained("models/hf_wikipage_retrieval")
model = (
    BartForConditionalGeneration.from_pretrained("models/hf_wikipage_retrieval")
    .eval()
    .to("cuda:0")
)

In [None]:
# create inputs
sentences = ["Stripes had Conrad Dunn featured in it"]
input_args = {
    k: v.to(model.device) for k, v in tokenizer.batch_encode_plus(
        sentences,
        return_tensors="pt"
    ).items()
}

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

# End-to-End Entity Linking

In [None]:
# loading tokenizer and model
tokenizer = BartTokenizer.from_pretrained("models/hf_e2e_entity_linking_wiki_abs")
model = (
    BartForConditionalGeneration.from_pretrained("models/hf_e2e_entity_linking_wiki_abs")
    .eval()
    .to("cuda:0")
)

In [None]:
# create inputs
sentences = [" London is the capital of the UK "]
input_args = {
    k: v.to(model.device) for k, v in tokenizer.batch_encode_plus(
        sentences,
        return_tensors="pt"
    ).items()
}

# no constrains on mention and candidates
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_hf(tokenizer, sentences)

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

In [None]:
# constraining the mentions with a prefix tree - no constrains on candidates
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_hf(
    tokenizer,
    sentences,
    mention_trie=Trie([
        tokenizer.encode(e)[1:]
        for e in [" London"]
    ])
)

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

In [None]:
# constraining the candidate sets given a mention
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_hf(
    tokenizer,
    sentences,
    mention_to_candidates_dict={
        "London": ["London"],
        "UK": ["UK"],
    }
)

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

In [None]:
# constraining the candidates with a prefix tree - no constrains on mentions
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_hf(
    tokenizer,
    sentences,
    candidates_trie=Trie([
        tokenizer.encode(" }} [ {} ]".format(e))[1:]
        for e in ["London", "UK", "NIL"]
    ])
)

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

In [None]:
# loading mention -> candidates used with GERBIL
with open("data/mention_to_candidates_dict_gerbil.pkl", "rb") as f:
    mention_to_candidates_dict = pickle.load(f)
    
# loading the mention trie used with GERBIL
with open("data/mention_trie_gerbil.pkl", "rb") as f:
    mention_trie = pickle.load(f)

In [None]:
# constraining the candidates with a prefix tree - no constrains on mentions
prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_hf(
    tokenizer,
    sentences,
    mention_trie=mention_trie,
    mention_to_candidates_dict=mention_to_candidates_dict,
)

# generating from the model
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=5,
        num_return_sequences=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)

In [None]:
from genre.utils import get_entity_spans_hf

In [None]:
sentences = ["London is the capital of the UK"]

In [None]:
get_entity_spans_hf(model, tokenizer, sentences)