In [1]:
import os
import re
import json
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from tqdm import tqdm
from bioc import pubtator
import copy

## Load Dev set

In [2]:
# Load processed dev data
datapath = "../data/NER/PURE/dev.json"
dev = []
for line in open(datapath, 'r'):
    dev.append(json.loads(line))

In [4]:
# Filter out only 100 Dev samples
# which are same with BioRED Test set
dev = [data for data in dev if "Task2" in data['doc_key']]
doc_keys = [data["doc_key"] for data in dev]
print(len(dev), len(doc_keys))

100 100


In [5]:
gold_unique_cuis = []

for data in dev:
    
    # To store the cased version of tokens
    data["cased_tokens"] = []
    
    # Store entity mention using spans for future use
    if data["ner"]:
        for annotation in data["ner"]:
            
            if annotation['entity_type'] == "Disease":
                annotation['entity_type'] = "DiseaseOrPhenotypicFeature"
            
            annotation["mentions"] = []
            for span in annotation["mention_spans"]:
                    
                entity_tokens = data["tokens"][int(span[0]):min(int(span[1]+1),len(data["tokens"]))]
                
                entity_mention = ''
                for idx, token in enumerate(entity_tokens):
                    if token.startswith("##"):
                        entity_mention += token.lstrip("##")
                    else:
                        if idx == 0:
                            entity_mention += token
                        else:
                            entity_mention += " " + token
                            
                annotation["mentions"].append(entity_mention)
                    
    # We need to make char2token mapper
    # as PubTator3 spans are char-based, whereas the PURE models takes token-based spans
    all_sentences = " ".join(data['sentence_texts'])

    char2token = {}    
    chunk = ""
    target_tokens = data['tokens'][1:-1]  # Exclude [CLS] and [SEP]
    
    uncased_chunk = ""
    uncased2cased = {}

    token_idx = 1
    # assume that target_tokens are sorted by its index
    unk_token_indices = []
    for curr_idx, char in enumerate(all_sentences):
        
        if char == " ":
            continue
            
        uncased_chunk += char.lower()
        chunk += char
        
        target_token = target_tokens[0].lstrip("##")
        if target_token == "[UNK]":
            unk_start_idx = curr_idx
            unk_token_indices.append(token_idx)
            token_idx += 1
            target_tokens = target_tokens[1:]
            continue
        else:         
            if target_token in uncased_chunk:
                if chunk != uncased_chunk:
                    if target_tokens[0].startswith("##"):
                        uncased2cased[token_idx] = "##"+chunk[-len(target_token):]
                    else:
                        uncased2cased[token_idx] = chunk[-len(target_token):]
                
                if unk_token_indices:
                    unk_chunk = uncased_chunk[:-len(target_token)]
                    unk_chunks = unk_chunk.strip().split(" ")
                    if len(unk_chunks) != len(unk_token_indices):  # NOTE: temporary snippet
                        unk_chunks = [unk_chunks[0][0], unk_chunks[0][1]]
                        
                    assert len(unk_chunks) == len(unk_token_indices)
                    
                    for unk_chunk, unk_token_idx in zip(unk_chunks, unk_token_indices):
                        char2token[(unk_start_idx, unk_start_idx+len(unk_chunk))] = unk_token_idx
                        unk_start_idx = unk_start_idx+len(unk_chunk)
                        
                    unk_token_indices = []
                    
                char2token[(curr_idx-len(target_token)+1, curr_idx+1)] = token_idx
                chunk = ""
                uncased_chunk = ""
                token_idx += 1
                target_tokens = target_tokens[1:]

    if target_token == "[UNK]":
        char2token[(unk_start_idx, len(all_sentences))] = unk_token_indices[0]
        unk_token_indices = []
        
    assert token_idx == len(data['tokens'])-1
    
    data['char2token'] = char2token
    data['token2char'] = {v:k for k, v in char2token.items()}
    
    for idx in range(len(data['tokens'])):
        if idx in uncased2cased:
            data["cased_tokens"].append(uncased2cased[idx]) 
        else:
            data["cased_tokens"].append(data["tokens"][idx])
        

## Load NER predictions

In [11]:
import pandas as pd

# Predicted NER from PURE
pred_ner_path = '../ner/models/final_model/postprocessed_predictions.csv'
pred_df = pd.read_csv(pred_ner_path, delimiter="\t").iloc[:, 1:]

# # If you are to use Gold NER, change the path
# pred_df = pd.read_csv("./data/ner_output/dev_gold_ner.csv", delimiter="\t").iloc[:, 1:]
print(len(pred_df))
pred_df.head(3)

