In [11]:
from ultil_ner import *
from typing import Dict, List, Tuple
from transformers import AutoModel, AutoConfig, AutoTokenizer, AutoModelForTokenClassification
import numpy as np
from torch import nn
import torch
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, Dataset

# đưa đường dẫn model đã train và tokenizer vào
model_path= '/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/output/biobert__6class'
tokenizer_path= '/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/model/BiomedNLP-PubMedBERT-base-uncased-abstract'
labels = get_labels()
label_map = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)
from ultil_ner import InputFeatures, open_pickle
from tqdm import tqdm
from torchcrf import CRF
class BertCRFModel(nn.Module):
    def __init__(self, config, model_name_or_path):
        super(BertCRFModel, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name_or_path, config=config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask.byte())
            outputs =(-1*loss,)+outputs
            return outputs 
        else:
            return self.crf.decode(logits, attention_mask.byte())

In [12]:
from ultil_ner import InputFeatures, split_text_by_spacy_example
from transformers import BertTokenizer
import numpy as np
from torch import nn
from typing import List
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm

"""
các class cho mục đích test
"""
class NerDatasetForPredict(Dataset):
    """
    """
    
    features: List[InputFeatures]
    pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index

    def __init__(self, features: List[InputFeatures]):
        self.features = features
        self.input_ids = [torch.tensor(example.input_ids).long() for example in self.features]
        self.attention_masks = [torch.tensor(example.attention_mask).float() for example in self.features]
        self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in self.features]
        self.labels = [torch.tensor(example.label_ids).long() for example in self.features]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i):
        return {
            "input_ids": self.input_ids[i],
            "attention_masks": self.attention_masks[i],
            "token_type_ids": self.token_type_ids[i],
            "labels": self.labels[i],
        }
    
# convert text to input for bert
def convert_text_to_input_bert(text, tokenizer, labels_map, max_seq_length=512):
    """
    Input: 
        text: raw text
        tokenizer: BertTokenizer
        labels_map: dict{label: index}
        max_seq_length: max length of input
    Output:
        sentences: list of sentences
        sentences_input: list of InputFeatures for Bert Input
    """
    sentences= split_text_by_spacy_example(text)
    sentences_input= []
    for sentence in sentences:
        encoded_input = tokenizer.encode_plus(
                    sentence,
                    add_special_tokens=True,
                    truncation=True,
                    padding="max_length",
                    max_length= max_seq_length,
                    return_attention_mask=True,
                    return_token_type_ids=True)
        input_ids = encoded_input["input_ids"]
        token_type_ids = encoded_input["token_type_ids"]
        attention_mask = encoded_input["attention_mask"]
        label_ids = ['O'] * len(input_ids)
        label_ids= [labels_map[label] for label in label_ids]
        assert len(input_ids) == max_seq_length
        assert len(token_type_ids) == max_seq_length
        assert len(attention_mask) == max_seq_length

        sentences_input.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, 
                                             token_type_ids=token_type_ids, label_ids=label_ids))
    return sentences, sentences_input

# predict BERT ko sử dụng CRF
def predict(model, data_loader, label_map, device= 'cpu'):
    '''Model: enity recognition
    label_map: dict{index: label}
    
    Output:
        preds_list: list of list of label of each sentence in B-I-O format
        batch_preds: list of list of index of label of each sentence in digit format'''
    model = model.to(device)
    batch_preds = []
    attention_masks = []
    for batch in tqdm(data_loader):
        with torch.no_grad():
            outputs = model(batch['input_ids'].to(device), batch['attention_masks'].to(device), batch['token_type_ids'].to(device))
            preds= np.argmax(outputs[0].detach().cpu().numpy(), axis=2)            
            batch_preds.extend(preds)
            attention_masks.extend(batch['attention_masks'].detach().cpu().numpy())
    batch_size, seq_len = np.array(batch_preds).shape
    preds_list = [[] for _ in range(batch_size)]

    for i in range(batch_size): # số lượng câu chứ không phải batch-size nhé
        for j in range(seq_len):
            if attention_masks[i][j] != 0:
                preds_list[i].append(label_map[batch_preds[i][j]])                   
    return preds_list, batch_preds

# predict BERT sử dụng CRF
def predict_crf(model, data_loader, label_map, device= 'cpu'):
    '''Model: enity recognition
    label_map: dict{index: label}
    
    Output:
        preds_list: list of list of label of each sentence in B-I-O format
        batch_preds: list of list of index of label of each sentence in digit format'''
    model = model.to(device)
    batch_preds = []
    attention_masks = []
    for batch in tqdm(data_loader):
        with torch.no_grad():
            outputs = model(batch['input_ids'].to(device), batch['attention_masks'].to(device), batch['token_type_ids'].to(device))
            preds = outputs[0]            
            batch_preds.append(preds)
            attention_masks.append(batch['attention_masks'].detach().cpu().numpy())
    batch_size, seq_len = np.array(batch_preds).shape
    preds_list = [[] for _ in range(batch_size)]

    for i in range(batch_size): # số lượng câu chứ không phải batch-size nhé
        for j in range(seq_len):
            if attention_masks[i][j] != 0:
                preds_list[i].append(label_map[batch_preds[i][j]])                   
    return preds_list, batch_preds

