In [5]:
import requests
import pickle
import torch
import pandas as pd
from datasets import Dataset 
import time

In [6]:

torch.cuda.empty_cache()
import gc
# del variables
gc.collect()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# device = "cpu"
device

device(type='cuda', index=0)

In [7]:
MAX_LEN = 50
OVERLAP = 10
TOP_K = 500
BATCH_SIZE = 64
DEGUB = False
THRESHOLD = 0.02

In [4]:
from transformers import pipeline

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
# model_checkpoint = "Mini_LCQUAD\checkpoint-500"
model_checkpoint = "Mini_EL2\checkpoint-500"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer, handle_impossible_answer=True, batch_size=BATCH_SIZE, device=device )
model_checkpoint = "Mini_LCQUAD2\checkpoint-500"
question_answerer2 = pipeline("question-answering", model=model_checkpoint, handle_impossible_answer=True, batch_size=BATCH_SIZE, device=device )

In [9]:
def expand(text, entity_ids = []):
    response = requests.get("https://qanswer-core1.univ-st-etienne.fr/api/entitylinker" , params={'text': text, 'language': 'en', 'knowledgebase': 'wikidata'})
    input_entities = []
    for r in response.json():
        if 'uri' in r and 'http://www.wikidata.org/entity/' in r['uri']:
            id = r['uri'].replace('http://www.wikidata.org/entity/','')
            if("wd:" + str(id) in entity_ids):
                entity = True
            else:
                entity = False
            input_entities.append({ "start":r['start'], "end":r['end'], "text": r['text'], "id": id, "description": str(r["qaContext"]["disambiguation"] or ''), "entity": entity  }) 
    return input_entities

In [6]:
def ENTQA(context, inputs):
    questions = [x["text"] + " : " + x["description"] for x in inputs]
    start_time = time.time()

    df = pd.DataFrame.from_records({"question": questions, "context": context})
    dataset = Dataset.from_pandas(df)
    ans = question_answerer(dataset)
    if(DEGUB): print("Linking Time: --- %s seconds ---" % (time.time() - start_time))
    # 
    for i in range(0, len(questions)):
        ans[i]["question"] = questions[i]
        ans[i]["id"] = inputs[i]["id"]
        ans[i]["start"] = inputs[i]["start"]
        ans[i]["end"] = inputs[i]["end"]
    ans = [{"question":x["question"],
            "ans": x["answer"],
            "start": x["start"],
            "end": x["end"],
            "score": x["score"],
            "QId":  x["id"]
            } for x in ans if(x["answer"] != '')]
    return {"context": context, "results": ans} 
        

In [7]:
def segment_text(text, max_len, overlap=0):
    words = text.split()
    segments = []
    current_seg = ""
    seg_len = 0
    wordCount = 0
    while wordCount  <  len(words):
        if(seg_len < max_len):
            current_seg += " "+words[wordCount]
            seg_len += 1
            wordCount += 1
        else:
            segments.append(current_seg.strip())
            current_seg = ""
            seg_len = 0
            if(len(words)-wordCount+overlap < max_len):
                wordCount = len(words)-max_len
            elif(wordCount-overlap > 0):
                wordCount -= overlap

    segments.append(current_seg.strip())
    return segments

In [285]:
def dropOverlapping(entities):
    to_drop = set()
    if(DEGUB): print("Total Entities: ", len(entities))
    # removing duplicate and ovelapping candidate
    for i in  range(0, len(entities)):
            if (i in to_drop):
                continue
            if(entities[i]["score"]*100 < THRESHOLD):
                to_drop.add(i)
                continue
            for j in  range(i, len(entities)):
                if (j in to_drop):
                    continue
                if(entities[i]["start"] < entities[j]["end"]) and (entities[i]["end"] > entities[j]["start"]):
                    if(entities[i]["score"] > entities[j]["score"]):   
                        to_drop.add(j)
                    elif(entities[i]["score"] < entities[j]["score"]):
                        to_drop.add(i)
                if(entities[i]["start"] == entities[j]["start"]) and (entities[i]["end"] == entities[j]["end"]):
                    if(entities[i]["score"] > entities[j]["score"]):   
                        to_drop.add(j)
                    elif(entities[i]["score"] < entities[j]["score"]):
                        to_drop.add(i)
    to_drop = list(to_drop)
    to_drop.sort()
    for i in range(0,len(to_drop)):
        del entities[to_drop[i]-i]
    return entities

