In [1]:
import numpy as np

import joblib
import torch

import config
import dataset
import engine
from model import EntityModel



In [2]:
meta_data = joblib.load("meta.bin")
enc_tag = meta_data["enc_tag"]

num_tag = len(list(enc_tag.classes_))



In [3]:
enc_tag.classes_

array(['B-chemical', 'B-protein', 'I-chemical', 'I-protein', 'O'],
      dtype='<U10')

In [4]:
model = EntityModel(num_tag=num_tag)
model.load_state_dict(torch.load(config.MODEL_PATH))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
# device = torch.device("cpu")
if torch.cuda.is_available(): model.to(device)


In [None]:
sentence = """
President Donald Trump may have broken a U.S. federal law and a Georgia law against election tampering by pressuring the state's top election official to "find" enough votes to overturn his loss to President-elect Joe Biden in the state, according to some legal experts.
"""
sentence = sentence.lower()
tokenized_sentence = config.TOKENIZER.encode(sentence)

sentence = sentence.split()
print(sentence)
print(tokenized_sentence)
print(config.TOKENIZER.convert_ids_to_tokens(tokenized_sentence))

test_dataset = dataset.EntityDataset(
    texts=[sentence],
    tags=[[0] * len(sentence)], O_tag_id= enc_tag.transform(["O"])[0]
)



## Custom Data

In [5]:
import train

In [6]:
from sklearn import preprocessing
from sklearn import model_selection

In [7]:
sentences, tag, enc_tag = train.read_bilou(config.TRAINING_FILE)

num_tag = len(list(enc_tag.classes_))

(
    train_sentences,
    test_sentences,
    train_tag,
    test_tag
) = model_selection.train_test_split(sentences, tag, random_state=config.RANDOM_STATE, test_size=0.1)

In [8]:
valid_dataset = dataset.EntityDataset(
    texts=test_sentences, tags=test_tag, O_tag_id= enc_tag.transform(["O"])[0]
)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=1
)

In [None]:
enc_tag.classes_

In [None]:
y_pred = []
y_ground = []

model.eval()
for data in (valid_data_loader):
    sentence_lengths = data["length"].tolist()
    for k, v in data.items(): #BioBERT is taking alot of space
        data[k] = v.to(device)
    tags, loss = model(data["ids"],data["mask"], data["token_type_ids"], data["target_tag"])
    for ix, tag in enumerate(tags):
        result = enc_tag.inverse_transform(
                tag.argmax(1).cpu().numpy()
            )
        
        y_pred.extend(result[:sentence_lengths[ix]])

        ground_truth_seq = data["target_tag"][ix][:sentence_lengths[ix]].tolist()
        ground_truth = enc_tag.inverse_transform(ground_truth_seq)
        y_ground.extend(ground_truth)


In [None]:
from sklearn.metrics import classification_report

In [None]:
print(classification_report(y_pred=(y_pred), y_true=(y_ground), labels=enc_tag.classes_))

In [None]:
np.array(y_pred)

In [None]:
tag.shape

In [None]:
bpe_tok_sent = config.TOKENIZER.convert_ids_to_tokens(tokenized_sentence)
result = enc_tag.inverse_transform(
            tag.argmax(1).cpu().numpy().reshape(-1)
        )[:len(tokenized_sentence)]

In [None]:
result

['president', 'donald', 'trump', 'may', 'have', 'broken', 'a', 'u.s.', 'federal', 'law', 'and', 'a', 'georgia', 'law', 'against', 'election', 'tampering', 'by', 'pressuring', 'the', "state's", 'top', 'election', 'official', 'to', '"find"', 'enough', 'votes', 'to', 'overturn', 'his', 'loss', 'to', 'president-elect', 'joe', 'biden', 'in', 'the', 'state,', 'according', 'to', 'some', 'legal', 'experts.']

[101, 2343, 6221, 8398, 2089, 2031, 3714, 1037, 1057, 1012, 1055, 1012, 2976, 2375, 1998, 1037, 4108, 2375, 2114, 2602, 17214, 4842, 2075, 2011, 2811, 12228, 1996, 2110, 1005, 1055, 2327, 2602, 2880, 2000, 1000, 2424, 1000, 2438, 4494, 2000, 2058, 22299, 2010, 3279, 2000, 2343, 1011, 11322, 3533, 7226, 2368, 1999, 1996, 2110, 1010, 2429, 2000, 2070, 3423, 8519, 1012, 102]