# predict đưa ra text cho model BERT ko sử dụng CRF
def predict_example(model, text: str, tokenizer: BertTokenizer, labels, max_seq_length=512, device= 'cpu'):
    """label_map: """
    labels_map = {label: i for i, label in enumerate(labels)}
    id2label = {i: label for label, i in labels_map.items()}
    sentences, sentences_input= convert_text_to_input_bert(text, tokenizer, labels_map, max_seq_length)

    test_data= NerDatasetForPredict(sentences_input)
    data= DataLoader(test_data, batch_size= 8, shuffle= False, num_workers= 4)

    preds_list, batch_preds= predict(model= model, data_loader= data, label_map= id2label, device= device)
    batch_size, seq_len = np.array(batch_preds).shape
    assert batch_size == len(sentences)
    para= ''
    result= []
    for i in range(batch_size):
        origin_token= tokenizer.tokenize(sentences[i], add_special_tokens=True) # token các từ
        pred_token= preds_list[i] #[0: len(origin_token)] # nhãn của từng token
        if i==0:
            para+= tokenizer.decode(tokenizer.encode(sentences[i], add_special_tokens=False))
        else:
            para+=' '+ tokenizer.decode(tokenizer.encode(sentences[i], add_special_tokens=False))
        entity_list= []
        entity= ''
        entity_type= ''
        # bắt đầu lấy ra kết quả
        for j in range(len(origin_token)):
                # nếu pred là O thì lưu entity cũ và bỏ qua
                if pred_token[j] == 'O':
                    if entity != '':
                        entity_list.append((entity, entity_type))
                        entity = ''
                        entity_type = ''
                    continue
                else:
                    # bắt đầu bằng B-
                    if pred_token[j].startswith('B-'):
                        # đã có rồi thì thêm vào entity list
                        if entity != '':
                            entity_list.append((entity, entity_type))
                            
                        # gán cho biến mới
                        entity = origin_token[j]
                        if origin_token[j].startswith('##'):
                            entity = origin_token[j][2:]
                        entity_type = pred_token[j][2:]
                        continue

                    # nếu bắt đầu bằng I-
                    elif pred_token[j].startswith('I-'):
                        if entity != '' and entity_type == pred_token[j][2:]:
                            if origin_token[j].startswith('##'):
                                entity+= origin_token[j][2:]
                            else:
                                entity+=  ' '
                                entity+=  origin_token[j]
                            continue

                        elif entity != '' and entity_type != pred_token[j][2:]:
                            entity_list.append((entity, entity_type))
                            entity = origin_token[j]
                            if origin_token[j].startswith('##'):
                                entity = origin_token[j][2:]                            
                            entity_type = pred_token[j][2:]
                            continue

                        elif entity == '' and entity_type == '':
                            entity = origin_token[j]
                            if origin_token[j].startswith('##'):
                                entity = origin_token[j][2:]
                            entity_type = pred_token[j][2:]
                            continue
        result.extend(entity_list)
    return result, para

In [13]:
crf_model = False
if not crf_model:
    config = AutoConfig.from_pretrained( model_path, num_labels=num_labels, 
                                        id2label=label_map, label2id={label: i for i, label in enumerate(labels)})
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, cache_dir=None)
    model = AutoModelForTokenClassification.from_pretrained(model_path, config=config)
else:
    labels = get_labels()
    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

        # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
            model_path,
            num_labels=num_labels,
            id2label=label_map,
            label2id={label: i for i, label in enumerate(labels)},
    )

    tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            use_fast=False,
    )

    model= BertCRFModel(model_name_or_path=model_path, config=config)
    # checkpoint của model đã train
    model.load_state_dict(torch.load('/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/output/pubmedbert-crf/epoch_7_f1_0.8817114093959733.pt'))
    

In [14]:
device= torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def predict(model, data_loader, label_map, device= 'cpu'):
    '''Model: enity recognition'''
    model = model.to(device)
    batch_preds = []
    attention_masks = []
    labels = []
    for batch in tqdm(data_loader):
        with torch.no_grad():
            outputs = model(batch['input_ids'].to(device), batch['attention_masks'].to(device), batch['token_type_ids'].to(device))
            preds= np.argmax(outputs[0].detach().cpu().numpy(), axis=2)            
            batch_preds.extend(preds)
            attention_masks.extend(batch['attention_masks'].detach().cpu().numpy())
            labels.extend(batch['labels'].detach().cpu().numpy())
    batch_size, seq_len = np.array(batch_preds).shape
    preds_list = [[] for _ in range(batch_size)]
    labels_list = [[] for _ in range(batch_size)]

    for i in range(batch_size): # số lượng câu chứ không phải batch-size nhé
        for j in range(seq_len):
            if attention_masks[i][j] != 0:
                preds_list[i].append(label_map[batch_preds[i][j]])  
                labels_list[i].append(label_map[labels[i][j]])                 
    return preds_list, batch_preds, labels_list

