In [1]:
from utils_metrics import get_entities_bio, f1_score, classification_report
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
import torch
import time
import math
import pandas as pd
import evaluate
metric = evaluate.load("seqeval")

In [2]:
class InputExample():
    def __init__(self, words, labels):
        self.words = words
        self.labels = labels

def template_entity(words, input_TXT, start):
    vowels = ('a','e','i','o','u','A','E','I','O','U')
    LABELS=['AerospaceManufacturer','AnatomicalStructure', 'Artist', 'ArtWork','Athlete',
            'CarManufacturer', 'Cleric', 'Clothing',
            'Disease','Drink',
            'Facility','Food',
            'HumanSettlement',
            'MedicalProcedure', 'Medication/Vaccine', 'MusicalGRP', 'MusicalWork',
            'ORG', 'OtherLOC', 'OtherPER', 'OtherPROD',
            'Politician', 'PrivateCorp', 'PublicCorp',
            'Scientist','Software', 'SportsGRP', 'SportsManager', 'Station', 'Symptom',
            'Vehicle','VisualWork',
            'WrittenWork']
    template_list=[" is an %s entity"%(e) if e.startswith(vowels) else " is a %s entity"%(e) for e in LABELS]
    # template_list=[" belongs to %s category"%(e) for e in LABELS]
    # template_list=[" should be tagged as %s"%(e) for e in LABELS]
    entity_dict={i:e for i, e in enumerate(LABELS)}
    num_entities = len(template_list)
    
    # input text -> template
    words_length = len(words)
    words_length_list = [len(i) for i in words]
    input_TXT = [input_TXT]*(num_entities*words_length)

    input_ids = tokenizer(input_TXT, return_tensors='pt')['input_ids']
    model.to(device)
    
    temp_list = []
    for i in range(words_length):
        for j in range(len(template_list)):
            temp_list.append(words[i]+template_list[j])

    output_ids = tokenizer(temp_list, return_tensors='pt', padding=True, truncation=True)['input_ids']
    # print("Before: ",output_ids.shape)
    output_ids[:, 0] = 2
    # print("After: ",output_ids.shape)
    output_length_list = [0]*num_entities*words_length


    for i in range(len(temp_list)//num_entities):
        base_length = ((tokenizer(temp_list[i * num_entities], return_tensors='pt', padding=True, truncation=True)['input_ids']).shape)[1] - 4
        output_length_list[i*num_entities:i*num_entities+ num_entities] = [base_length]*num_entities
        output_length_list[i*num_entities+4] += 1

    score = [1]*num_entities*words_length
    with torch.no_grad():
        output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids[:, :output_ids.shape[1] - 2].to(device))[0]
        for i in range(output_ids.shape[1] - 3):
            # print(input_ids.shape)
            logits = output[:, i, :]
            logits = logits.softmax(dim=1)
            # values, predictions = logits.topk(1,dim = 1)
            logits = logits.to('cpu').numpy()
            # print(output_ids[:, i+1].item())
            for j in range(0, num_entities*words_length):
                if i < output_length_list[j]:
                    score[j] = score[j] * logits[j][int(output_ids[j][i + 1])]

    end = start+(score.index(max(score))//num_entities)
        # score_list.append(score)
    return [start, end, entity_dict[(score.index(max(score))%num_entities)]if round(max(score),4) > 0 else 'O' , round(max(score),4)] #[start_index,end_index,label,score]

def prediction(input_TXT):
    input_TXT_list = input_TXT.split(' ')

    entity_list = []
    for i in range(len(input_TXT_list)):
        words = []
        for j in range(1, min(9, len(input_TXT_list) - i + 1)):
            word = (' ').join(input_TXT_list[i:i+j])
            words.append(word)

        entity = template_entity(words, input_TXT, i) #[start_index,end_index,label,score]
        if entity[1] >= len(input_TXT_list):
            entity[1] = len(input_TXT_list)-1
        if entity[2] != 'O':
            entity_list.append(entity)
    i = 0
    if len(entity_list) > 1:
        while i < len(entity_list):
            j = i+1
            while j < len(entity_list):
                if (entity_list[i][1] < entity_list[j][0]) or (entity_list[i][0] > entity_list[j][1]):
                    j += 1
                else:
                    if entity_list[i][3] < entity_list[j][3]:
                        entity_list[i], entity_list[j] = entity_list[j], entity_list[i]
                        entity_list.pop(j)
                    else:
                        entity_list.pop(j)
            i += 1
    label_list = ['O'] * len(input_TXT_list)

    for entity in entity_list:
        label_list[entity[0]:entity[1]+1] = ["I-"+entity[2]]*(entity[1]-entity[0]+1)
        label_list[entity[0]] = "B-"+entity[2]
    return label_list

def cal_time(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('./outputs/best_model_en_bart_temp1')
model.eval()
model.config.use_cache = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
score_list = []
file_path = '/home/tranthh/semeval2023/semeval_test_phase/public_data/EN-English/en_test.conll'
guid_index = 1
examples = []
with open(file_path, "r", encoding="utf-8") as f:
    words = []
    labels = []
    for line in f:
        if line.startswith("-DOCSTART-") or line.startswith('#') or line == "" or line == "\n":
            if words:
                examples.append(InputExample(words=words, labels=labels))
                words = []
                labels = []
        else:
            splits = line.split(" ")
            words.append(splits[0])
            if len(splits) > 1:
                labels.append(splits[-1].replace("\n", ""))
            else:
                # Examples could have no label for mode = "test"
                labels.append("O")
    if words:
        examples.append(InputExample(words=words, labels=labels))

In [None]:
trues_list = []
preds_list = []
num_01 = len(examples)
num_point = 0
start = time.time()
for example in examples:
    sources = ' '.join(example.words)
    preds_list.append(prediction(sources))
    trues_list.append(example.labels)
    if num_point % 20 == 0:
        print('%d/%d (%s)'%(num_point+1, num_01, cal_time(start)))
        print(example.words)
        print('Pred:', preds_list[num_point])
        print('Gold:', trues_list[num_point])
    num_point += 1

1/249980 (0m 9s)
['the', 'species', 'was', 'described', 'by', 'dietrich', 'brandis', 'after', 'the', 'forester', 't.', 'f.', 'bourdillon', '.']
Pred: ['O', 'O', 'O', 'B-WrittenWork', 'O', 'B-OtherPER', 'I-OtherPER', 'O', 'O', 'O', 'B-OtherPER', 'I-OtherPER', 'I-OtherPER', 'O']
Gold: ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_']
21/249980 (2m 4s)
['nanosmilus', 'was', 'first', 'discovered', 'in', '1880', 'by', 'edward', 'drinker', 'cope', 'and', 'described', 'from', 'fragmentary', 'material', '.']
Pred: ['O', 'O', 'O', 'O', 'O', 'B-WrittenWork', 'O', 'B-OtherPER', 'I-OtherPER', 'B-OtherPER', 'O', 'B-WrittenWork', 'O', 'O', 'O', 'O']
Gold: ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_']
41/249980 (3m 57s)
['most', 'goals', 'in', 'a', 'season', ':', '214', 'by', 'oel', 'coslett', '1971', '–', '27', '.']
Pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Athlete', 'I-Athlete', 'B-WrittenWork', 'O', 'O', 'O']
Gold: ['_', '_', '_', '_'

In [None]:
true_entities = get_entities_bio(trues_list)
pred_entities = get_entities_bio(preds_list)
results = {
    "f1": f1_score(true_entities, pred_entities)
}
print(classification_report(true_entities,pred_entities))

In [None]:
for num_point in range(len(preds_list)):
    preds_list[num_point] = ' '.join(preds_list[num_point]) + '\n'
    trues_list[num_point] = ' '.join(trues_list[num_point]) + '\n'
# with open('./pred_template3.txt', 'w') as f0:
#     f0.writelines(preds_list)
# with open('./gold_template3.txt', 'w') as f0:
#     f0.writelines(trues_list)

In [None]:
final_preds = []
for x in preds_list:
    final_preds.extend([' '] + x.split() + [' '])
pd.DataFrame(final_preds).to_csv('./res/en_test_temp1.conll',  header=None,  index=False)    

In [None]:
print(classification_report(true_entities,pred_entities))

In [None]:
# predictions = []
# for x in preds_list:
#     predictions.extend(x.replace('\n','').split(' '))
    
# groundtruths = []
# for x in trues_list:
#     groundtruths.extend(x.replace('\n','').split(' '))

# print(classification_report(true_entities,pred_entities))