In [1]:
import sys
import pandas as pd
import datasets
from pathlib import Path

from transformers import AutoTokenizer, AutoModelForSequenceClassification

sys.path.append('../camembert_dual_encoder/')
from camembert_dual_encoder import CamembertDualEncoderModel
from camembert_dual_encoder.data.embeddings import KeyedVectors

from model import ELPipeline, WikipediaMapper, WikidataPropertyGetter
from model import keep_best_candidates

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Loading known entities (their descriptions will be use for candidates desambiguisation)
df_entities = pd.read_csv('data/entity_mapping.csv', encoding='utf-8')
entity2description = dict(zip(df_entities['entity'], df_entities['description']))

In [5]:
# The model should have been extracted in the following bi_encoder_pretrained_model_path :
bi_encoder_pretrained_model_path = 'data/weights/biEncoder/'
bi_encoder = CamembertDualEncoderModel.from_pretrained(bi_encoder_pretrained_model_path).eval()
bi_encoder_tokenizer = AutoTokenizer.from_pretrained('camembert-base')

# Loading embeddings
bi_encoder_embeddings_dir = Path(bi_encoder_pretrained_model_path, 'embeddings')
embeddings = KeyedVectors.from_directory(bi_encoder_embeddings_dir)

In [6]:
# The model should have been extracted in the following cross_encoder_pretrained_model_path :
cross_encoder_pretrained_model_path = 'data/weights/crossEncoder/'
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_encoder_pretrained_model_path).eval()
cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_pretrained_model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
wiki_pages = datasets.load_dataset("gcaillaut/frwiki-20220601_page", split="train").to_pandas()
wiki_redirects = datasets.load_dataset("gcaillaut/frwiki-20220601_all_redirect", split="train").to_pandas()

wikipedia_mapper = WikipediaMapper(wiki_pages, wiki_redirects)
wikidata_getter = WikidataPropertyGetter.default()

In [8]:
# Build the object requires for prediction
el_pipeline = ELPipeline(bi_encoder, cross_encoder, bi_encoder_tokenizer, cross_encoder_tokenizer, embeddings, entity2description, wikipedia_mapper, wikidata_getter)


In [9]:
input_pipeline = ["Orléans est une commune du Centre-Nord-Ouest de la France sur les rives de la Loire, préfecture du département du Loiret et capitale de la région Centre-Val de Loire."]
output_pipeline = el_pipeline(input_pipeline)

keep_best_candidates(output_pipeline)