id2label= {i:label for i, label in enumerate(get_labels())}
def idtoLabel(pr, labels=id2label):
    res= []
    for i in pr:
        res.append(labels[i])
    return res
def predict_crf(model, data_loader, label_map, device= 'cpu'):
    '''Model: enity recognition'''
    model = model.to(device)
    batch_preds = []
    labels = []
    for batch in tqdm(data_loader):
        with torch.no_grad():
            outputs = model(batch['input_ids'].to(device), batch['attention_masks'].to(device), batch['token_type_ids'].to(device))
            preds= outputs[0]  
            assert len(preds)==sum(batch['attention_masks'][0].numpy().tolist())         
            batch_preds.append(idtoLabel(preds))
            labels.append(idtoLabel(batch['labels'].numpy().tolist()[0][0: len(preds)]))            
    return batch_preds, labels

def get_entity(pubtator_file: str, id:int,
                not_use: List[str] =['OrganismTaxon', 'CellLine']):

    with open(pubtator_file, 'r') as f: 
        pubtator_text = f.read()
    annotations = {}
    for line in pubtator_text.strip().split('\n'):
        fields = line.split('\t')
        if len(fields) ==1:
            if fields[0]== '': 
                continue
            pubmed_id, section, text = fields[0].split("|")
            if section == 't':
                annotations[pubmed_id] = {'text': text, 'entities': []}
            else:
                annotations[pubmed_id]['text']+= " "+ text
            continue
        pmid = fields[0]
        if pmid not in annotations:
            annotations[pmid] = {'text': fields[2], 'entities': []}
        if len(fields) == 6:
            start, end = int(fields[1]), int(fields[2])
            entity_type = fields[4]
            if fields[4] in not_use:
                continue
            entity_text= fields[3]
            assert annotations[pmid]['text'][start:end] == entity_text
            annotations[pmid]['entities'].append((entity_text, entity_type))
    return annotations[id]['text'], annotations[id]['entities']

In [20]:
# đường dẫn tới file data test đã xử lý
feature = open_pickle('/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/data/Preprocess_BioBERT_6class/Test.pkl')
test_dataset= NerDatasetForPredict(features=feature)

label_map_id= {'GeneOrGeneProduct':0, 'DiseaseOrPhenotypicFeature':1, 'ChemicalEntity':2, 'SequenceVariant':3, 'OrganismTaxon':4, 'CellLine':5, 'O':6}
label_map_new= {}
for i in labels:
    if i!= 'O':
        label_map_new[i]= label_map_id[i[2:]]
    else:
        label_map_new[i]= label_map_id[i]
id2id_new= {}
for i in range(len(labels)):
    id2id_new[i]= label_map_new[labels[i]]
if not crf_model:
    data= DataLoader(test_dataset, batch_size= 16, shuffle= False, num_workers= 4)
    preds_list, batch_preds, label_list = predict(model, data, label_map, device)
else:
    data= DataLoader(test_dataset, batch_size= 1, shuffle= False, num_workers= 4)
    batch_preds, label_list = predict_crf(model, data, label_map, device)

100%|██████████| 69/69 [00:12<00:00,  5.63it/s]


In [21]:
label_map_new

{'B-GeneOrGeneProduct': 0,
 'I-GeneOrGeneProduct': 0,
 'B-DiseaseOrPhenotypicFeature': 1,
 'I-DiseaseOrPhenotypicFeature': 1,
 'B-ChemicalEntity': 2,
 'I-ChemicalEntity': 2,
 'B-SequenceVariant': 3,
 'I-SequenceVariant': 3,
 'B-OrganismTaxon': 4,
 'I-OrganismTaxon': 4,
 'B-CellLine': 5,
 'I-CellLine': 5,
 'O': 6}