3496


Unnamed: 0,abstract_id,offset_start,offset_finish,type
0,BC8_BioRED_Task2_Doc594,8,13,GeneOrGeneProduct
1,BC8_BioRED_Task2_Doc594,56,72,DiseaseOrPhenotypicFeature
2,BC8_BioRED_Task2_Doc594,97,108,DiseaseOrPhenotypicFeature


In [13]:
preds = defaultdict(list)
for item in pred_df.to_dict('records'):
    if item['abstract_id'] not in doc_keys:
        continue
    preds[item['abstract_id']].append({
        'abstract_id': item['abstract_id'],
        'start': item['offset_start'],
        'end': item['offset_finish'],
        'type': item['type'],   
    })
print(len(preds))

100


## Load PubTator3 Prediction files

In [14]:
# Align name of labels btwn the BioRED and PubT3
biored2pubt = {
    'GeneOrGeneProduct': 'Gene',
    'DiseaseOrPhenotypicFeature': 'Disease',
    'OrganismTaxon': 'Species',
    'ChemicalEntity': 'Chemical',
    'SequenceVariant': 'SNP',
    'CellLine': 'CellLine'
}
pubt2biored = {v:k for k, v in biored2pubt.items()}
pubt2biored['DNAMutation'] = 'SequenceVariant'
pubt2biored['ProteinMutation'] = 'SequenceVariant'

In [15]:
fp = ".outputs/pubtator/biored_task1_val_100_pubtator3_aligned.pubtator"

pubtator_files = defaultdict(list)

# To make lookup dict for entity mentions based on PubT result
pubt_cui_lookup = {}

with open(fp) as f:
    docs = pubtator.load(f)
    for doc in docs:
        
        if doc.pmid not in doc_keys:
            continue            
            
        # Add mentions for NER-pred file
        ner_preds = preds[doc.pmid]
        for pred in ner_preds:
            pred['text'] = doc.text[pred['start']:pred['end']]
            
        for e in doc.annotations:
            if e.type == 'CellLine':
                e.id = e.id.replace(":", "_")
            elif e.type in ["Disease", "Chemical"] and e.id.startswith('MESH'):
                e.id = e.id.split(":")[-1]
            # Parsing rules for SeqVar IDs
            elif e.type == "SNP" or e.type.endswith("Mutation"):
                # Define a regular expression pattern to match the RS# and the following number
                pattern = re.compile(r'RS#:(\d+)')
                matches = pattern.findall(e.id)            
                if matches:
                    rs_num = matches[0]
                    e.id = f'rs{rs_num}'
                else:
                    e.id = e.id.split(';')[0].split(':')[1]

            pubtator_files[doc.pmid].append({
                'start':int(e.start),
                'end':int(e.end),
                'mention':e.text,
                'type':e.type,
                'cui':e.id,
            })
            
            mention = e.text.lower()
            if not e.type == "SNP" and not e.type.endswith("Mutation"):
                cuis = re.split(r'[,|]', e.id)
            else:
                cuis = [e.id]
            for cui in cuis:
                if e.type == 'Chromosome':
                    continue
                biored_type = pubt2biored[e.type]
                lookup_key = (mention, biored_type)
                # Lowercase key
                if lookup_key not in pubt_cui_lookup:
                    pubt_cui_lookup[lookup_key] = {}
                    pubt_cui_lookup[lookup_key][cui] = 1
                else:
                    if cui not in pubt_cui_lookup[lookup_key]:
                        pubt_cui_lookup[lookup_key][cui] = 1
                    else:
                        pubt_cui_lookup[lookup_key][cui] += 1

In [16]:
# Build look-up dictionary from PubTator3 outputs
pubt_cui_lookup_listed = {}
for k, values in pubt_cui_lookup.items():
    values = sorted([(inner_k, inner_v) for inner_k, inner_v in values.items()], key=lambda x: -x[1])
    pubt_cui_lookup_listed[k] = values[0][0]  # Only take the most frequent one

## Load NEL predictions from ResCNN

In [18]:
# Map normalized queries into original mentions
# If you have different NER prediction, you need to map mentions based on the predicted NER
disease_map_path = './resources/data/biosyn-processed-bc8biored-disease/mention_map_disease_dev.json'
# disease_map_path = './resources/data/biosyn-processed-bc8biored-disease/mention_map_disease_dev_gold.json'
with open(disease_map_path) as f:
    mention_map_disease = json.load(f)
print(len(mention_map_disease))

450