def mergeResults(results):
    results_merged  = { "text": "",
                    "entities": []}
    results_merged["text"] = results[0]["context"]
    # results_merged["entities"].extend(mapIndexes(0, 0, results[0]["results"]))
    results_merged["entities"].extend(results[0]["results"]) 
    for i in range(1, len(results)-1):
        len_overlap = len(" ".join(results[i]["context"].split()[:OVERLAP]))
        len_seg = len(results_merged["text"])
        results_merged["text"] = results_merged["text"] + results[i]["context"][OVERLAP:]
        # results_merged["entities"].extend(mapIndexes(len_seg, len_overlap, results[i]["results"]))
        results_merged["entities"].extend(results[i]["results"])

    index = results[-1]["context"].index(results[-2]["context"][-20:]) + 20
    len_seg = len(results_merged["text"])
    results_merged["text"] = results_merged["text"] + results[-1]["context"][index:]
    # results_merged["entities"].extend(mapIndexes(len_seg, index, results[-1]["results"]))
    results_merged["entities"].extend(results[-1]["results"])

    return dropOverlapping(results_merged["entities"])

In [268]:
def loopENTQA(text):
    answers = []
    if(DEGUB): print("Token Counts: --- %s tokens ---" % (len(text.split())))
    start_time = time.time()
    # Expand the Text and fetch the candidates
    inputs = expand(text)
    if(DEGUB): print("Expending Time: --- %s seconds ---" % (time.time() - start_time))
    if(DEGUB): print("Candidate Entities (whole Text):", len(inputs))
    start_time = time.time()
    segments = segment_text(text, MAX_LEN, OVERLAP)
    # segment text with overlapping
    if(DEGUB): print("Segmentation Time: --- %s seconds ---" % (time.time() - start_time))
    previous_word_count = -1
    for i in range(0, len(segments)):
        candidates = []
        if(i == len(segments)-1):
            # select candidates in the specific segment
            candidates = [x for x in inputs if(x["end"] >= len(text.split()[-MAX_LEN:]) )]
        else:
            word_count = len(segments[i].split())
            candidates = [x for x in inputs if((x["start"] > previous_word_count)  and (x["end"] <= previous_word_count + word_count))]
            previous_word_count += word_count -1
        
        if(len(segments) == 1):
            candidates = inputs
        if(DEGUB): print("Candidate Entities (segment Text):", len(candidates))
        # received the positive candidates
        answers.append(ENTQA(segments[i], candidates))
    start_time = time.time()
    results = {"text":  text}
    if(len(segments)> 1):
        # merging the segments
        results["entities"] = mergeResults(answers)
    else:
        results["entities"] =  dropOverlapping(answers[-1]["results"])
    
    if(DEGUB): print("Merging Time: --- %s seconds ---" % (time.time() - start_time))
    return results

