### Notebook to load the model for inference

In [1]:
# import the libraries
import sys
sys.path.append('../code/')

from components.Summarizer import Summarizer
from components.Question import Question
from components.Knowledge_graph import Knowledge_graph
from components.Language import Language
from components.Query import Query

Setting ds_accelerator to cuda (auto detect)


In [2]:
# path to model
model_path = "../fine-tuned_models/qald9plus-finetune"
# input data attributes for model
# KG to use
knowledge_graph="Wikidata"
# to extract & utilize linguistic information or not
linguistic_context = True
# to extract & utilize entity knowledge or not
entity_knowledge = True
# Padding length for the question
question_padding_length = 32
# Padding length for the entity
entity_padding_length = 5

In [3]:

# Initialize the model
sparql_model = Summarizer(model_path)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
kg = Knowledge_graph[knowledge_graph]

def get_wikidata_entities(question: Question):
        ner = Language.get_supported_ner(question.language)
        if ner == "no_ner":
            entity_knowledge = []
        else:
            entity_knowledge = question.recognize_entities(ner, "mgenre_el")
        return entity_knowledge

def get_dbpedia_entities(question: Question):
        return question.recognize_entities("babelscape_ner" ,"mag_el")

# Function to prepare input
def prep_input(input_str, lang, linguistic_context, entity_knowledge, question_padding_length, entity_padding_length, kg):
    lang = Language(lang)
    question = Question(input_str, lang)
    question_string = question.question_string
    if linguistic_context:
        question_string = question.get_question_string_with_lingtuistic_context(question_padding_length)
    if entity_padding_length:
        if kg==Knowledge_graph.Wikidata:
            entity_knowledge = get_wikidata_entities(question)
        elif kg==Knowledge_graph.DBpedia:
            entity_knowledge = get_dbpedia_entities(question)
    question_string = question.add_entity_knowledge(question_string, entity_knowledge, entity_padding_length)
    return question_string


In [6]:
lang = "en"
question_str = "What is the time zone of Salt Lake City?"
processed_question_string = prep_input(question_str, lang, linguistic_context, entity_knowledge, question_padding_length, entity_padding_length, kg)
print(processed_question_string)

What is the time zone of Salt Lake City? <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> PRON AUX DET NOUN NOUN ADP PROPN PROPN PROPN PUNCT <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> attr ROOT det compound nsubj prep compound compound pobj punct <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 2 1 3 3 2 3 6 5 4 2 <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> wd_Q23337 <pad> <pad> <pad> <pad>


In [7]:

pred_sparql = sparql_model.predict_sparql(processed_question_string)
query = Query(pred_sparql, kg)
print(query.sparql)

SELECT DISTINCT  ?uri WHERE  {  wd:Q23337 wdt:P421  ?uri  }  
