In [1]:
# https://towardsdatascience.com/calculating-string-similarity-in-python-276e18a7d33a
# https://pypi.org/project/fuzzywuzzy/
# https://www.adamsmith.haus/python/answers/how-to-find-a-similarity-metric-between-two-strings-in-python


from transformers import T5TokenizerFast, T5ForConditionalGeneration 
from transformers import Trainer

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" 
import pandas as pd
import numpy as np
import torch
import torchvision
import Levenshtein
from fuzzywuzzy import fuzz


# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
# Create torch dataset
# https://towardsdatascience.com/fine-tuning-pretrained-nlp-models-with-huggingfaces-trainer-6326a4456e7b

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels['input_ids'][idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [3]:
def get_levenshtein_dis(str_1, str_2):
    return Levenshtein.distance(str_1, str_2)

In [4]:
def get_fuzzy_ration(str_1, str_2):
    return fuzz.ratio(str_1, str_2)

In [5]:
max_source_length = 1024
max_target_length = 128

In [6]:
tokenizer = T5TokenizerFast.from_pretrained("t5-base")
model_path = "./NEL_model_normal/checkpoint-5500"
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)

In [7]:
test_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/test_data.csv')
test_data = test_data.sample(frac=1, random_state=1)
test_data

Unnamed: 0,qid,question,entity,wikidata_reply
8687,Q183862,**what album has metalcore music?**,**metalcore**,"**[[Q183862, metalcore, fusion genre of heavy ..."
7972,Q1641839,**Name an experimental rock album.**,**experimental rock**,"**[[Q1641839, experimental rock, type of music..."
1628,Q17285413,**Where is joy sengupta from?**,**joy sengupta**,"**[[Q17285413, Joy Sengupta, Indian actor and ..."
8699,Q513674,**What is the sex of matthew breeze?**,**matthew breeze**,"**[[Q513674, Matthew Breeze, Australian soccer..."
5648,Q7333580,**what country is tuxbury pond in**,**tuxbury pond**,"**[[Q7333580, Tuxbury Pond, lake in Rockingham..."
...,...,...,...,...
2895,Q200092,**What is the name of a horror movie on netflix**,**horror movie**,"**[[Q200092, horror film, film genre], [Q59051..."
7813,Q7038198,**what kind of film is ninaithen vandhai?**,**ninaithen vandhai**,"**[[Q7038198, Ninaithen Vandhai, 1998 film by ..."
905,Q534599,**Where did damon knight die?**,**damon knight**,"**[[Q534599, Damon Knight, American science fi..."
5192,Q21077,**what artist is signed to warner music group?**,**warner music group**,"**[[Q21077, Warner Music Group, American multi..."


In [8]:
input_text = list(test_data['question'] + ',' + test_data['entity'] + ',' + test_data['wikidata_reply'])
input_text[0]

'**what album has metalcore music?**,**metalcore**,**[[Q183862, metalcore, fusion genre of heavy metal and hardcore punk], [Q108940567, Metalcore Superstars, album by One Morning Left], [Q4490718, melodic metalcore, subgenre of metalcore], [Q30587784, progressive metalcore, subgenre of metalcore], [Q3501147, gabber metal, fusion of gabber and metal], [Q1965804, Metalcore-bändide loend, Wikimedia list article]]**'

In [9]:
target_text = list(test_data['qid'])
target_text[0]

'Q183862'

In [10]:
X_test_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_test_tokenized = tokenizer(target_text, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(test_data))

9906


In [11]:
from transformers import Seq2SeqTrainingArguments

test_args = Seq2SeqTrainingArguments(
    "test_trainer",
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    eval_accumulation_steps = 50,  # VIP
    predict_with_generate=True
)

In [12]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model, 
    args=test_args
#     train_dataset= train_dataset, 
#     eval_dataset= eval_dataset
)

In [13]:
test_dataset = Dataset(X_test_tokenized, y_test_tokenized) 

In [14]:
    # Define test trainer
test_trainer = Seq2SeqTrainer(model, args=test_args, tokenizer=tokenizer)

    # Make prediction
raw_pred = test_trainer.predict(test_dataset)