In [245]:
DEGUB = True
# text = "The project “RE-ODRA – social and economic activation of post-factory areas in Nowa Sól poland is related to revitalisation of the industrial areas of the former Odra factory together with the immediate surroundings, as part of Delta airline the tasks set out in the Local Programme for the Regeneration of the state of New York for the years 2016-2023.  Which female actress is the voice over on South Park American animated sitcom 48.67 and is employed as a singer? What periodical literature does Delta Air Lines use as a moutpiece? Delta Air Lines was the first airline to use Boeing 747. Delta went bankcrupt after the financial crisis in 2005 in state of New York."
# text = "Nowa Sól town on the Oder River in Lubusz Voivodeship, western Poland"
text = """Borgomasino is a comune (municipality) in the Metropolitan City of Turin in the Italian region Piedmont, located about 40 kilometres (25 mi) northeast of Turin.
# Among the sites are the Parish Church of Santissimo Salvatore designed by Bernardo Vittone and the castle."""
# text = """Leonardo di ser Piero da Vinci [b] (15 April 1452 – 2 May 1519) was an Italian polymath of the High Renaissance who was active as a painter, draughtsman, engineer, scientist, theorist, sculptor, and architect.[3] While his fame initially rested on his achievements as a painter, he also became known for his notebooks, in which he made drawings and notes on a variety of subjects, including anatomy, astronomy, botany, cartography, painting, and paleontology. Leonardo is widely regarded to have been a genius who epitomized the Renaissance humanist ideal,[4] and his collective works comprise a contribution to later generations of artists matched only by that of his younger contemporary, Michelangelo."""
# text = "Leonardo di ser Piero da Vinci"
# text = "draughtsman engineer scientist theorist"

In [258]:
model_checkpoint = "Mini_EL2\checkpoint-500"
question_answerer = pipeline("question-answering", model=model_checkpoint, handle_impossible_answer=False, batch_size=BATCH_SIZE, device=device )
THRESHOLD = 50
results = loopENTQA(text)
results["entities"] = sorted(results["entities"], key=lambda d: d['start']) 
display_text = []
text_splitted =  results["text"].replace('-', ' ').replace(',', ' ').replace('.', ' ').split()
last_index = 0
for i in results["entities"]:
    display_text.append(" " + " ".join(text_splitted[last_index:i["start"]]) + " ")
    display_text.append((" ".join(text_splitted[i["start"]:i["end"]]), i["question"].split(":")[-1] + " " + "{:.2f}".format(i["score"]*100) , "#faa", "#000000"))
    last_index = i["end"]
display_text.append(" " + " ".join(text_splitted[last_index:])) 
display_text.append("\n")
with open('display_text.pickle', 'wb') as handle:
    pickle.dump(display_text, handle, protocol=pickle.HIGHEST_PROTOCOL)

Token Counts: --- 43 tokens ---
Expending Time: --- 0.7468669414520264 seconds ---
Candidate Entities (whole Text): 116
Segmentation Time: --- 0.0 seconds ---
Candidate Entities (segment Text): 116
Linking Time: --- 0.41979360580444336 seconds ---
Merging Time: --- 0.0 seconds ---


In [269]:
model_checkpoint = "Mini_LCQUAD2"
question_answerer = pipeline("question-answering", model=model_checkpoint, handle_impossible_answer=False, batch_size=BATCH_SIZE, device=device )
THRESHOLD = 20
results = loopENTQA(text)
results["entities"] = sorted(results["entities"], key=lambda d: d['start']) 
display_text = []
text_splitted =  results["text"].replace('-', ' ').replace(',', ' ').replace('.', ' ').split()
last_index = 0
for i in results["entities"]:
    display_text.append(" " + " ".join(text_splitted[last_index:i["start"]]) + " ")
    display_text.append((" ".join(text_splitted[i["start"]:i["end"]]), i["question"].split(":")[-1] + " " + "{:.2f}".format(i["score"]*100) , "#faa", "#000000"))
    last_index = i["end"]
display_text.append(" " + " ".join(text_splitted[last_index:])) 
display_text.append("\n")
with open('display_text.pickle', 'wb') as handle:
    pickle.dump(display_text, handle, protocol=pickle.HIGHEST_PROTOCOL)

Token Counts: --- 43 tokens ---
Expending Time: --- 0.6230158805847168 seconds ---
Candidate Entities (whole Text): 116
Segmentation Time: --- 0.0 seconds ---
Candidate Entities (segment Text): 116
Linking Time: --- 0.3925018310546875 seconds ---
Merging Time: --- 0.001026153564453125 seconds ---
