In [1]:
# https://towardsdatascience.com/calculating-string-similarity-in-python-276e18a7d33a
# https://pypi.org/project/fuzzywuzzy/
# https://www.adamsmith.haus/python/answers/how-to-find-a-similarity-metric-between-two-strings-in-python


from transformers import T5TokenizerFast, T5ForConditionalGeneration 
from transformers import Trainer

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" 
import pandas as pd
import numpy as np
import torch
import torchvision
import Levenshtein
from fuzzywuzzy import fuzz


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
# Create torch dataset
# https://towardsdatascience.com/fine-tuning-pretrained-nlp-models-with-huggingfaces-trainer-6326a4456e7b

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels['input_ids'][idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [3]:
def get_levenshtein_dis(str_1, str_2):
    return Levenshtein.distance(str_1, str_2)

In [4]:
def get_fuzzy_ration(str_1, str_2):
    return fuzz.ratio(str_1, str_2)

In [5]:
max_source_length = 1024
max_target_length = 128

In [6]:
tokenizer = T5TokenizerFast.from_pretrained("t5-base")
model_path = "./NEL_model_normal_shuffled/checkpoint-8500"
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)

## testing with the shuffled data

In [7]:
test_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/test_data_shuffled.csv')
test_data = test_data.sample(frac=1, random_state=1)
test_data

Unnamed: 0,qid,question,entity,wikidata_reply
5758,Q1891138,**what kind of music does the go-betweens perf...,**the go-betweens**,"**[[Q1891138, The Go-Betweens, Australian indi..."
5975,Q838368,**Name a black-and-white film.**,**black-and-white**,"**[[Q2407593, White and Black in chess, chess ..."
1984,Q2659421,**is evilenko a animation or crime fiction film**,**evilenko**,"**[[Q2659421, Evilenko, 2004 film directed by ..."
5832,Q557632,**Which English band is on the record label de...,**decca**,"**[[Q3338174, Echinochloa colona, species of p..."
1945,Q9036,**What company was named after nikola tesla?**,**nikola tesla**,"**[[Q16085077, Nicola Tesla, 1977 TV series], ..."
...,...,...,...,...
2895,Q7913657,**what type of music is van she?**,**van she**,"**[[Q945687, Van She, Electropop band from Syd..."
7813,Q3134980,**what is located in the mountain time zone?**,**mountain time zone**,"**[[Q3134980, Mountain Time Zone, time zone of..."
905,Q4952351,**Which gender is boyd irwin?**,**boyd irwin**,"**[[Q4952351, Boyd Irwin, actor (1880-1957)]]**"
5192,Q33999,**Name an actor.**,**actor**,"**[[Q421946, actor, person performing an actio..."


In [8]:
input_text = list(test_data['question'] + ',' + test_data['entity'] + ',' + test_data['wikidata_reply'])
input_text[0]

'**what kind of music does the go-betweens perform**,**the go-betweens**,**[[Q1891138, The Go-Betweens, Australian indie rock band]]**'

In [9]:
target_text = list(test_data['qid'])
target_text[0]

'Q1891138'

In [10]:
X_test_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_test_tokenized = tokenizer(target_text, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(test_data))

9921


In [11]:
from transformers import Seq2SeqTrainingArguments

test_args = Seq2SeqTrainingArguments(
    "test_trainer",
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    eval_accumulation_steps = 50,  # VIP
    predict_with_generate=True
)

In [12]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model, 
    args=test_args
#     train_dataset= train_dataset, 
#     eval_dataset= eval_dataset
)

In [13]:
test_dataset = Dataset(X_test_tokenized, y_test_tokenized) 

In [14]:
#     # Define test trainer
# test_trainer = Seq2SeqTrainer(model, args=test_args, tokenizer=tokenizer)

#     # Make prediction
# raw_pred = test_trainer.predict(test_dataset)''

In [15]:
# predicitons = tokenizer.batch_decode(raw_pred[1], skip_special_tokens=True)
# predicitons'

