# A CRF system for detecting rare diseases

This notebook contains the implementation of a CRF system for detecting diseases, rare disases and symptoms from texts. 


The system works for English and Spanish. You only have to set the parameter LANG:


In [1]:
#LANG='es'
LANG='en'

If we are running from google colab, we should mount our Google Drive directory (eoc comment the code):

In [2]:
#If your run from google colab:
from google.colab import drive
drive.mount("/content/drive/")

#If you run from JupiterLab
#root='./'

!ls '.'

Mounted at /content/drive/
drive  sample_data


We define the directories needed for our system:

In [3]:
from datetime import datetime

today=datetime.today().strftime('%Y-%m-%d')

#dataset path
root = '/content/drive/My Drive/Colab Notebooks'
path = root + '/ner/data/gold_nlp4rare_corpus/'

#file to save the trained model
allTypes = True
sTypes = ''
if allTypes:
    sTypes='_all'
    
model = root + '/ner/models/{}/CRF{}.pkl'.format(LANG, sTypes)

#file to save the predictions of the model
output = root + '/ner/outputs/{}/CRF{}-date:{}.csv'.format(LANG,sTypes,today)

#file to save the scores (metrics)
scores = root + '/ner/scores/{}/CRF{}-date:{}.txt'.format(LANG,sTypes,today)

print('data:', path)
print('model trained:', model)
print('prediction files:', output)
print('scores file:', scores)

data: /content/drive/My Drive/Colab Notebooks/ner/data/gold_nlp4rare_corpus/
model trained: /content/drive/My Drive/Colab Notebooks/ner/models/en/CRF_all.pkl
prediction files: /content/drive/My Drive/Colab Notebooks/ner/outputs/en/CRF_all-date:2021-07-08.csv
scores file: /content/drive/My Drive/Colab Notebooks/ner/scores/en/CRF_all-date:2021-07-08.txt


## loading data

Correcting "NA" issue for the Spanish corpus. NA is the acronym for neurálgica amiotrófica.   Pandas takes that string as NaN, so we will manually asign NA to NaNs in the training csv.

In [4]:
import pandas as pd

df_train=pd.read_csv(path+'train{}.csv'.format(sTypes))
df_train=df_train[['Sentence #','Word','POS','Tag']]
if LANG=='es':
    #NA is an acronym of neuroacantocitosis, and pandas takes that string as NaN, so we will manually asign NA to NaNs in the training csv.
    df_train=df_train.replace(np.nan, 'NA')
    
print('training set loaded:',len(df_train))

df_dev=pd.read_csv(path+'dev{}.csv'.format(sTypes))
df_dev=df_dev[['Sentence #','Word','POS','Tag']]
if LANG=='es':
    #NA is an acronym of neuroacantocitosis, and pandas takes that string as NaN, so we will manually asign NA to NaNs in the training csv.
    df_dev=df_dev.replace(np.nan, 'NA')

print('development set loaded:',len(df_dev))

df_test=pd.read_csv(path+'test{}.csv'.format(sTypes))
df_test=df_test[['Sentence #','Word','POS','Tag']]
if LANG=='es':
  #Fix the 'NaN' issue
  df_test=df_test.replace(np.nan, 'NA')
    
print('test set loaded:',len(df_test))

training set loaded: 135656
development set loaded: 18492
test set loaded: 37837


Our CRF does not exploit the development set to adjust any parameter. For this reason, we join the training and development set to train our model. 
This function rewrites the id sentences of the development set. Finally, it concat the two datasets into one.


In [5]:
def concat(df_tr,df_dev):
    """This function rewrites the id sentences of the development set. Finally, it concat the two datasets into one"""
    int(df_tr.iloc[-1]['Sentence #'][9:])
    list_of_dev_sent_numb=[]
    for i in df_dev['Sentence #']:
        list_of_dev_sent_numb.append(int(i[9:]))
    l = [int(df_tr.iloc[-1]['Sentence #'][9:])] * len(df_dev)
    se=['Sentence:']* len(df_dev)
    num=[sum(x)for x in zip(list_of_dev_sent_numb,l)]
    new_sentence_number=[]
    for i in range(len(num)):
        new_sentence_number.append(se[i]+str(num[i]))
    df_dev['Sentence #']=new_sentence_number


    df=pd.concat([df_tr,df_dev],axis=0)
    return df