In [22]:
def get_result(type_ent, preds_list, labels_list):
    name= {i:[0, 0, 0] for i in type_ent}
    name['all']= [0, 0, 0]
    # tp, p_pred, p_true
    for ent in name.keys():
        for i in range(len(preds_list)):
            if preds_list[i]==labels_list[i] and labels_list[i][2:]== ent:
                name[ent][0]+=1
            if labels_list[i][2:]== ent:
                name[ent][2]+=1
            if preds_list[i][2:]== ent:
                name[ent][1]+=1  
    for i in range(len(preds_list)):            
        if preds_list[i]==labels_list[i] and labels_list[i][2:]!= 'O':
            name['all'][0]+=1
        if labels_list[i][2:]!= 'O':
            name['all'][2]+=1
        if preds_list[i][2:]!= 'O':
            name['all'][1]+=1
    

    # tp, fp, tn, fn
    for i in name.keys():
        tp, p_pred, p_true= name[i]
        pre= tp/(p_pred)
        rec= tp/(p_true)
        print(i)
        print('     precision: ', tp/(p_pred))
        print('     recall: ', tp/(p_true))
        print('     f1: ', 2*pre*rec/(pre+rec))


#### Biobert

In [23]:
# gán nhãn B và I của 1 thực thể là 1
preds_list_new= []
out_label_list_new= []
for i in preds_list:
    preds_list_new+=[label_map_new[j] for j in i]
for i in label_list:
    out_label_list_new+=[label_map_new[j] for j in i]

In [24]:
# nhãn smooth
print(classification_report(y_true=out_label_list_new,y_pred= preds_list_new, digits=4))#, target_names=label_map_id.keys()))

              precision    recall  f1-score   support

           0     0.9480    0.9410    0.9445      4746
           1     0.9192    0.9360    0.9275      3829
           2     0.8787    0.9335    0.9053      2693
           3     0.9745    0.9417    0.9578      1543
           4     0.8899    0.9876    0.9362       483
           5     0.9429    0.8777    0.9091       188
           6     0.9863    0.9808    0.9835     34157

    accuracy                         0.9689     47639
   macro avg     0.9342    0.9426    0.9377     47639
weighted avg     0.9695    0.9689    0.9691     47639



In [25]:
from sklearn.metrics import f1_score
print(f1_score(y_true=out_label_list_new,y_pred= preds_list_new, average='macro'))

0.9377063920423063


In [26]:
from sklearn.metrics import f1_score
print(f1_score(y_true=out_label_list_new,y_pred= preds_list_new, average='micro'))

0.9689330170658494


In [27]:
# ko gán nhãn
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in preds_list:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()
# 0.915

                              precision    recall  f1-score   support

                  B-CellLine     0.9556    0.8600    0.9053        50
            B-ChemicalEntity     0.8866    0.9125    0.8993       754
B-DiseaseOrPhenotypicFeature     0.8560    0.9008    0.8778       917
         B-GeneOrGeneProduct     0.9388    0.9356    0.9372      1180
             B-OrganismTaxon     0.9676    0.9873    0.9773       393
           B-SequenceVariant     0.9643    0.8963    0.9290       241
                  I-CellLine     0.9385    0.8841    0.9104       138
            I-ChemicalEntity     0.8695    0.9350    0.9011      1939
I-DiseaseOrPhenotypicFeature     0.9121    0.9190    0.9155      2912
         I-GeneOrGeneProduct     0.9471    0.9389    0.9430      3566
             I-OrganismTaxon     0.6593    0.9889    0.7911        90
           I-SequenceVariant     0.9676    0.9416    0.9545      1302
                           O     0.9863    0.9808    0.9835     34157

                  

In [None]:
dict(classification_report(y_true=label_test,y_pred= batch, digits=4))

In [28]:
get_result(set([i[2:] for i in get_labels() if i!='O']), batch, label_test)

CellLine
     precision:  0.9428571428571428
     recall:  0.8776595744680851
     f1:  0.9090909090909091
DiseaseOrPhenotypicFeature
     precision:  0.8981790202616056
     recall:  0.914599112039697
     f1:  0.9063146997929606
OrganismTaxon
     precision:  0.8899253731343284
     recall:  0.9875776397515528
     f1:  0.9362119725220805
GeneOrGeneProduct
     precision:  0.9450222882615156
     recall:  0.9380530973451328
     f1:  0.9415247964470763
SequenceVariant
     precision:  0.9671361502347418
     recall:  0.9345430978613092
     f1:  0.950560316413975
ChemicalEntity
     precision:  0.8741698706745893
     recall:  0.928704047530635
     f1:  0.9006121714079942
all
     precision:  0.966414072503621
     recall:  0.966414072503621
     f1:  0.966414072503621


#### biobert focal

In [21]:
preds_list_new= []
out_label_list_new= []
for i in preds_list:
    preds_list_new+=[label_map_new[j] for j in i]
for i in label_list:
    out_label_list_new+=[label_map_new[j] for j in i]

In [22]:
# nhãn smooth
print(classification_report(y_true=out_label_list_new,y_pred= preds_list_new, digits=4))#, target_names=label_map_id.keys()))

              precision    recall  f1-score   support

           0     0.9236    0.9421    0.9327      4746
           1     0.9218    0.9266    0.9242      3829
           2     0.8922    0.9317    0.9115      2693
           3     0.9604    0.9423    0.9513      1543
           4     0.9856    0.9798    0.9827     34828

    accuracy                         0.9678     47639
   macro avg     0.9367    0.9445    0.9405     47639