In [19]:
chemical_map_path = './resources/data/biosyn-processed-bc8biored-disease/mention_map_chemical_dev.json'
# chemical_map_path = './resources/data/biosyn-processed-bc8biored-disease/mention_map_chemical_dev_gold.json'
with open(chemical_map_path) as f:
    mention_map_chemical = json.load(f)
print(len(mention_map_chemical))

275


In [20]:
# load NEL output from ResCNN for Disease concept
disease_output_path = './outputs/rescnn/results/bc8biored-disease-aio/lightweight_cnn_text_with_attention_pooling_biolinkbert/lr_0.001-depth_4-fs_256-drop_0.25/predictions_dev_disease.json'
with open(disease_output_path, 'r') as f:
    disease_rescnn = json.load(f)
disease_rescnn_dict = {d['mention']: d['pred_id'] for d in disease_rescnn}
print(len(disease_rescnn), len(disease_rescnn_dict), disease_rescnn[0])

400 384 {'mention': 'multiple pterygium syndrome', 'gold_id': 'C537377', 'pred_id': 'C537377|265000', 'pred_name': 'multiple pterygium syndrome'}


In [21]:
# load NEL output from ResCNN for Chemical concept
chemical_output_path = './outputs/rescnn/results/bc8biored-chemical-aio/lightweight_cnn_text_biolinkbert/lr_0.001-depth_3-fs_256-drop_0.25/predictions_dev_chemical.json'
with open(chemical_output_path, 'r') as f:
    chemical_rescnn = json.load(f)
chemical_rescnn_dict = {d['mention']: d['pred_id'] for d in chemical_rescnn}
print(len(chemical_rescnn), len(chemical_rescnn_dict), chemical_rescnn[0])

224 223 {'mention': '1 25 oh 2d3', 'gold_id': 'D002117', 'pred_id': 'D002117'}


In [22]:
# Step 1. Match ResCNN-based Disease/Chemical output with NER predictions

for doc_key, values in preds.items():
#     print(doc_key)
    for pred in values:
        # ResCNN only cares about Disease and Chemical
        if pred['type'] not in ['DiseaseOrPhenotypicFeature', 'ChemicalEntity']:
            continue
            
        if pred['type'] == 'DiseaseOrPhenotypicFeature':
#             print(pred['text'], normalized_query)
            normalized_query = mention_map_disease[pred['text']]
            if normalized_query not in disease_rescnn_dict:  # composite would not exist
                continue     
            cui_pred = disease_rescnn_dict[normalized_query]
            
        elif pred['type'] == 'ChemicalEntity':
            normalized_query = mention_map_chemical[pred['text']]
            if normalized_query not in chemical_rescnn_dict:  # composite would not exist
                continue     
            cui_pred = chemical_rescnn_dict[normalized_query]
            
        cui_pred_list = cui_pred.split('|')
        pred['cui'] = cui_pred_list[0] # Only choose the primary ID
        
#         if len(cui_pred_list) <= 1:    
#             pred['cui'] = cui_pred.split('|')[0]          
#         else:
#             pred['cui'] = '|'.join(cui_pred.split('|')[:2])

In [23]:
# Step 2. Match PubT3 output with NER output

num_total = 0
num_rescnn = 0
num_pubt3_em = 0
num_pubt3_lookup = 0

for doc_key, value_list in pubtator_files.items():  # PubT outputs (NER+NEL)
    
    target_pred = preds[doc_key]  # Our NER output
#     print(doc_key)
    
    for p in target_pred:
        
        num_total += 1
        
        # already matched with ResCNN preds (Disease and Chem, non-composites)
        if 'cui' in p:
            num_rescnn += 1
            continue
        
        flag = False
        for v in value_list:
            # exact or partial match for span+type based on PubT3 outputs
            if (v['start'] == p['start'] and v['end'] == p['end'] and pubt2biored[v['type']] == p['type']) \
            or (v['end'] >= p['start'] and v['start'] <= p['end'] and pubt2biored[v['type']] == p['type']):
                p['cui'] = v['cui']  # follow PubT3 style
                flag = True
                num_pubt3_em += 1
                break
                
        if not flag:
            # Lookup for mention throughout all PubT3 outputs
            flag_lookup = False
            query = (p['text'].lower(), p['type'])
            for (key_mention, key_type), key_cui in pubt_cui_lookup_listed.items():
                if query == (key_mention, key_type) or \
                (p['text'].lower() in key_mention or key_mention in p['text'].lower()) and p['type'] == key_type:
                    num_pubt3_lookup += 1
                    p['cui'] = key_cui
                    flag_lookup = True
                    break
                    
            if not flag_lookup:
                p['cui'] = ""