df_train=concat(df_train,df_dev)
print('full training dataset loaded:',len(df_train))

full training dataset loaded: 154148


## Feature set

Our CRF system uses a feature set to represent each token. The following functions are used to extract these features. In particular, the function *word2features* allows to represent each token as a feature vector. See this functions to know the features used.




In [6]:
class Sentence(object):
    def __init__(self, df):
        self.n_sent = 1
        self.df = df
        self.empty = False
        agg = lambda s : [(w, p, t) for w, p, t in zip(s['Word'].values.tolist(),
                                                       s['POS'].values.tolist(),
                                                       s['Tag'].values.tolist())]
        self.grouped = self.df.groupby("Sentence #").apply(agg)
        self.sentences = [s for s in self.grouped]
        
    def get_text(self):
        try:
            s = self.grouped['Sentence: {}'.format(self.n_sent)]
            self.n_sent +=1
            return s
        except:
            return None

In [7]:
def word2features(sent, i):
  #if str(sent[i][0])!='nan':#esto es extra
    word = sent[i][0]
    postag = sent[i][1]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    #if i > 0 and str(sent[i-1][0])!='nan':#esto es extra
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True
    
    #if i > 1 and str(sent[i-2][0])!='nan':#esto es extra
    if i > 1:
        word1 = sent[i-2][0]
        postag1 = sent[i-2][1]
        features.update({
            '-2:word.lower()': word1.lower(),
            '-2:word.istitle()': word1.istitle(),
            '-2:word.isupper()': word1.isupper(),
            '-2:postag': postag1,
            '-2:postag[:2]': postag1[:2],
        })
    else:
        features['Second_word'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True
    if i < len(sent)-2:
        word1 = sent[i+2][0]
        postag1 = sent[i+2][1]
        features.update({
            '+2:word.lower()': word1.lower(),
            '+2:word.istitle()': word1.istitle(),
            '+2:word.isupper()': word1.isupper(),
            '+2:postag': postag1,
            '+2:postag[:2]': postag1[:2],
        })
    else:
        features['Second_to_last'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

Now, we have to load the features of the training set (X) and also their corresponding labels (y) to train our model. 

In [8]:
getter = Sentence(df_train)
#getting all sentences
sentences = getter.sentences
X_train = [sent2features(s) for s in sentences]
y_train = [sent2labels(s) for s in sentences]

print('training vectors and their labels were loaded!')

training vectors and their labels were loaded!


## Training 


We now train the model:

In [9]:
import time

#crfsuite is an implementation for Conditional Random Field

!pip install sklearn-crfsuite
from sklearn_crfsuite import CRF

crf = CRF(algorithm = 'lbfgs',
         c1 = 0.1,
         c2 = 0.1,
         max_iterations = 100,
         all_possible_transitions = False)

start = time.time()
crf.fit(X_train, y_train)
stop = time.time()
print(f"Training time: {round(stop - start,2)}s")

print('CRF model was trained!')

Collecting sklearn-crfsuite
  Downloading https://files.pythonhosted.org/packages/25/74/5b7befa513482e6dee1f3dd68171a6c9dfc14c0eaa00f885ffeba54fe9b0/sklearn_crfsuite-0.3.6-py2.py3-none-any.whl
Collecting python-crfsuite>=0.8.3
[?25l  Downloading https://files.pythonhosted.org/packages/79/47/58f16c46506139f17de4630dbcfb877ce41a6355a1bbf3c443edb9708429/python_crfsuite-0.9.7-cp37-cp37m-manylinux1_x86_64.whl (743kB)
[K     |████████████████████████████████| 747kB 5.0MB/s 
Installing collected packages: python-crfsuite, sklearn-crfsuite
Successfully installed python-crfsuite-0.9.7 sklearn-crfsuite-0.3.6
Training time: 33.61s
CRF model was trained!


To save the model, please run this cell:

In [None]:
import pickle
# save the model to disk
pickle.dump(crf, open(model, 'wb'))
print('crf model saved:'+model)

crf model saved:drive/My Drive/Colab Notebooks/nlp4rareNER/models/en/CRF_all.pkl


## Testing the CRF model

Formating the test set

In [10]:
getter = Sentence(df_test)
sentences = getter.sentences
X_test = [sent2features(s) for s in sentences]
y_test = [sent2labels(s) for s in sentences]


print('X_test and y_test loaded!')

X_test and y_test loaded!


In [11]:
load_model=False #By default false

if load_model:      
    crf = pickle.load(open(model, 'rb'))
    print('model was loaded')


#Predicting on the test set.
y_pred = crf.predict(X_test)
print('predictions were obtained!')

predictions were obtained!


## Results

First, we calculate the macro F1 (weighted) including the class 'O' (wchih represents the tokens that are not entities).


In [14]:
from sklearn_crfsuite.metrics import flat_f1_score
from sklearn_crfsuite.metrics import flat_classification_report

#f1_score = flat_f1_score(y_test, y_pred, average = 'weighted')
#print('F1 score:',f1_score)
if allTypes:
    labels=['B-DISEASE','I-DISEASE','B-RAREDISEASE','I-RAREDISEASE','B-SYMPTOM','I-SYMPTOM','B-SIGN','I-SIGN']
else:
    labels=['B-RAREDISEASE','I-RAREDISEASE','B-SIGN-SYM','I-SIGN-SYM']

report = flat_classification_report(y_test, y_pred, labels=labels, digits=4)
print('CRF results  (including the class O):')
print('Language:'+LANG+'\n')

print()
print(report)

CRF results  (including the class O):
Language:en


               precision    recall  f1-score   support

    B-DISEASE     0.7116    0.5124    0.5958       443
    I-DISEASE     0.7133    0.5225    0.6032       400
B-RAREDISEASE     0.8464    0.8369    0.8416      1073
I-RAREDISEASE     0.8681    0.8261    0.8466      1179
    B-SYMPTOM     0.8286    0.5800    0.6824        50
    I-SYMPTOM     0.6429    0.2250    0.3333        80
       B-SIGN     0.5883    0.4894    0.5343       803
       I-SIGN     0.5591    0.3991    0.4658      2215

    micro avg     0.7112    0.5818    0.6400      6243
    macro avg     0.7198    0.5489    0.6129      6243
 weighted avg     0.6945    0.5818    0.6292      6243



However, the inclusion of the class 'O' does not give a realistic average measure of our system. In the following cell, we do not include this class to calculate the micro and macro metrics. 


In [13]:
#f1_score = flat_f1_score(y_test, y_pred, average = 'weighted',labels=['B-DISEASE','B-RAREDISEASE','B-SYMPTOM','I-DISEASE','I-RAREDISEASE','I-SYMPTOM'])
#print('F1 score:',f1_score)

report = flat_classification_report(y_test, y_pred, labels=labels, digits=4)

print('CRF results (without including the class O):')
print('Language:'+LANG+'\n')

print()
print(report)

f=open(scores,"w")
f.write('CRF results (without including the class O)\n')
f.write('Language:'+LANG+'\n')

#f.write(report) 
data=report.split('\n')
for line in data:
    
    if ' avg' in line:
        line=line.replace(' avg','-avg')
    line=line.strip()
    line=' '.join(line.split())

    line=line.replace(' ','\t&\t')+'\\\\'
    if line.startswith('weighted'):
        line='\hline%'+line
    
    print(line)
    f.write(line)


f.close() 
print('scores were saved to ' + scores )

CRF results (without including the class O):
Language:en


               precision    recall  f1-score   support

    B-DISEASE     0.7116    0.5124    0.5958       443
    I-DISEASE     0.7133    0.5225    0.6032       400
B-RAREDISEASE     0.8464    0.8369    0.8416      1073
I-RAREDISEASE     0.8681    0.8261    0.8466      1179
    B-SYMPTOM     0.8286    0.5800    0.6824        50
    I-SYMPTOM     0.6429    0.2250    0.3333        80
       B-SIGN     0.5883    0.4894    0.5343       803
       I-SIGN     0.5591    0.3991    0.4658      2215

    micro avg     0.7112    0.5818    0.6400      6243
    macro avg     0.7198    0.5489    0.6129      6243
 weighted avg     0.6945    0.5818    0.6292      6243

precision	&	recall	&	f1-score	&	support\\
\\
B-DISEASE	&	0.7116	&	0.5124	&	0.5958	&	443\\
I-DISEASE	&	0.7133	&	0.5225	&	0.6032	&	400\\
B-RAREDISEASE	&	0.8464	&	0.8369	&	0.8416	&	1073\\
I-RAREDISEASE	&	0.8681	&	0.8261	&	0.8466	&	1179\\
B-SYMPTOM	&	0.8286	&	0.5800	&	0.6824	&	50\\

In [15]:
!pip install seqeval
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2


f=open(scores,"a")

report = classification_report(y_test, y_pred,  scheme=IOB2, digits=4)
f.write('BiLSTM results on entity (approximate):\n')
print('BiLSTM results on entity (approximate):')
print('Language:'+LANG+'\n')

print()
print(report)

data=report.split('\n')
for line in data:
    if len(line.strip())==0 or 'precision' in line:
        pass
    else:
        if ' avg' in line:
            line=line.replace(' avg','-avg')
        line=line.strip()
        line=' '.join(line.split())

        line=line.replace(' ','\t&\t')+'\\\\'
        if line.startswith('weighted'):
            line='\hline%'+line
        print(line)
        f.write(line)


f.close() 
print('scores were saved to ' + scores )

Collecting seqeval
[?25l  Downloading https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
[K     |███████▌                        | 10kB 14.0MB/s eta 0:00:01[K     |███████████████                 | 20kB 19.9MB/s eta 0:00:01[K     |██████████████████████▌         | 30kB 14.2MB/s eta 0:00:01[K     |██████████████████████████████  | 40kB 11.1MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 4.4MB/s 
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-cp37-none-any.whl size=16184 sha256=0b2725cb14eaeeaf0987e8e6c12dfc51bb0e2deabc1902dd10228c1186df34b1
  Stored in directory: /root/.cache/pip/wheels/52/df/1b/45d75646c37428f7e626214704a0e35bd3cfc32eda37e59e5f
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2
BiLSTM results on en

### Saving the predictions

To save the predictions, please run this cell:

In [None]:
#Saving the results for the error analysis
pickle.dump(y_pred, open(output, 'wb'))
print('prediction saved in ' + output)

prediction saved in drive/My Drive/Colab Notebooks/nlp4rareNER/outputs/en/CRF_all-date:2021-06-23.csv


In [16]:
import logging
from collections import namedtuple
from copy import deepcopy

logging.basicConfig(
    format="%(asctime)s %(name)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="DEBUG",
)

Entity = namedtuple("Entity", "e_type start_offset end_offset")

class Evaluator():

    def __init__(self, true, pred, tags):
        """
        """

        if len(true) != len(pred):
            raise ValueError("Number of predicted documents does not equal true")

        self.true = true
        self.pred = pred
        self.tags = tags

        # Setup dict into which metrics will be stored.

        self.metrics_results = {
            'correct': 0,
            'incorrect': 0,
            'partial': 0,
            'missed': 0,
            'spurious': 0,
            'possible': 0,
            'actual': 0,
            'precision': 0,
            'recall': 0,
        }

        # Copy results dict to cover the four schemes.

        self.results = {
            'strict': deepcopy(self.metrics_results),
            'ent_type': deepcopy(self.metrics_results),
            'partial':deepcopy(self.metrics_results),
            'exact':deepcopy(self.metrics_results),
            }

        # Create an accumulator to store results

        self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags}


    def evaluate(self):

        logging.info(
            "Imported %s predictions for %s true examples",
            len(self.pred), len(self.true)
        )

        for true_ents, pred_ents in zip(self.true, self.pred):

            # Check that the length of the true and predicted examples are the
            # same. This must be checked here, because another error may not
            # be thrown if the lengths do not match.

            if len(true_ents) != len(pred_ents):
                raise ValueError("Prediction length does not match true example length")

            # Compute results for one message

            tmp_results, tmp_agg_results = compute_metrics(
                collect_named_entities(true_ents),
                collect_named_entities(pred_ents),
                self.tags
            )

            # Cycle through each result and accumulate

            # TODO: Combine these loops below:

            for eval_schema in self.results:

                for metric in self.results[eval_schema]:

                    self.results[eval_schema][metric] += tmp_results[eval_schema][metric]

            # Calculate global precision and recall

            self.results = compute_precision_recall_wrapper(self.results)

            # Aggregate results by entity type

            for e_type in self.tags:

                for eval_schema in tmp_agg_results[e_type]:

                    for metric in tmp_agg_results[e_type][eval_schema]:

                        self.evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]

                # Calculate precision recall at the individual entity level

                self.evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(self.evaluation_agg_entities_type[e_type])

        return self.results, self.evaluation_agg_entities_type


def collect_named_entities(tokens):
    """
    Creates a list of Entity named-tuples, storing the entity type and the start and end
    offsets of the entity.
    :param tokens: a list of tags
    :return: a list of Entity named-tuples
    """

    named_entities = []
    start_offset = None
    end_offset = None
    ent_type = None

    for offset, token_tag in enumerate(tokens):

        if token_tag == 'O':
            if ent_type is not None and start_offset is not None:
                end_offset = offset - 1
                named_entities.append(Entity(ent_type, start_offset, end_offset))
                start_offset = None
                end_offset = None
                ent_type = None

        elif ent_type is None:
            ent_type = token_tag[2:]
            start_offset = offset

        elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == 'B'):

            end_offset = offset - 1
            named_entities.append(Entity(ent_type, start_offset, end_offset))

            # start of a new entity
            ent_type = token_tag[2:]
            start_offset = offset
            end_offset = None

    # catches an entity that goes up until the last token

    if ent_type is not None and start_offset is not None and end_offset is None:
        named_entities.append(Entity(ent_type, start_offset, len(tokens)-1))

    return named_entities


def compute_metrics(true_named_entities, pred_named_entities, tags):


    eval_metrics = {'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 0, 'spurious': 0, 'precision': 0, 'recall': 0}

    # overall results
    
    evaluation = {
        'strict': deepcopy(eval_metrics),
        'ent_type': deepcopy(eval_metrics),
        'partial': deepcopy(eval_metrics),
        'exact': deepcopy(eval_metrics)
    }

    # results by entity type

    evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags}

    # keep track of entities that overlapped

    true_which_overlapped_with_pred = []

    # Subset into only the tags that we are interested in.
    # NOTE: we remove the tags we don't want from both the predicted and the
    # true entities. This covers the two cases where mismatches can occur:
    #
    # 1) Where the model predicts a tag that is not present in the true data
    # 2) Where there is a tag in the true data that the model is not capable of
    # predicting.

    true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags]
    pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags]

    # go through each predicted named-entity

    for pred in pred_named_entities:
        found_overlap = False

        # Check each of the potential scenarios in turn. See
        # http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/
        # for scenario explanation.

        # Scenario I: Exact match between true and pred

        if pred in true_named_entities:
            true_which_overlapped_with_pred.append(pred)
            evaluation['strict']['correct'] += 1
            evaluation['ent_type']['correct'] += 1
            evaluation['exact']['correct'] += 1
            evaluation['partial']['correct'] += 1

            # for the agg. by e_type results
            evaluation_agg_entities_type[pred.e_type]['strict']['correct'] += 1
            evaluation_agg_entities_type[pred.e_type]['ent_type']['correct'] += 1
            evaluation_agg_entities_type[pred.e_type]['exact']['correct'] += 1
            evaluation_agg_entities_type[pred.e_type]['partial']['correct'] += 1

        else:

            # check for overlaps with any of the true entities

            for true in true_named_entities:

                pred_range = range(pred.start_offset, pred.end_offset)
                true_range = range(true.start_offset, true.end_offset)

                # Scenario IV: Offsets match, but entity type is wrong

                if true.start_offset == pred.start_offset and pred.end_offset == true.end_offset \
                        and true.e_type != pred.e_type:

                    # overall results
                    evaluation['strict']['incorrect'] += 1
                    evaluation['ent_type']['incorrect'] += 1
                    evaluation['partial']['correct'] += 1
                    evaluation['exact']['correct'] += 1

                    # aggregated by entity type results
                    evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1
                    evaluation_agg_entities_type[true.e_type]['ent_type']['incorrect'] += 1
                    evaluation_agg_entities_type[true.e_type]['partial']['correct'] += 1
                    evaluation_agg_entities_type[true.e_type]['exact']['correct'] += 1

                    true_which_overlapped_with_pred.append(true)
                    found_overlap = True

                    break

                # check for an overlap i.e. not exact boundary match, with true entities

                elif find_overlap(true_range, pred_range):

                    true_which_overlapped_with_pred.append(true)

                    # Scenario V: There is an overlap (but offsets do not match
                    # exactly), and the entity type is the same.
                    # 2.1 overlaps with the same entity type

                    if pred.e_type == true.e_type:

                        # overall results
                        evaluation['strict']['incorrect'] += 1
                        evaluation['ent_type']['correct'] += 1
                        evaluation['partial']['partial'] += 1
                        evaluation['exact']['incorrect'] += 1

                        # aggregated by entity type results
                        evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1
                        evaluation_agg_entities_type[true.e_type]['ent_type']['correct'] += 1
                        evaluation_agg_entities_type[true.e_type]['partial']['partial'] += 1
                        evaluation_agg_entities_type[true.e_type]['exact']['incorrect'] += 1

                        found_overlap = True

                        break

                    # Scenario VI: Entities overlap, but the entity type is
                    # different.

                    else:
                        # overall results
                        evaluation['strict']['incorrect'] += 1
                        evaluation['ent_type']['incorrect'] += 1
                        evaluation['partial']['partial'] += 1
                        evaluation['exact']['incorrect'] += 1

                        # aggregated by entity type results
                        # Results against the true entity

                        evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1
                        evaluation_agg_entities_type[true.e_type]['partial']['partial'] += 1
                        evaluation_agg_entities_type[true.e_type]['ent_type']['incorrect'] += 1
                        evaluation_agg_entities_type[true.e_type]['exact']['incorrect'] += 1

                        # Results against the predicted entity

                        # evaluation_agg_entities_type[pred.e_type]['strict']['spurious'] += 1

                        found_overlap = True

                        break

            # Scenario II: Entities are spurious (i.e., over-generated).

            if not found_overlap:

                # Overall results

                evaluation['strict']['spurious'] += 1
                evaluation['ent_type']['spurious'] += 1
                evaluation['partial']['spurious'] += 1
                evaluation['exact']['spurious'] += 1

                # Aggregated by entity type results

                # NOTE: when pred.e_type is not found in tags
                # or when it simply does not appear in the test set, then it is
                # spurious, but it is not clear where to assign it at the tag
                # level. In this case, it is applied to all target_tags
                # found in this example. This will mean that the sum of the
                # evaluation_agg_entities will not equal evaluation.

                for true in tags:                    

                    evaluation_agg_entities_type[true]['strict']['spurious'] += 1
                    evaluation_agg_entities_type[true]['ent_type']['spurious'] += 1
                    evaluation_agg_entities_type[true]['partial']['spurious'] += 1
                    evaluation_agg_entities_type[true]['exact']['spurious'] += 1

    # Scenario III: Entity was missed entirely.

    for true in true_named_entities:
        if true in true_which_overlapped_with_pred:
            continue
        else:
            # overall results
            evaluation['strict']['missed'] += 1
            evaluation['ent_type']['missed'] += 1
            evaluation['partial']['missed'] += 1
            evaluation['exact']['missed'] += 1

            # for the agg. by e_type
            evaluation_agg_entities_type[true.e_type]['strict']['missed'] += 1
            evaluation_agg_entities_type[true.e_type]['ent_type']['missed'] += 1
            evaluation_agg_entities_type[true.e_type]['partial']['missed'] += 1
            evaluation_agg_entities_type[true.e_type]['exact']['missed'] += 1

    # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the
    # overall results, and use these to calculate precision and recall.

    for eval_type in evaluation:
        evaluation[eval_type] = compute_actual_possible(evaluation[eval_type])

    # Compute 'possible', 'actual', and precision and recall on entity level
    # results. Start by cycling through the accumulated results.

    for entity_type, entity_level in evaluation_agg_entities_type.items():

        # Cycle through the evaluation types for each dict containing entity
        # level results.

        for eval_type in entity_level:

            evaluation_agg_entities_type[entity_type][eval_type] = compute_actual_possible(
                entity_level[eval_type]
            )

    return evaluation, evaluation_agg_entities_type


def find_overlap(true_range, pred_range):
    """Find the overlap between two ranges
    Find the overlap between two ranges. Return the overlapping values if
    present, else return an empty set().
    Examples:
    >>> find_overlap((1, 2), (2, 3))
    2
    >>> find_overlap((1, 2), (3, 4))
    set()
    """

    true_set = set(true_range)
    pred_set = set(pred_range)

    overlaps = true_set.intersection(pred_set)

    return overlaps


def compute_actual_possible(results):
    """
    Takes a result dict that has been output by compute metrics.
    Returns the results dict with actual, possible populated.
    When the results dicts is from partial or ent_type metrics, then
    partial_or_type=True to ensure the right calculation is used for
    calculating precision and recall.
    """

    correct = results['correct']
    incorrect = results['incorrect']
    partial = results['partial']
    missed = results['missed']
    spurious = results['spurious']

    # Possible: number annotations in the gold-standard which contribute to the
    # final score

    possible = correct + incorrect + partial + missed

    # Actual: number of annotations produced by the NER system

    actual = correct + incorrect + partial + spurious

    results["actual"] = actual
    results["possible"] = possible

    return results


def compute_precision_recall(results, partial_or_type=False):
    """
    Takes a result dict that has been output by compute metrics.
    Returns the results dict with precison and recall populated.
    When the results dicts is from partial or ent_type metrics, then
    partial_or_type=True to ensure the right calculation is used for
    calculating precision and recall.
    """

    actual = results["actual"]
    possible = results["possible"]
    partial = results['partial']
    correct = results['correct']

    if partial_or_type:
        precision = (correct + 0.5 * partial) / actual if actual > 0 else 0
        recall = (correct + 0.5 * partial) / possible if possible > 0 else 0

    else:
        precision = correct / actual if actual > 0 else 0
        recall = correct / possible if possible > 0 else 0

    results["precision"] = precision
    results["recall"] = recall

    return results


def compute_precision_recall_wrapper(results):
    """
    Wraps the compute_precision_recall function and runs on a dict of results
    """

    results_a = {key: compute_precision_recall(value, True) for key, value in results.items() if
                 key in ['partial', 'ent_type']}
    results_b = {key: compute_precision_recall(value) for key, value in results.items() if
                 key in ['strict', 'exact']}

    results = {**results_a, **results_b}

    return results

In [17]:
import random

random.seed(42)
result_examples_idx = random.sample(range(len(y_test)), k=30)
result_examples_y_gold = list()
result_examples_y_pred = list()
original_sentences = list()

for idx in result_examples_idx:
    result_examples_y_gold.append(y_test[idx])
    result_examples_y_pred.append(y_pred[idx])
    original_sentences.append(sentences[idx])

itr = 0
for g, p, s in zip(result_examples_y_gold, result_examples_y_pred, original_sentences):
    assert len(g) == len(p), 'Results does not seem to be the same'
    print('Sentence Nr: ', result_examples_idx[itr])
    original_s = pd.Series(s, name='WORD')
    df = pd.DataFrame(original_s)
    df['GOLD'] = g
    df['PRED'] = p
    print(df)
    itr += 1
    print()

Sentence Nr:  1309
                            WORD           GOLD           PRED
0                  (The, DET, O)              O              O
1            (symptoms, NOUN, O)              O              O
2                (and, CCONJ, O)              O              O
3                  (the, DET, O)              O              O
4             (physical, ADJ, O)              O              O
5            (findings, NOUN, O)              O              O
6          (associated, VERB, O)              O              O
7                 (with, ADP, O)              O              O
8   (ASLD, PROPN, B-RAREDISEASE)  B-RAREDISEASE  B-RAREDISEASE
9                (vary, VERB, O)              O              O
10             (greatly, ADV, O)              O              O
11                (from, ADP, O)              O              O
12               (case, NOUN, O)              O              O
13                  (to, ADP, O)              O              O
14               (case, NOUN, O)    

In [18]:
if allTypes:
    test_labels = ['DISEASE', 'RAREDISEASE', 'SYMPTOM', 'SIGN']
else:
    test_labels = ['RAREDISEASE', 'SIGN-SYM']

test_to_use_gold = result_examples_y_gold
test_to_use_pred = result_examples_y_pred

evaluator_examples = Evaluator(test_to_use_gold, test_to_use_pred, test_labels)
results_examples, results_agg_examples = evaluator_examples.evaluate()

print('## OVERALL RESULTS')
for item in results_examples.keys():
    print('\tEvaluation Metric: ', item)
    print('\t', results_examples[item])
print('## RESULTS AT ENTITY LEVEL')
for entity in results_agg_examples.keys():
    print('Entity: ', entity)
    for item in results_agg_examples[entity].keys():
        print('\tEvaluation Metric: ', item)
        print('\t', results_agg_examples[entity][item])

2021-07-08 17:15:48 root INFO: Imported 30 predictions for 30 true examples


## OVERALL RESULTS
	Evaluation Metric:  ent_type
	 {'correct': 32, 'incorrect': 8, 'partial': 0, 'missed': 10, 'spurious': 6, 'possible': 50, 'actual': 46, 'precision': 0.6956521739130435, 'recall': 0.64}
	Evaluation Metric:  partial
	 {'correct': 39, 'incorrect': 0, 'partial': 1, 'missed': 10, 'spurious': 6, 'possible': 50, 'actual': 46, 'precision': 0.8586956521739131, 'recall': 0.79}
	Evaluation Metric:  strict
	 {'correct': 31, 'incorrect': 9, 'partial': 0, 'missed': 10, 'spurious': 6, 'possible': 50, 'actual': 46, 'precision': 0.6739130434782609, 'recall': 0.62}
	Evaluation Metric:  exact
	 {'correct': 39, 'incorrect': 1, 'partial': 0, 'missed': 10, 'spurious': 6, 'possible': 50, 'actual': 46, 'precision': 0.8478260869565217, 'recall': 0.78}
## RESULTS AT ENTITY LEVEL
Entity:  DISEASE
	Evaluation Metric:  ent_type
	 {'correct': 6, 'incorrect': 7, 'partial': 0, 'missed': 6, 'spurious': 6, 'possible': 19, 'actual': 19, 'precision': 0.3157894736842105, 'recall': 0.3157894736842105}
	

In [19]:
evaluator_all = Evaluator(y_test, y_pred, test_labels)
results_all, results_agg_all = evaluator_all.evaluate()

print('## OVERALL RESULTS')
for item in results_all.keys():
    print('\tEvaluation Metric: ', item)
    print('\t', results_all[item])
print('## RESULTS AT ENTITY LEVEL')
for entity in results_agg_all.keys():
    print('Entity: ', entity)
    for item in results_agg_all[entity].keys():
        print('\tEvaluation Metric: ', item)
        print('\t', results_agg_all[entity][item])

2021-07-08 17:15:51 root INFO: Imported 1772 predictions for 1772 true examples


## OVERALL RESULTS
	Evaluation Metric:  ent_type
	 {'correct': 1641, 'incorrect': 223, 'partial': 0, 'missed': 701, 'spurious': 283, 'possible': 2565, 'actual': 2147, 'precision': 0.7643223102002794, 'recall': 0.639766081871345}
	Evaluation Metric:  partial
	 {'correct': 1729, 'incorrect': 0, 'partial': 135, 'missed': 701, 'spurious': 283, 'possible': 2565, 'actual': 2147, 'precision': 0.8367489520260829, 'recall': 0.7003898635477583}
	Evaluation Metric:  strict
	 {'correct': 1527, 'incorrect': 337, 'partial': 0, 'missed': 701, 'spurious': 283, 'possible': 2565, 'actual': 2147, 'precision': 0.7112249650675361, 'recall': 0.5953216374269006}
	Evaluation Metric:  exact
	 {'correct': 1729, 'incorrect': 135, 'partial': 0, 'missed': 701, 'spurious': 283, 'possible': 2565, 'actual': 2147, 'precision': 0.8053097345132744, 'recall': 0.674074074074074}
## RESULTS AT ENTITY LEVEL
Entity:  DISEASE
	Evaluation Metric:  ent_type
	 {'correct': 232, 'incorrect': 124, 'partial': 0, 'missed': 98, 'spuri