weighted avg     0.9682    0.9678    0.9680     47639



In [23]:
from sklearn.metrics import f1_score
print(f1_score(y_true=out_label_list_new,y_pred= preds_list_new, average='macro'))

0.9404844018011438


In [26]:
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in preds_list:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()
# 0.915

                              precision    recall  f1-score   support

            B-ChemicalEntity     0.8893    0.9164    0.9027       754
B-DiseaseOrPhenotypicFeature     0.8567    0.8931    0.8745       917
         B-GeneOrGeneProduct     0.9020    0.9356    0.9185      1180
           B-SequenceVariant     0.9500    0.8672    0.9067       241
            I-ChemicalEntity     0.8885    0.9324    0.9099      1939
I-DiseaseOrPhenotypicFeature     0.9170    0.9111    0.9140      2912
         I-GeneOrGeneProduct     0.9248    0.9380    0.9314      3566
           I-SequenceVariant     0.9498    0.9439    0.9468      1302
                           O     0.9856    0.9798    0.9827     34828

                    accuracy                         0.9652     47639
                   macro avg     0.9182    0.9242    0.9208     47639
                weighted avg     0.9657    0.9652    0.9654     47639



#### pubmed bert

In [7]:
preds_list_new= []
out_label_list_new= []
for i in preds_list:
    preds_list_new+=[label_map_new[j] for j in i]
for i in label_list:
    out_label_list_new+=[label_map_new[j] for j in i]

In [8]:
# nhãn smooth
print(classification_report(y_true=out_label_list_new,y_pred= preds_list_new, digits=4))#, target_names=label_map_id.keys()))

              precision    recall  f1-score   support

           0     0.9503    0.9545    0.9524      2904
           1     0.9211    0.8962    0.9085      2071
           2     0.8987    0.9261    0.9122      1447
           3     0.9861    0.9450    0.9651      1200
           4     0.8844    0.9851    0.9321       404
           5     0.8986    0.8732    0.8857       142
           6     0.9879    0.9882    0.9880     28464

    accuracy                         0.9759     36632
   macro avg     0.9324    0.9383    0.9349     36632
weighted avg     0.9761    0.9759    0.9760     36632



In [9]:
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in preds_list:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()
# 0.915

                              precision    recall  f1-score   support

                  B-CellLine     0.9362    0.8800    0.9072        50
            B-ChemicalEntity     0.9028    0.9244    0.9135       754
B-DiseaseOrPhenotypicFeature     0.8698    0.8811    0.8754       917
         B-GeneOrGeneProduct     0.9432    0.9432    0.9432      1180
             B-OrganismTaxon     0.9676    0.9873    0.9773       393
           B-SequenceVariant     0.9487    0.9212    0.9347       241
                  I-CellLine     0.8791    0.8696    0.8743        92
            I-ChemicalEntity     0.8790    0.9120    0.8952       693
I-DiseaseOrPhenotypicFeature     0.8978    0.8449    0.8705      1154
         I-GeneOrGeneProduct     0.9482    0.9553    0.9517      1724
             I-OrganismTaxon     0.2041    0.9091    0.3333        11
           I-SequenceVariant     0.9869    0.9426    0.9643       959
                           O     0.9879    0.9882    0.9880     28464

                  

In [10]:
get_result(set([i[2:] for i in get_labels() if i!='O']), batch, label_test)

CellLine
     precision:  0.8985507246376812
     recall:  0.8732394366197183
     f1:  0.8857142857142857
DiseaseOrPhenotypicFeature
     precision:  0.884863523573201
     recall:  0.8609367455335587
     f1:  0.8727361722956438
OrganismTaxon
     precision:  0.8844444444444445
     recall:  0.9851485148514851
     f1:  0.9320843091334895
GeneOrGeneProduct
     precision:  0.9461775797051766
     recall:  0.9504132231404959
     f1:  0.9482906717058925
SequenceVariant
     precision:  0.9791304347826087
     recall:  0.9383333333333334
     f1:  0.9582978723404256
ChemicalEntity
     precision:  0.8913480885311871
     recall:  0.9184519695922598
     f1:  0.9046970728386657
all
     precision:  0.9731109412535488
     recall:  0.9731109412535488
     f1:  0.9731109412535488


#### pubmed smooth 6 class

In [8]:
preds_list_new= []
out_label_list_new= []
for i in preds_list:
    preds_list_new+=[label_map_new[j] for j in i]
for i in label_list:
    out_label_list_new+=[label_map_new[j] for j in i]