In [16]:
tokens = []
for text in input_text:
    tokens.append(tokenizer('nel: ' + text, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids)

In [17]:
results = []
model = model.to(device)

for token in tokens:
    results.append(model.generate(token.to(device))) 

In [18]:
final_ouput_2 = []
for result in results:
    final_ouput_2.append(tokenizer.decode(result[0], skip_special_tokens=True))

In [19]:
# counter = 0
# for i in range(len(final_ouput_2)):
#     if final_ouput_2[i] != predicitons[i]:
#         print(target_text[i] + "    " + final_ouput_2[i] + "    " + predicitons[i])
#         counter += 1
# counter

In [20]:
# counter = 0
# for i in range(len(final_ouput_2)):
#     if predicitons[i] != target_text[i]:
#         counter += 1
# 1- counter/len(predicitons)

In [21]:
counter = 0
for i in range(len(final_ouput_2)):
    if final_ouput_2[i] != target_text[i]:
        counter += 1
1- counter/len(final_ouput_2)

0.892450357826832

In [22]:
wikidata_data = list(test_data['wikidata_reply'])

In [23]:
import re
entities = []
for wiki_reply in wikidata_data:
    entities.append(re.findall(r'\[+(Q.*?),', wiki_reply))
len(entities)

9921

In [24]:
final_copy = final_ouput_2.copy()
for i in range(len(final_copy)):
    min_value = 10000
    value_to_add = ''
    for entity in entities[i]:
        if get_levenshtein_dis(entity, final_ouput_2[i]) < min_value:
            min_value =  get_levenshtein_dis(entity, final_ouput_2[i])
            value_to_add = entity
    final_copy[i] = value_to_add

In [25]:
counter = 0
for i in range(len(final_copy)):
    if final_copy[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy)

0.8917447837919564

In [28]:
final_copy_3

NameError: name 'final_copy_3' is not defined

In [29]:
final_copy_3 = final_ouput_2.copy()
for i in range(len(final_copy)):
    diff = 10000
    value_to_add = ''
    for entity in entities[i]:
        try:
            if abs(int(entity[1:]) -  int(final_ouput_2[i][1:])) < diff:
                diff =  abs(int(entity[1:]) -  int(final_ouput_2[i][1:]))
                value_to_add = entity
        except:
            pass
    final_copy_3[i] = value_to_add

In [30]:
counter = 0
for i in range(len(final_copy_3)):
    if final_copy_3[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy_3)

0.8941045830809611

In [31]:
final_copy_2 = final_ouput_2.copy()
for i in range(len(final_copy_2)):
    max_value = -10000
    value_to_add = ''
    for entity in entities[i]:
        if get_levenshtein_dis(entity, final_ouput_2[i]) > max_value:
            max_value =  get_fuzzy_ration(entity, final_ouput_2[i])
            value_to_add = entity
    final_copy_2[i] = value_to_add
    print(max_value)

100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
36
100
100
100
100
100
35
100
24
100
100
100
100
57
100
100
100
100
100
50
100
100
100
36
100
100
100
100
100
100
100
100
100
36
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
53
100
100
100
100
100
100
100
100
100
100
43
100
100
100
40
100
100
100
100
100
50
50
100
100
100
46
100
100
100
100
100
100
100
100
100
100
29
100
100
100
100
100
47
100
100
40
100
100
100
100
62
100
100
36
100
50
100
50
25
100
100
59
33
100
100
100
100
100
100
100
100
100
100
40
100
100
25
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
100
100
100
100
100
100
100
100
71
100
35
100
100
100
100
27
100
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
100

100
100
100
100
100
100
100
100
100
57
100
46
100
100
27
100
100
100
100
100
100
100
40
100
88
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
29
47
100
38
100
62
100
100
100
100
93
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
35
100
100
100
100
100
100
100
100
88
100
100
100
100
100
100
100
100
100
43
100
100
100
40
100
100
29
100
100
100
100
100
100
100
100
100
47
100
100
100
100
100
100
100
27
43
100
100
100
46
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
50
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
25
100
100
43
40
100
88
47
100
100
100
100
100
100
100
100
100
100
100
100
100
100
27
100
27
100
100
29
35
100
100
100
100
46
100
22
24
57
25
100
100
100
59
100
100
100
59
100
100
100
47
100
100
100
38
100
100
100
100
100
100
27
53
40
100
33
100
47
35
100
100
100
100
100
100
100
100
100
100

67
100
100
100
100
35
100
100
100
100
100
100
18
100
43
100
100
62
100
47
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
62
100
33
100
100
100
100
100
100
100
100
100
100
100
50
100
100
100
100
100
100
100
100
100
100
100
100
100
62
100
100
43
100
100
40
100
100
100
100
100
100
100
100
100
100
100
100
35
100
100
100
100
100
47
40
100
100
100
100
25
57
100
100
100
100
38
100
100
100
42
100
86
88
100
100
100
100
100
100
100
31
100
100
47
100
100
100
43
100
100
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
88
25
100
100
100
100
100
50
100
100
100
100
100
100
100
38
62
100
100
100
100
38
40
100
100
100
100
100
27
100
100
100
100
100
100
100
100
100
100
100
100
100
100
25
100
38
100
100
40
47
100
100
100
100
100
100
100
100
43
100
100
47
100
100
25
100
100
100
100
27
100
100
100
100
100
100
88
100
100
100
100
100
100
100
25
100
100
100
100
100
40
100
100
100
100
100
100
100
100
100
100
100
100
100
53
100
100
100
100
100
53
100
100
100
100
100
33

33
100
100
100
100
35
100
100
44
100
100
100
100
100
100
56
100
100
100
31
100
100
40
100
100
100
100
82
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
38
100
100
100
100
100
44
100
100
40
100
40
100
100
100
100
100
100
100
100
100
100
100
44
100
100
100
100
100
25
62
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
38
100
100
100
100
100
100
59
40
100
100
100
59
100
100
100
100
100
100
100
29
13
100
33
100
100
100
100
100
100
100
100
100
38
100
100
100
100
100
100
100
100
31
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
24
100
35
100
100
100
40
100
100
100
100
100
100
100
100
100
100
47
100
40
100
40
100
100
100
100
100
27
100
100
100
100
15
100
35
38
100
100
33
29
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
38
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
43
40
33
27
100
100
40
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
15
53
100
100
100
100
100
100
100
100
100
100
100
1

100
100
100
100
100
100
100
100
38
100
100
100
100
13
100
100
100
100
100
100
44
100
100
100
62
100
50
100
100
35
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
18
100
100
43
100
100
100
100
38
100
100
100
57
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
40
100
53
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
100
35
100
100
100
100
25
100
53
100
100
100
100
100
100
100
88
100
100
100
100
100
100
38
100
100
44
100
100
100
100
40
100
100
100
100
94
100
100
100
100
100
100
100
100
100
100
53
50
100
27
100
100
100
100
100
100
100
27
100
100
100
100
100
100
100
100
100
100
35
100
33
100
100
100
100
100
100
100
100
100
62
100
47
24
100
100
100
100
100
100
33
100
100
100
100
38
100
100
100
100
100
100
100
100
100
100
100
46
50
100
100
100
100
100
35
100
100
100
100
100
100
100
93
100
100
100
100
100
100
100
100
100
100
100
24
100
100
100
100
38
100
100
100
100
100
1

In [32]:
counter = 0
for i in range(len(final_copy_2)):
    if final_copy_2[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy_2)

0.8339390268524127

## testing with the unshuffled data

In [27]:
test_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/test_data.csv')
test_data = test_data.sample(frac=1, random_state=1)
test_data

Unnamed: 0,qid,question,entity,wikidata_reply
8687,Q183862,**what album has metalcore music?**,**metalcore**,"**[[Q183862, metalcore, fusion genre of heavy ..."
7972,Q1641839,**Name an experimental rock album.**,**experimental rock**,"**[[Q1641839, experimental rock, type of music..."
1628,Q17285413,**Where is joy sengupta from?**,**joy sengupta**,"**[[Q17285413, Joy Sengupta, Indian actor and ..."
8699,Q513674,**What is the sex of matthew breeze?**,**matthew breeze**,"**[[Q513674, Matthew Breeze, Australian soccer..."
5648,Q7333580,**what country is tuxbury pond in**,**tuxbury pond**,"**[[Q7333580, Tuxbury Pond, lake in Rockingham..."
...,...,...,...,...
2895,Q200092,**What is the name of a horror movie on netflix**,**horror movie**,"**[[Q200092, horror film, film genre], [Q59051..."
7813,Q7038198,**what kind of film is ninaithen vandhai?**,**ninaithen vandhai**,"**[[Q7038198, Ninaithen Vandhai, 1998 film by ..."
905,Q534599,**Where did damon knight die?**,**damon knight**,"**[[Q534599, Damon Knight, American science fi..."
5192,Q21077,**what artist is signed to warner music group?**,**warner music group**,"**[[Q21077, Warner Music Group, American multi..."


In [28]:
input_text = list(test_data['question'] + ',' + test_data['entity'] + ',' + test_data['wikidata_reply'])
input_text[0]

'**what album has metalcore music?**,**metalcore**,**[[Q183862, metalcore, fusion genre of heavy metal and hardcore punk], [Q108940567, Metalcore Superstars, album by One Morning Left], [Q4490718, melodic metalcore, subgenre of metalcore], [Q30587784, progressive metalcore, subgenre of metalcore], [Q3501147, gabber metal, fusion of gabber and metal], [Q1965804, Metalcore-bändide loend, Wikimedia list article]]**'

In [29]:
target_text = list(test_data['qid'])
target_text[0]

'Q183862'

In [30]:
X_test_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_test_tokenized = tokenizer(target_text, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(test_data))

9906


In [31]:
from transformers import Seq2SeqTrainingArguments

test_args = Seq2SeqTrainingArguments(
    "test_trainer",
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    eval_accumulation_steps = 50,  # VIP
    predict_with_generate=True
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [32]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model, 
    args=test_args
#     train_dataset= train_dataset, 
#     eval_dataset= eval_dataset
)

In [33]:
test_dataset = Dataset(X_test_tokenized, y_test_tokenized) 

In [34]:
tokens = []
for text in input_text:
    tokens.append(tokenizer('nel: ' + text, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids)

In [35]:
results = []
model = model.to(device)

for token in tokens:
    results.append(model.generate(token.to(device))) 

In [36]:
final_ouput_2 = []
for result in results:
    final_ouput_2.append(tokenizer.decode(result[0], skip_special_tokens=True))

In [37]:
counter = 0
for i in range(len(final_ouput_2)):
    if final_ouput_2[i] != target_text[i]:
        counter += 1
1- counter/len(final_ouput_2)

0.8921865536038764

In [38]:
wikidata_data = list(test_data['wikidata_reply'])

In [39]:
import re
entities = []
for wiki_reply in wikidata_data:
    entities.append(re.findall(r'\[+(Q.*?),', wiki_reply))
len(entities)

9906

In [40]:
final_copy = final_ouput_2.copy()
for i in range(len(final_copy)):
    min_value = 10000
    value_to_add = ''
    for entity in entities[i]:
        if get_levenshtein_dis(entity, final_ouput_2[i]) < min_value:
            min_value =  get_levenshtein_dis(entity, final_ouput_2[i])
            value_to_add = entity
    final_copy[i] = value_to_add

In [41]:
counter = 0
for i in range(len(final_copy)):
    if final_copy[i] != target_text[i]:
        counter += 1
1- counter/len(final_copy)

0.8924894003634161