***** Running Prediction *****
  Num examples = 9906
  Batch size = 6


In [15]:
predicitons = tokenizer.batch_decode(raw_pred[1], skip_special_tokens=True)
predicitons

['Q183862',
 'Q1641839',
 'Q17285413',
 'Q513674',
 'Q7333580',
 'Q11399',
 'Q151679',
 'Q1751177',
 'Q51023',
 'Q772463',
 'Q186335',
 'Q2471197',
 'Q181684',
 'Q316165',
 'Q1995969',
 'Q860626',
 'Q828322',
 'Q615874',
 'Q41322',
 'Q2619985',
 'Q254848',
 'Q104718',
 'Q371670',
 'Q7027400',
 'Q7922251',
 'Q1024044',
 'Q5066318',
 'Q5532596',
 'Q7317647',
 'Q503966',
 'Q4543607',
 'Q349069',
 'Q93196',
 'Q13942157',
 'Q7301356',
 'Q2412723',
 'Q2060077',
 'Q882739',
 'Q4712933',
 'Q710548',
 'Q2255030',
 'Q928151',
 'Q45981',
 'Q2450261',
 'Q975549',
 'Q4928797',
 'Q3161867',
 'Q464393',
 'Q6545041',
 'Q714331',
 'Q706908',
 'Q1166257',
 'Q22',
 'Q1718539',
 'Q6938651',
 'Q130232',
 'Q4237536',
 'Q41',
 'Q1942884',
 'Q1509872',
 'Q458658',
 'Q4993258',
 'Q170790',
 'Q5070950',
 'Q2074562',
 'Q183504',
 'Q1158979',
 'Q10059',
 'Q784496',
 'Q5097389',
 'Q2783875',
 'Q484344',
 'Q13375',
 'Q4151791',
 'Q4541898',
 'Q3500501',
 'Q2918159',
 'Q237548',
 'Q2132',
 'Q182015',
 'Q2526255',
 '

In [16]:
tokens = []
for text in input_text:
    tokens.append(tokenizer('nel: ' + text, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids)

In [17]:
results = []
model = model.to(device)

for token in tokens:
    results.append(model.generate(token.to(device))) 

In [18]:
final_ouput_2 = []
for result in results:
    final_ouput_2.append(tokenizer.decode(result[0], skip_special_tokens=True))

In [61]:
counter = 0
for i in range(len(final_ouput_2)):
    if final_ouput_2[i] != predicitons[i]:
        print(target_text[i] + "    " + final_ouput_2[i] + "    " + predicitons[i])
        counter += 1
counter

Q51023    Q3109242    Q51023
Q254848    Q18574071    Q254848
Q13942157    Q477429    Q13942157
Q882739    Q309489    Q882739
Q4928797    Q32825    Q4928797
Q1942884    Q812906    Q1942884
Q784496    Q2132    Q784496
Q2918159    Q63614867    Q2918159
Q6208167    Q465897    Q6208167
Q1208034    Q61243    Q1208034
Q4772709    Q4772707    Q4772709
Q2368037    Q55605990    Q2368037
Q17124728    Q48244    Q17124728
Q7612815    Q7612817    Q7612815
Q740939    Q3217141    Q740939
Q937005    Q7727936    Q937005
Q974046    Q1214631    Q974046
Q6209000    Q6208996    Q6209000
Q1567731    Q638531    Q1567731
Q5597461    Q1645566    Q5597461
Q1337034    Q476544    Q1337034
Q1282939    Q79642    Q1282939
Q100255    Q63071    Q100255
Q1549904    Q1549104    Q1549904
Q468932    Q770512    Q468932
Q3101663    Q1507656    Q3101663
Q989495    Q16992512    Q989495
Q986183    Q1393    Q986183
Q3259696    Q5123562    Q3259696
Q218101    Q615292    Q218101
Q6257932    Q6257950    Q6257932
Q8020130    Q802012

1088

In [62]:
counter = 0
for i in range(len(final_ouput_2)):
    if predicitons[i] != target_text[i]:
        counter += 1
1- counter/len(predicitons)

1.0

In [63]:
counter = 0
for i in range(len(final_ouput_2)):
    if final_ouput_2[i] != target_text[i]:
        counter += 1
1- counter/len(final_ouput_2)

0.8901675752069453

In [64]:
wikidata_data = list(test_data['wikidata_reply'])

In [65]:
import re
entities = []
for wiki_reply in wikidata_data:
    entities.append(re.findall(r'\[+(Q.*?),', wiki_reply))
len(entities)

9906

In [66]:
final_copy = final_ouput_2.copy()
for i in range(len(final_copy)):
    min_value = 10000
    value_to_add = ''
    for entity in entities[i]:
        if get_levenshtein_dis(entity, final_ouput_2[i]) < min_value:
            min_value =  get_levenshtein_dis(entity, final_ouput_2[i])
            value_to_add = entity
    final_copy[i] = value_to_add

In [67]:
counter = 0
for i in range(len(final_copy)):
    if final_copy[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy)

0.8929941449626488

In [68]:
final_copy_3

['Q183862',
 'Q1641839',
 'Q17285413',
 'Q513674',
 'Q7333580',
 'Q11399',
 'Q151679',
 'Q1751177',
 'Q3109242',
 'Q772463',
 'Q186335',
 'Q2471197',
 'Q181684',
 'Q316165',
 'Q1995969',
 'Q860626',
 'Q828322',
 'Q615874',
 'Q41322',
 'Q2619985',
 'Q18574071',
 'Q104718',
 'Q371670',
 'Q7027400',
 'Q7922251',
 'Q1024044',
 'Q5066318',
 'Q5532596',
 'Q7317647',
 'Q503966',
 'Q4543607',
 'Q349069',
 'Q93196',
 'Q477429',
 'Q7301356',
 'Q2412723',
 'Q2060077',
 'Q309489',
 'Q4712933',
 'Q710548',
 'Q2255030',
 'Q928151',
 'Q45981',
 'Q2450261',
 'Q975549',
 'Q32825',
 'Q3161867',
 'Q464393',
 'Q6545041',
 'Q714331',
 'Q706908',
 'Q1166257',
 'Q22',
 'Q1718539',
 'Q6938651',
 'Q130232',
 'Q4237536',
 'Q41',
 'Q812906',
 'Q1509872',
 'Q458658',
 'Q4993258',
 'Q170790',
 'Q5070950',
 'Q2074562',
 'Q183504',
 'Q1158979',
 'Q10059',
 'Q2132',
 'Q5097389',
 'Q2783875',
 'Q484344',
 'Q13375',
 'Q4151791',
 'Q4541898',
 'Q3500501',
 'Q63614867',
 'Q237548',
 'Q2132',
 'Q182015',
 'Q2526255',
 'Q2

In [72]:
final_copy_3 = final_ouput_2.copy()
for i in range(len(final_copy)):
    diff = 10000
    value_to_add = ''
    for entity in entities[i]:
        try:
            if abs(int(entity[1:]) -  int(final_ouput_2[i][1:])) < diff:
                diff =  abs(int(entity[1:]) -  int(final_ouput_2[i][1:]))
                value_to_add = entity
        except:
            pass
    final_copy_3[i] = value_to_add

In [73]:
counter = 0
for i in range(len(final_copy_3)):
    if final_copy_3[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy_3)

0.8903694730466384

In [76]:
final_copy_2 = final_ouput_2.copy()
for i in range(len(final_copy_2)):
    max_value = -10000
    value_to_add = ''
    for entity in entities[i]:
        if get_levenshtein_dis(entity, final_ouput_2[i]) > max_value:
            max_value =  get_fuzzy_ration(entity, final_ouput_2[i])
            value_to_add = entity
    final_copy_2[i] = value_to_add
    print(max_value)

100
100
100
100
100
100
100
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
43
100
100
100
100
100
100
100
36
100
100
100
100
100
35
100
24
100
100
100
100
100
100
100
100
100
100
50
100
100
100
36
100
100
100
100
100
100
100
100
100
36
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
35
100
100
100
100
100
100
100
100
100
100
100
100
100
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
40
53
100
100
100
100
100
100
100
100
100
100
43
100
100
100
40
50
100
100
100
100
100
50
100
100
100
46
100
100
100
100
100
100
100
100
100
100
29
100
100
100
100
100
47
100
100
40
100
100
100
100
62
100
100
36
100
67
100
100
25
100
100
35
33
100
100
100
100
100
100
100
100
100
100
40
100
100
25
100
100
100
100
100
100
100
100
88
100
100
100
100
100
100
100
100
100
100
40
100
100
100
100
100
100
100
100
100
71
100
47
100
100
100
100
27
100
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
100
36

100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
100
100
100
100
100
100
100
100
100
57
100
46
100
100
31
100
100
100
100
100
100
100
40
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
31
47
46
38
100
62
100
100
100
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
50
100
100
35
100
100
100
100
100
100
100
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
27
29
100
100
100
100
100
25
100
100
100
47
100
100
100
100
100
100
100
100
43
100
100
100
43
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
50
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
25
100
100
100
40
100
88
47
100
100
100
100
100
100
100
100
100
100
100
100
100
100
27
100
27
100
100
43
35
100
100
100
100
46
100
22
24
57
25
100
100
100
47
100
100
100
59
100
100
100
47
100
100
100
38
27
100
100
100
100


100
100
100
100
100
100
100
100
100
100
43
100
100
40
100
100
100
62
100
100
100
100
100
100
100
100
35
100
100
100
100
100
47
40
100
100
100
100
25
57
100
46
100
100
100
100
100
100
42
100
100
88
100
100
100
100
100
100
100
31
100
100
47
100
100
100
100
100
100
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
88
25
100
100
100
100
100
50
100
100
100
100
100
100
100
38
100
100
100
100
100
38
40
100
100
100
100
100
27
100
100
100
100
100
100
100
100
100
100
100
100
100
100
25
100
38
100
100
40
47
100
100
100
100
100
100
31
100
43
100
100
47
100
100
25
100
100
100
100
27
100
100
38
100
100
100
40
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
100
100
100
100
29
100
100
100
100
100
33
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
38
100
100
100
100
100
100
43
100
100
100
100
29
100
88
100
100
29
100
100
100
100
100
100
100
50
100
44
38
100
40
100
53
100
100
40

100
100
50
38
100
100
100
100
100
100
100
53
100
100
100
100
100
100
100
100
100
100
47
100
100
100
100
100
100
100
100
100
100
100
100
22
100
100
100
27
35
15
29
100
100
100
38
100
62
100
100
100
33
100
100
100
100
35
100
100
44
100
100
100
100
100
100
56
100
100
29
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
38
100
100
62
100
100
100
100
100
40
100
40
100
100
100
100
100
100
100
100
100
100
100
44
100
100
100
100
100
25
62
100
100
100
100
100
100
100
100
100
100
100
100
100
43
100
38
100
100
100
100
100
100
59
40
100
100
100
59
100
100
100
100
57
100
100
40
53
100
100
100
100
100
100
100
100
100
100
100
38
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
24
100
35
100
100
100
40
100
100
100
100
100
100
100
100
100
100
47
100
53
100
40
100
100
100
100
100
27
100
100
100
100
15
100
35
38
100
100
100
29
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
38
100
100

100
100
100
100
100
100
100
100
100
100
38
100
100
100
100
13
100
100
100
100
100
100
100
100
100
100
62
100
50
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
18
100
100
43
100
100
100
100
38
100
100
100
100
100
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
88
100
100
100
100
100
100
100
100
40
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
25
100
53
100
100
100
100
100
100
100
88
100
100
100
100
100
100
100
100
100
44
100
100
100
100
40
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
50
100
100
100
100
100
100
100
100
100
40
100
100
100
100
100
100
100
100
100
100
35
100
32
100
100
100
100
100
100
100
100
100
62
100
47
35
100
100
100
100
100
100
33
47
100
100
100
38
100
100
100
100
100
100
100
100
100
100
100
46
50
100
100
100
100
100
35
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
24
100
100
100
100
38
100
100

In [75]:
counter = 0
for i in range(len(final_copy_2)):
    if final_copy_2[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy_2)

0.8339390268524127