In [9]:
# nhãn smooth
print(classification_report(y_true=out_label_list_new,y_pred= preds_list_new, digits=4))#, target_names=label_map_id.keys()))

              precision    recall  f1-score   support

           0     0.9465    0.9566    0.9515      2904
           1     0.9225    0.9029    0.9126      2071
           2     0.8829    0.9219    0.9020      1447
           3     0.9826    0.9425    0.9621      1200
           4     0.8808    0.9876    0.9312       404
           5     0.9528    0.8521    0.8996       142
           6     0.9883    0.9871    0.9877     28464

    accuracy                         0.9753     36632
   macro avg     0.9366    0.9358    0.9353     36632
weighted avg     0.9756    0.9753    0.9754     36632



In [10]:
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in preds_list:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()
# 0.915

                              precision    recall  f1-score   support

                  B-CellLine     0.9778    0.8800    0.9263        50
            B-ChemicalEntity     0.8977    0.9191    0.9083       754
B-DiseaseOrPhenotypicFeature     0.8713    0.8713    0.8713       917
         B-GeneOrGeneProduct     0.9377    0.9432    0.9404      1180
             B-OrganismTaxon     0.9605    0.9898    0.9749       393
           B-SequenceVariant     0.9492    0.9295    0.9392       241
                  I-CellLine     0.9390    0.8370    0.8851        92
            I-ChemicalEntity     0.8525    0.9091    0.8799       693
I-DiseaseOrPhenotypicFeature     0.8955    0.8614    0.8781      1154
         I-GeneOrGeneProduct     0.9462    0.9594    0.9528      1724
             I-OrganismTaxon     0.2083    0.9091    0.3390        11
           I-SequenceVariant     0.9825    0.9374    0.9594       959
                           O     0.9883    0.9871    0.9877     28464

                  

In [13]:
set([i[2:] for i in get_labels() if i!='O'])

{'CellLine',
 'ChemicalEntity',
 'DiseaseOrPhenotypicFeature',
 'GeneOrGeneProduct',
 'OrganismTaxon',
 'SequenceVariant'}

In [14]:
get_labels()

['B-GeneOrGeneProduct',
 'I-GeneOrGeneProduct',
 'B-DiseaseOrPhenotypicFeature',
 'I-DiseaseOrPhenotypicFeature',
 'B-ChemicalEntity',
 'I-ChemicalEntity',
 'B-SequenceVariant',
 'I-SequenceVariant',
 'B-OrganismTaxon',
 'I-OrganismTaxon',
 'B-CellLine',
 'I-CellLine',
 'O']

In [10]:
def get_result(type_ent, preds_list, labels_list):
    name= {i:[0, 0, 0] for i in type_ent}
    name['all']= [0, 0, 0]
    # tp, p_pred, p_true
    for ent in name.keys():
        for i in range(len(preds_list)):
            if preds_list[i]==labels_list[i] and labels_list[i][2:]== ent:
                name[ent][0]+=1
            if labels_list[i][2:]== ent:
                name[ent][2]+=1
            if preds_list[i][2:]== ent:
                name[ent][1]+=1  
    for i in range(len(preds_list)):            
        if preds_list[i]==labels_list[i] and labels_list[i][2:]!= 'O':
            name['all'][0]+=1
        if labels_list[i][2:]!= 'O':
            name['all'][2]+=1
        if preds_list[i][2:]!= 'O':
            name['all'][1]+=1
    

    # tp, fp, tn, fn
    for i in name.keys():
        tp, p_pred, p_true= name[i]
        pre= tp/(p_pred)
        rec= tp/(p_true)
        print(i)
        print('     precision: ', tp/(p_pred))
        print('     recall: ', tp/(p_true))
        print('     f1: ', 2*pre*rec/(pre+rec))


In [22]:
get_result(set([i[2:] for i in get_labels() if i!='O']), batch, label_test)

ChemicalEntity
     precision:  0.8755790866975512
     recall:  0.9143054595715273
     f1:  0.894523326572008
CellLine
     precision:  0.952755905511811
     recall:  0.852112676056338
     f1:  0.8996282527881041
DiseaseOrPhenotypicFeature
     precision:  0.8845584607794771
     recall:  0.8657653307580879
     f1:  0.8750610053684724
OrganismTaxon
     precision:  0.8807947019867549
     recall:  0.9876237623762376
     f1:  0.9311551925320886
GeneOrGeneProduct
     precision:  0.9427597955706984
     recall:  0.9528236914600551
     f1:  0.9477650282582634
SequenceVariant
     precision:  0.9756733275412685
     recall:  0.9358333333333333
     f1:  0.955338153977031
all
     precision:  0.9724284778335881
     recall:  0.9724284778335881
     f1:  0.9724284778335881


#### biobert smooth

In [29]:
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in preds_list:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()
# 0.915

                              precision    recall  f1-score   support

                  B-CellLine     0.9767    0.8400    0.9032        50
            B-ChemicalEntity     0.8559    0.8979    0.8764       754