print('Total mentions:', num_total)
print('ResCNN matched:', num_rescnn)
print('PubT3 matched:', num_pubt3_em)
print('PubT3-lookup matched:', num_pubt3_lookup)

Total mentions: 3496
ResCNN matched: 1653
PubT3 matched: 1609
PubT3-lookup matched: 79


# Evaluation

### Align NEL outputs for processed json files

In [24]:
from collections import defaultdict

PUBTATOR_ENTITY_TYPES = [
    "GeneOrGeneProduct", "DiseaseOrPhenotypicFeature", "ChemicalEntity", 
    "OrganismTaxon", "SequenceVariant", "CellLine"
]

for data in tqdm(dev):

    data['pred_ner'] = {}
    
    pred_annotation = preds[data['doc_key']]
    for ann in pred_annotation:
#         if ann['type'] not in PUBTATOR_ENTITY_TYPES:
#             continue
        if not ann['cui'] or ann['cui'] == '-':  # cui-less
            continue
        
        start, end = ann['start'], ann['end']
        for (char_start, char_end), token_idx in data['char2token'].items():
            if char_start <= start < char_end:
                token_start = token_idx
            if char_start <= end-1 < char_end:
                token_end = token_idx
                
        for idx, (sent_token_start, sent_token_end) in enumerate(data['sentence_spans']):
            if sent_token_start <= token_start < sent_token_end:
                sent_idx = idx
                break
                
        if ann['type'] != "SequenceVariant":
            cuis = re.split(r'[,;|]', ann['cui'])
        else:
            cuis = [ann['cui']]
        
        ann['database_id'] = ann['cui']
        
        for cui in cuis:       
            if cui not in data['pred_ner']:
                value = {
                    'entity_id': cui,
                    'mention_spans': [[token_start, token_end]],
                    'entity_type': ann['type'],
                    'sentence_mentions': [sent_idx],
                }
                data['pred_ner'][cui] = value
            else:
                target_ann = data['pred_ner'][cui]
                target_ann['mention_spans'].append([token_start, token_end])
                if sent_idx not in target_ann['sentence_mentions']:
                    target_ann['sentence_mentions'].append(sent_idx)
          
    data['pred_ner'] = list(data['pred_ner'].values())


100%|████████████████████████████████████████| 100/100 [00:00<00:00, 450.29it/s]

































































































































































In [26]:
# NEL Evaluation with F-measure
# (pmid, entity_type, id) in the ID evaluation. 
# If a mention has multiple IDs, and different IDs will be expanded into different instances.

labels = list(biored2pubt.keys())

nel_scores = {
    label: {
        "n_gold":0, "n_pred":0, "n_correct":0
    } for label in labels
}

multiple_id_counter = Counter()

for data in tqdm(dev):
    golds = []  # Doc-level
    for gold in data["ner"]:
        # Consider multiple ID cases
        if gold['entity_type'] != "SequenceVariant":
            gold_cuis = re.split(r'[,|;]', gold["entity_id"])
            gold_cuis = [str(cui.strip()) for cui in gold_cuis]
        else:
            gold_cuis = [gold["entity_id"]] 
            
        multiple_id_counter[str(len(gold_cuis))] += 1
        
        for cui in gold_cuis:
            if (gold['entity_type'], cui) not in golds:
                golds.append((gold['entity_type'], cui))
                nel_scores[gold['entity_type']]['n_gold'] += 1
        
    for pred in data["pred_ner"]:
        pred_tuple = (pred['entity_type'], pred['entity_id'])
        nel_scores[pred['entity_type']]['n_pred'] += 1
        
#         if pred['entity_type'] == "SequenceVariant":
#             print(pred)
            
        if pred_tuple in golds:
            nel_scores[pred['entity_type']]['n_correct'] += 1
            
#             if pred['entity_type'] == "SequenceVariant":
#                 print(golds, '\n')
            
print(nel_scores)

100%|██████████████████████████████████████| 100/100 [00:00<00:00, 14991.97it/s]

{'GeneOrGeneProduct': {'n_gold': 436, 'n_pred': 396, 'n_correct': 344}, 'DiseaseOrPhenotypicFeature': {'n_gold': 344, 'n_pred': 357, 'n_correct': 311}, 'OrganismTaxon': {'n_gold': 113, 'n_pred': 113, 'n_correct': 112}, 'ChemicalEntity': {'n_gold': 220, 'n_pred': 226, 'n_correct': 188}, 'SequenceVariant': {'n_gold': 139, 'n_pred': 118, 'n_correct': 78}, 'CellLine': {'n_gold': 22, 'n_pred': 20, 'n_correct': 17}}