['[CLS]', 'president', 'donald', 'trump', 'may', 'have', 'broken', 'a', 'u', '.', 's', '.', 'federal', 'law', 'and', 'a', 'georgia', 'law', 'against', 'election', 'tam', '##per', '##ing', 'by', 'press', '##uring', 'the', 'state', "'", 's', 'top', 'election', 'official', 'to', '"', 'find', '"', 'enough', 'votes', 'to', 'over', '##turn', 'his', 'loss', 'to', 'president', '-', 'elect', 'joe', 'bid', '##en', 'in', 'the', 'state', ',', 'according', 'to', 'some', 'legal', 'experts', '.', '[SEP]']

### Assign first BPE token's tag to other BPE tokens

In [None]:
prev_tok_tag = ''


concatenated_bpe = ''
concatenated_bpes = []
concatenated_tags = []

new_result = []
for idx, (bpe_tok, tag) in enumerate(zip(bpe_tok_sent, result)):
    if not "##" in bpe_tok:
        if idx!=0: 
            concatenated_bpes.append(concatenated_bpe)
            concatenated_tags.append(main_tok_tag)
        concatenated_bpe = ''
        concatenated_bpe+=(bpe_tok).replace("##","")
        main_tok_tag = tag
    else:
        concatenated_bpe+=(bpe_tok).replace("##","")

In [14]:
def undo_bpe(bpe_tok_sent, result, ground_truth):
    prev_tok_tag = ''


    concatenated_bpe = ''
    concatenated_bpes = []
    concatenated_tags = []
    concatenated_ground_tags = []
    
    new_result = []
    for idx, (bpe_tok, tag, ground_tag) in enumerate(zip(bpe_tok_sent, result, ground_truth)):
        if not "##" in bpe_tok:
            if idx!=0: 
                concatenated_bpes.append(concatenated_bpe)
                concatenated_tags.append(main_tok_tag)
                concatenated_ground_tags.append(main_ground_tag)
            concatenated_bpe = ''
            concatenated_bpe+=(bpe_tok).replace("##","")
            main_tok_tag = tag
            main_ground_tag = ground_tag
        else:
            concatenated_bpe+=(bpe_tok).replace("##","")
        
    return concatenated_bpes, concatenated_tags, concatenated_ground_tags

In [15]:
from tqdm import tqdm

In [None]:
y_pred = []
y_ground = []

model.eval()
for data in tqdm(valid_data_loader):
    sentence_lengths = data["length"].tolist()
    for k, v in data.items(): #BioBERT is taking alot of space
        data[k] = v.to(device)
    tags, loss = model(data["ids"],data["mask"], data["token_type_ids"], data["target_tag"])
    for ix, tag in enumerate(tags):
        result = enc_tag.inverse_transform(
                tag.argmax(1).cpu().numpy()
            )
        
        result_shortened = result[:sentence_lengths[ix]]


        ground_truth_seq = data["target_tag"][ix][:sentence_lengths[ix]].tolist()
        ground_truth = enc_tag.inverse_transform(ground_truth_seq)
        

        
        
        text_sentence = config.TOKENIZER.convert_ids_to_tokens(data["ids"][ix][:sentence_lengths[ix]].tolist())
        
#         print(f"Before \n{text_sentence}\n{result_shortened}\n{ground_truth}\n{'='*25}\n")
        
        text, res, groundtruth = undo_bpe(text_sentence, result_shortened, ground_truth)
        
#         print(f"After \n{text}\n{res}\n{groundtruth}\n{'='*25}\n")
#         print(f"=*="*40)
        
        y_pred.extend(res)
        y_ground.extend(groundtruth)
        


In [18]:
from sklearn.metrics import classification_report

In [19]:
print(classification_report(y_pred=(y_pred), y_true=(y_ground), labels=enc_tag.classes_))

              precision    recall  f1-score   support

  B-chemical       0.91      0.97      0.94      1088
   B-protein       0.92      0.95      0.93      5579
  I-chemical       0.91      0.86      0.89       431
   I-protein       0.89      0.93      0.91      2572
           O       0.99      0.99      0.99     61936

    accuracy                           0.98     71606
   macro avg       0.92      0.94      0.93     71606
weighted avg       0.98      0.98      0.98     71606



In [49]:
print(data['ids'][0].shape)

torch.Size([128])