B-DiseaseOrPhenotypicFeature     0.8574    0.8855    0.8712       917
         B-GeneOrGeneProduct     0.9118    0.9288    0.9202      1180
             B-OrganismTaxon     0.9645    0.9669    0.9657       393
           B-SequenceVariant     0.9638    0.8838    0.9221       241
                  I-CellLine     0.9750    0.8478    0.9070       138
            I-ChemicalEntity     0.8522    0.9216    0.8855      1939
I-DiseaseOrPhenotypicFeature     0.9120    0.9179    0.9149      2912
         I-GeneOrGeneProduct     0.9339    0.9358    0.9349      3566
             I-OrganismTaxon     0.6581    0.8556    0.7440        90
           I-SequenceVariant     0.9651    0.9332    0.9488      1302
                           O     0.9851    0.9789    0.9820     34157

                  

In [30]:
get_result(set([i[2:] for i in get_labels() if i!='O']), batch, label_test)

ChemicalEntity
     precision:  0.853185595567867
     recall:  0.9149647233568511
     f1:  0.8829958788747536
CellLine
     precision:  0.9754601226993865
     recall:  0.8457446808510638
     f1:  0.905982905982906
DiseaseOrPhenotypicFeature
     precision:  0.8986591026302218
     recall:  0.9101593105249413
     f1:  0.904372648241858
OrganismTaxon
     precision:  0.8943248532289628
     recall:  0.9461697722567288
     f1:  0.9195171026156941
GeneOrGeneProduct
     precision:  0.9283769633507853
     recall:  0.9340497260851243
     f1:  0.9312047053880895
SequenceVariant
     precision:  0.9648648648648649
     recall:  0.9254698639014906
     f1:  0.944756864042342
all
     precision:  0.9627196204790193
     recall:  0.9627196204790193
     f1:  0.9627196204790193


#### pubmed + crf

In [10]:
# batch_preds, label_list
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in batch_preds:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()

                              precision    recall  f1-score   support

                  B-CellLine     0.9756    0.8000    0.8791        50
            B-ChemicalEntity     0.8961    0.9151    0.9055       754
B-DiseaseOrPhenotypicFeature     0.8677    0.8441    0.8557       917
         B-GeneOrGeneProduct     0.9234    0.9398    0.9315      1180
             B-OrganismTaxon     0.9420    0.9924    0.9665       393
           B-SequenceVariant     0.8945    0.8797    0.8870       241
                  I-CellLine     0.9577    0.7391    0.8344        92
            I-ChemicalEntity     0.8688    0.8889    0.8787       693
I-DiseaseOrPhenotypicFeature     0.9218    0.7868    0.8490      1154
         I-GeneOrGeneProduct     0.9456    0.9269    0.9361      1724
             I-OrganismTaxon     0.1923    0.9091    0.3175        11
           I-SequenceVariant     0.9374    0.9364    0.9369       959
                           O     0.9831    0.9882    0.9856     28464

                  

In [11]:
len(label_test)

36632

In [12]:
get_result(set([i[2:] for i in get_labels() if i!='O']), batch, label_test)

SequenceVariant
     precision:  0.9288702928870293
     recall:  0.925
     f1:  0.9269311064718162
GeneOrGeneProduct
     precision:  0.9363542026980284
     recall:  0.9321625344352618
     f1:  0.934253666954271
ChemicalEntity
     precision:  0.8830290736984449
     recall:  0.902557014512785
     f1:  0.8926862611073137
CellLine
     precision:  0.9642857142857143
     recall:  0.7605633802816901
     f1:  0.8503937007874015
OrganismTaxon
     precision:  0.8583690987124464
     recall:  0.9900990099009901
     f1:  0.9195402298850576
DiseaseOrPhenotypicFeature
     precision:  0.8961108151305275
     recall:  0.8121680347658136
     f1:  0.8520770010131712
all
     precision:  0.9674874426730727
     recall:  0.9674874426730727
     f1:  0.9674874426730727


#### biobert + crf

In [7]:
# batch_preds, label_list
label_test= []
for i in label_list:
    label_test.extend(i)
batch= []
for i in batch_preds:
    batch.extend(i)
# nhãn chính xác
print(classification_report(y_true=label_test,y_pred= batch, digits=4))#, target_names=label_map_id.keys()

                              precision    recall  f1-score   support

                  B-CellLine     0.9744    0.7600    0.8539        50
            B-ChemicalEntity     0.9008    0.9151    0.9079       754
B-DiseaseOrPhenotypicFeature     0.8749    0.8691    0.8720       917
         B-GeneOrGeneProduct     0.9221    0.9331    0.9275      1180
             B-OrganismTaxon     0.9652    0.9873    0.9761       393
           B-SequenceVariant     0.8943    0.9129    0.9035       241
                  I-CellLine     0.9815    0.7681    0.8618       138
            I-ChemicalEntity     0.8920    0.9371    0.9140      1939