In [27]:
def save_div(x, y):
    if y == 0:
        return 0.0
    else:
        return x/y
            
for k, v in nel_scores.items():
    v["precision"] = save_div(v["n_correct"], v["n_pred"])
    v["recall"] = save_div(v["n_correct"], v["n_gold"])
    v["f1"] = save_div(2*v["precision"]*v["recall"], (v["precision"]+v["recall"]))

print(nel_scores)

{'GeneOrGeneProduct': {'n_gold': 436, 'n_pred': 396, 'n_correct': 344, 'precision': 0.8686868686868687, 'recall': 0.7889908256880734, 'f1': 0.8269230769230771}, 'DiseaseOrPhenotypicFeature': {'n_gold': 344, 'n_pred': 357, 'n_correct': 311, 'precision': 0.8711484593837535, 'recall': 0.9040697674418605, 'f1': 0.8873038516405135}, 'OrganismTaxon': {'n_gold': 113, 'n_pred': 113, 'n_correct': 112, 'precision': 0.9911504424778761, 'recall': 0.9911504424778761, 'f1': 0.9911504424778761}, 'ChemicalEntity': {'n_gold': 220, 'n_pred': 226, 'n_correct': 188, 'precision': 0.831858407079646, 'recall': 0.8545454545454545, 'f1': 0.8430493273542601}, 'SequenceVariant': {'n_gold': 139, 'n_pred': 118, 'n_correct': 78, 'precision': 0.6610169491525424, 'recall': 0.5611510791366906, 'f1': 0.6070038910505836}, 'CellLine': {'n_gold': 22, 'n_pred': 20, 'n_correct': 17, 'precision': 0.85, 'recall': 0.7727272727272727, 'f1': 0.8095238095238095}}


In [28]:
gold = 0
correct = 0
pred = 0

for entity_type, score in nel_scores.items():
    gold += score['n_gold']
    pred += score['n_pred']
    correct += score['n_correct']
    
    print(f"{entity_type}({score['n_gold']}) >>> prec: {score['precision']*100:.2f} | rec: {score['recall']*100:.2f} | f1: {score['f1']*100:.2f}")
    print()
    
print("="*89)

precision = save_div(correct, pred)
recall = save_div(correct, gold)
f1 = save_div(2*precision*recall, precision+recall)

print(f"Total Precision: {precision*100:.2f} ({correct} out of {pred})")
print(f"Total Recall: {recall*100:.2f} ({correct} out of {gold})")
print(f"Total F1: {f1*100:.2f}")

GeneOrGeneProduct(436) >>> prec: 86.87 | rec: 78.90 | f1: 82.69

DiseaseOrPhenotypicFeature(344) >>> prec: 87.11 | rec: 90.41 | f1: 88.73

OrganismTaxon(113) >>> prec: 99.12 | rec: 99.12 | f1: 99.12

ChemicalEntity(220) >>> prec: 83.19 | rec: 85.45 | f1: 84.30

SequenceVariant(139) >>> prec: 66.10 | rec: 56.12 | f1: 60.70

CellLine(22) >>> prec: 85.00 | rec: 77.27 | f1: 80.95

Total Precision: 85.37 (1050 out of 1230)
Total Recall: 82.42 (1050 out of 1274)
Total F1: 83.87


## Save the output file for RE

In [54]:
# assign empty string to entities failed to be normalized
pred_output = []
preds_ = copy.deepcopy(preds)
for k, v in preds_.items():
    for e in v:
        if 'mention' in e:
            del e['mention']
        if 'cui' in e:
            del e['cui']
        if 'database_id' not in e:
            e['database_id'] = ""
        pred_output.append(e)

In [56]:
output_csv_path = "./outputs/final/nel_predictions.csv"
pd.DataFrame(pred_output).to_csv("./data/nel_output/dev_100_predictions_nel_ResCNN_DC_PubT3_BioLinkBERT_goldner_032124.csv" ,index=False)

In [58]:
# Drop unused keys for RE input
unused_keys = ['cased_tokens', 'char2token', 'token2char']
for data in dev:
    for k in unused_keys:
        del data[k]
print(dev[0].keys())

dict_keys(['doc_key', 'sentence_texts', 'sentence_spans', 'tokens', 'token_indices', 'ner', 'relations', 'no_relations', 'pred_ner'])


In [59]:
output_json_path = "./outputs/final/nel_predictions.json"
with open(output_json_path, "w") as f:
    for data in dev:
        f.write(json.dumps(data))
        f.write("\n")