I-DiseaseOrPhenotypicFeature     0.9342    0.8929    0.9131      2912
         I-GeneOrGeneProduct     0.9401    0.9204    0.9301      3566
             I-OrganismTaxon     0.6984    0.9778    0.8148        90
           I-SequenceVariant     0.9505    0.9301    0.9402      1302
                           O     0.9816    0.9846    0.9831     34157

                  

In [8]:
len(label_test)

47639

In [11]:
get_result(set([i[2:] for i in get_labels() if i!='O']), batch, label_test)

OrganismTaxon
     precision:  0.9015151515151515
     recall:  0.9855072463768116
     f1:  0.9416419386745797
CellLine
     precision:  0.9795918367346939
     recall:  0.7659574468085106
     f1:  0.8597014925373134
DiseaseOrPhenotypicFeature
     precision:  0.9195993502977802
     recall:  0.8871768085662053
     f1:  0.9030971686827063
GeneOrGeneProduct
     precision:  0.935538954108858
     recall:  0.9235145385587863
     f1:  0.9294878591877849
SequenceVariant
     precision:  0.9414473684210526
     recall:  0.9274141283214518
     f1:  0.9343780607247796
ChemicalEntity
     precision:  0.8943988583660364
     recall:  0.9309320460453027
     f1:  0.9122998544395925
all
     precision:  0.9649656793803396
     recall:  0.9649656793803396
     f1:  0.9649656793803396


#### code for report

In [33]:
from transformers import AutoTokenizer
text= "Founder mutations in the BRCA1 gene."
tokenizer= AutoTokenizer.from_pretrained('/media/data3/users/longnd/ehr-relation-extraction/biobert_ner/model/biobert-v1.1', cache_dir=None)
tok= tokenizer.tokenize(text, add_special_tokens=True)
tok_pad= tokenizer.encode(text, add_special_tokens=True, max_length=20, padding='max_length', truncation=True)
tok1= tokenizer.encode(text, add_special_tokens=True)

In [36]:
print('Original: ', text)
print('Tokenized: ', tokenizer.tokenize(text))
print('Tokenized using padding: ', tokenizer.tokenize(text, add_special_tokens=True, max_length=13, padding='max_length', truncation=True))
print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text, add_special_tokens=True, max_length=13, padding='max_length', truncation=True)))
print('Attention Mask: ', tokenizer.encode_plus(text, add_special_tokens=True, max_length=13, padding='max_length', truncation=True)['attention_mask'])

Original:  Founder mutations in the BRCA1 gene.
Tokenized:  ['Founder', 'mutations', 'in', 'the', 'BR', '##CA', '##1', 'gene', '.']
Tokenized using padding:  ['[CLS]', 'Founder', 'mutations', 'in', 'the', 'BR', '##CA', '##1', 'gene', '.', '[SEP]', '[PAD]', '[PAD]']
Token IDs:  [101, 16505, 17157, 1107, 1103, 26660, 11356, 1475, 5565, 119, 102, 0, 0]
Attention Mask:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]


In [14]:
from colorama import init, Fore, Back, Style

# Initialize colorama to enable ANSI color codes on Windows
init()

# Function to print colored text
def print_colored_text(text, color='white', background=None, style=None):
    color_code = getattr(Fore, color.upper(), Fore.WHITE)
    background_code = getattr(Back, background.upper(), '')
    style_code = getattr(Style, style.upper(), '')

    colored_text = f"{style_code}{background_code}{color_code}{text}{Style.RESET_ALL}"
    print(colored_text)

In [15]:
print_colored_text(text,background='blue', color='green', style='underline')

Founder mutations in the BRCA1 gene in Polish families with breast-ovarian cancer.


In [1]:
def print_bio_tags(sentence, bio_tags):
    words = sentence.split()
    tags = bio_tags.split()

    for word, tag in zip(words, tags):
        print(f"{word}\t{tag}")

# Example usage
sentence = "Congenital hypothyroidism due to a new deletion in the sodium/iodide symporter protein."
bio_tags = "B-Disease I-Disease O O O O O O O B-Disease I-Disease I-Disease"

print_bio_tags(sentence, bio_tags)

Congenital	B-Disease
hypothyroidism	I-Disease
due	O
to	O
a	O
new	O
deletion	O
in	O
the	O
sodium/iodide	B-Disease
symporter	I-Disease
protein.	I-Disease


In [2]:
print("Original: ", sentence)
print("Labels: ", bio_tags)

Original:  Congenital hypothyroidism due to a new deletion in the sodium/iodide symporter protein.
Labels:  B-Disease I-Disease O O O O O O O B-Disease I-Disease I-Disease
