## Fine-tuning ClinicalBERT for clinical assertion detection

In this blog post, we will show how to fine-tune a bert language model for a downstream task: clinical assertion detection. We are going to leverage the Hugging Face transformer library and the model hub. 

We will show how  :

1. Load and prepare the data for assertion dectection

2. Fine-tune an auto-encoding language model such as Clinical BERT

3. Evaluate and run Inference with the trained model

### 1. Background and Context 

#### Clinical assertion dectection

This work is based on the paper "Assertion Detection in Clinical Notes: Medical Language Models to the Rescue?" , using Language model for assertion detection. Assertion detection is the task to identify the assertion of an entity based on textual cues in unstructured text. In other words we want to classify the assertions made on given medical concepts as being :
* present
* absent
* possible in the patient
* conditionally present in the patient under certain circumstances
* hypothtically present in the patient at some future point
* mentioned in the patient report but associated with somenone else

For example given the text "The patient recovered during the night and now denies any shortness of breath.", the model should identify that the entity: shortness of breath is absent. 

#### The data

For this demo we use The 2010 i2b2/VA Workshop on Natural Language Processing Challenges for Clinical Records presented three tasks: a concept extraction task focused on the extraction of medical concepts from patient reports; an assertion classification task focused on assigning assertion types for medical problem concepts; 
and a relation classification task focused on assigning relation types that hold between medical problems, tests, and treatments. These are be available to the research community from [i2b2](https://i2b2.org/NLP/DataSets) portal under data use agreements. For more information please consult the paper [2010 i2b2/VA challenge on concepts, assertions, and relations in clinical text](https://academic.oup.com/jamia/article/18/5/552/830538)

You need to request access, download and extract  the data needed.

### 2. Install the dependencies

In this example we will use the Pytorch and HuggingFace library, an run the experiemnt on a Google Colab. You will also need to install spacy and the biomedical pretrained model  **en_ner_bc5cdr_md** a spaCy NER model trained on the BC5CDR corpus. The model en_ner_bc5cdr_md was trained for DISEASE and CHEMICAL entity recognition. To install all the dependencies run the following cell.

In [1]:
# install all the libraries and dependencies
#%pip install -r requirements.txt
#python=3.8
# conda install matplotlib numpy scikit-learn
# conda install pandas
# pip install spacy
# pip install scispacy ? gave error ignored
# pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_ner_bc5cdr_md-0.4.0.tar.gz
# conda install pytorch torchvision -c pytorch
# pip install ipykernel
# pip install transformers
# pip install datasets
# pip install evaluate
# pip install accelerate -U

### 3. Load and prepare the data

We will use the assertion classification data from i2b2, which consist of XXXXX records of discharge summary notes. 

In [187]:
# Get paths for data and labels
import os 
cwd  = os.getcwd()
labels_path_beth = os.path.join(cwd,"Data/concept_assertion_relation_training_data","beth","ast")
data_path_beth = os.path.join(cwd,"Data/concept_assertion_relation_training_data","beth","txt")
labels_path_partners = os.path.join(cwd,"Data/concept_assertion_relation_training_data","partners","ast")
data_path_partners = os.path.join(cwd,"Data/concept_assertion_relation_training_data","partners","txt")
print(labels_path_beth)
print(data_path_beth)
print(labels_path_partners)
print(data_path_partners)

C:\Users\kcaro\Documents\GitHub\clinical-adapter\Data/concept_assertion_relation_training_data\beth\ast
C:\Users\kcaro\Documents\GitHub\clinical-adapter\Data/concept_assertion_relation_training_data\beth\txt
C:\Users\kcaro\Documents\GitHub\clinical-adapter\Data/concept_assertion_relation_training_data\partners\ast
C:\Users\kcaro\Documents\GitHub\clinical-adapter\Data/concept_assertion_relation_training_data\partners\txt


In [188]:
# Create list of .txt clinical notes for each dataset (beth and partners)
def list_files_in_directory(directory_path):
    ignore = ['.DS_Store']
    files = os.listdir(directory_path)
    files = [file for file in files if os.path.isfile(os.path.join(directory_path, file)) and file not in ignore]
    files = [file[:-4] for file in files]
    
    return files

notes_beth = list_files_in_directory(data_path_beth)
notes_partners = list_files_in_directory(data_path_partners)

In [189]:
# Loop through each clinical note and read in the content as a list of tuples [(note_id, note_content), ...]
def load_clinical_notes(notes, data_path):
    content_notes = []
    for note in notes:
        _file = os.path.join(data_path, note + '.txt')
        with open(_file) as f:
            content = f.read()
            content_notes.append((note,content))
    f.close()

    return content_notes

content_beth = load_clinical_notes(notes_beth, data_path_beth)
content_partners = load_clinical_notes(notes_partners, data_path_partners)

# Merge the content into one list
content = content_beth + content_partners
print("Number of beth notes:", len(content_beth))
print("Number of partners notes:", len(content_partners))
print("Number of combined notes:", len(content))

Number of beth notes: 73
Number of partners notes: 97
Number of combined notes: 170


In [190]:
# Split each note into sentences using the spacy biomedical pretrained model (en_ner_bc5cdr_md) ~22 seconds
import en_ner_bc5cdr_md
import spacy

def split_note_sentences(content_records):
    # load spacy
    nlp1 = spacy.load("en_ner_bc5cdr_md",disable = ['parser'])
    nlp1.add_pipe('sentencizer')

    # transform the data into a list of sentences
    docs = [(r,nlp1(text)) for r,text in content_records]
    data = []
    for r,doc in docs:
        for s in doc.sents:
            sentence = str.strip(str(s))
            sentence = sentence.replace("\n"," ")
            data.append((r,sentence))
    
    return data

data = split_note_sentences(content)



In [191]:
data[0]

('record-105',
 'Admission Date : 2017-06-13 Discharge Date : 2017-06-17 Date of Birth : 1956-02-17 Sex : M Service : CARDIOTHORACIC Allergies : Patient recorded as having No Known Allergies to Drugs Attending:Jordan U Kostohryz , M.D. Chief Complaint: recent mild angina with exertion Major Surgical or Invasive Procedure :  emergency CABG X 3 ( 2017-06-13 )( LIMA to LAD , SVG to ramus , SVG to OM ) History of Present Illness : 61 yo African-American-Hispanic male had abnormal EKG found as part of pre-op eval.')

The next step is to load and process the labels. They are provided as ast files.

In [127]:
notes_

'record-105'

In [171]:
labels_notes = []
_file = os.path.join(labels_path_beth, notes_beth[0] + '.ast')
with open(_file) as f:
    content = f.readlines()
    file_data = []
    
    ast = line.strip().split('||')
    entity_text = re.findall('"([^"]*)"', ast[0])
    pattern = re.compile(r'c="[^"]*"([^"]+)')
    loc = pattern.search(ast[0]).group(1)

    for line in content:
        ast = line.strip().split('||')
        line_entity = []

        assertion = ast[2].split('=')
        entity_label = ast[1].split("=")
        entity_text = re.findall('"([^"]*)"', ast[0])
        pattern = re.compile(r'c="[^"]*"([^"]+)')
        loc = pattern.search(ast[0]).group(1)
        
        line_entity.append(notes_beth[0])
        line_entity.append(assertion[1].replace('"',''))
        line_entity.append(entity_label[1].replace('"',''))
        line_entity.append(entity_text[0])
        line_entity.append(loc[1:])
        file_data.append(line_entity)
labels_notes.append(file_data)

f.close()

c="left basilar atelectasis" 55:6 55:8||t="problem"||a="present"

['c="bilateral pleural effusions" 57:0 57:2', 't="problem"', 'a="present"']
['bilateral pleural effusions']
 57:0 57:2





In [131]:
labels_notes

[[['record-105', 'present', 'problem', 'left basilar atelectasis'],
  ['record-105', 'present', 'problem', 'ventral hernia'],
  ['record-105', 'present', 'problem', 'htn'],
  ['record-105', 'absent', 'problem', 'spontaneous echo contrast'],
  ['record-105', 'present', 'problem', '80% lm lesion'],
  ['record-105', 'absent', 'problem', 'interstitial edema'],
  ['record-105', 'present', 'problem', 'abnormal ekg'],
  ['record-105', 'present', 'problem', 'htn'],
  ['record-105', 'conditional', 'problem', 'recent mild angina'],
  ['record-105', 'present', 'problem', 'perfusion defects'],
  ['record-105', 'present', 'problem', 'ai'],
  ['record-105', 'present', 'problem', 'cad'],
  ['record-105', 'absent', 'problem', 'known allergies'],
  ['record-105', 'present', 'problem', 'abnormal stress test'],
  ['record-105', 'present', 'problem', 'mildly thickened'],
  ['record-105', 'present', 'problem', 'trace aortic regurgitation'],
  ['record-105',
   'present',
   'problem',
   'mild symmetric le

In [180]:
# Load and process the labels into a list of lists for each note. [[['note_id', 'assertion', 'lable', 'entity'], ...], ...]
# They are provided as ast files.
import re
import pandas as pd 

def load_notes_labels(notes, labels_path):
    labels_notes = []
    for note in notes:
        _file = os.path.join(labels_path, note + '.ast')
        with open(_file) as f:
            content = f.readlines()
            file_data = []
            for line in content:
                ast = line.strip().split('||')
                line_entity = []

                assertion = ast[2].split('=')
                entity_label = ast[1].split("=")
                entity_text = re.findall('"([^"]*)"', ast[0])
                pattern = re.compile(r'c="[^"]*"([^"]+)')
                loc = pattern.search(ast[0]).group(1)

                line_entity.append(note)
                line_entity.append(assertion[1].replace('"',''))
                line_entity.append(entity_label[1].replace('"',''))
                line_entity.append(entity_text[0])
                line_entity.append(loc[1:])
                file_data.append(line_entity)
        labels_notes.append(file_data)
        

    f.close()

    return labels_notes

labels_beth  = load_notes_labels(notes_beth, labels_path_beth)
labels_partners = load_notes_labels(notes_partners, labels_path_partners)

# Merge the labels into one list
labels = labels_beth + labels_partners
print("Number of beth labels:", len(labels_beth))
print("Number of partners labels:", len(labels_partners))
print("Number of combined labels:", len(labels))

# labels in a dataframe
data_labels = [line for f in labels for line in f]
df_data_labels = pd.DataFrame(data_labels,columns=['record','assertion','label','entity', 'loc'])

print("Number of entities classified:", len(df_data_labels))
df_data_labels.head()

Number of beth labels: 73
Number of partners labels: 97
Number of combined labels: 170
Number of entities classified: 7073


Unnamed: 0,record,assertion,label,entity,loc
0,record-105,present,problem,left basilar atelectasis,55:6 55:8
1,record-105,present,problem,ventral hernia,143:1 143:2
2,record-105,present,problem,htn,26:0 26:0
3,record-105,absent,problem,spontaneous echo contrast,68:1 68:3
4,record-105,present,problem,80% lm lesion,21:6 21:8


In [182]:
df_data_labels[df_data_labels['record'] == 'record-15']

Unnamed: 0,record,assertion,label,entity,loc
1226,record-15,present,problem,very mild left lower extremity pain,15:2 15:7
1227,record-15,present,problem,congestive heart failure,20:1 20:3
1228,record-15,present,problem,diabetes mellitus,24:1 24:2
1229,record-15,present,problem,acute renal failure,21:1 21:3
1230,record-15,present,problem,osteomyelitis,34:12 34:12
1231,record-15,present,problem,osteomyelitis,33:13 33:13
1232,record-15,present,problem,osteomyelitis,23:1 23:1
1233,record-15,present,problem,acute myocardial infarction,22:1 22:3
1234,record-15,present,problem,coronary artery disease,19:1 19:3


In [177]:
labels_beth[0]

[['record-105', 'present', 'problem', 'left basilar atelectasis', '55:6 55:8'],
 ['record-105', 'present', 'problem', 'ventral hernia', '143:1 143:2'],
 ['record-105', 'present', 'problem', 'htn', '26:0 26:0'],
 ['record-105', 'absent', 'problem', 'spontaneous echo contrast', '68:1 68:3'],
 ['record-105', 'present', 'problem', '80% lm lesion', '21:6 21:8'],
 ['record-105', 'absent', 'problem', 'interstitial edema', '54:2 54:3'],
 ['record-105', 'present', 'problem', 'abnormal ekg', '18:5 18:6'],
 ['record-105', 'present', 'problem', 'htn', '143:0 143:0'],
 ['record-105', 'conditional', 'problem', 'recent mild angina', '14:2 14:4'],
 ['record-105', 'present', 'problem', 'perfusion defects', '19:12 19:13'],
 ['record-105', 'present', 'problem', 'ai', '98:3 98:3'],
 ['record-105', 'present', 'problem', 'cad', '142:0 142:0'],
 ['record-105', 'absent', 'problem', 'known allergies', '12:5 12:6'],
 ['record-105', 'present', 'problem', 'abnormal stress test', '19:2 19:4'],
 ['record-105', 'pre

In [8]:
# len(df_data_labels)
# len(df_data_labels[df_data_labels['label'] == 'problem'])

In [83]:
len(data), len(new_data)

(9921, 10679)

### 4. Annotate text for clinical assertion detection

After pre-processing the data we need to annotate each entity in our training data between the token '[entity]' .

In [9]:
import re

def clean_text(text):
    """
    Applies some pre-processing on the given text.

    Steps :
    - Removing HTML tags
    - Removing punctuation
    - Lowering text
    """
    
    # remove HTML tags
    text = re.sub(r'<.*?>', '', text)
    
    # remove the characters [\], ['] and ["]
    text = re.sub(r"\\", "", text)    
    text = re.sub(r"\'", "", text)    
    text = re.sub(r"\"", "", text)    
    
    # convert text to lowercase
    text = text.strip().lower()
    
    # remove all non-ASCII characters:
    text = re.sub(r'[^\x00-\x7f]',r'', text) 
    
    # replace punctuation characters with spaces
    filters='!"\'#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n'
    translate_dict = dict((c, " ") for c in filters)
    translate_map = str.maketrans(translate_dict)
    text = text.translate(translate_map)
    text = " ".join(text.split())
    
    return text

First, we need to map the data with the labels. We do so by using the record id and searching the entity name in the text to do the mapping.

In [10]:
def linking_train_labels_data(data, df_data_labels):
    new_data = []
    for r , sent in data:
        
        
        for index,row in df_data_labels.loc[df_data_labels['record'] == r,['entity','assertion']].iterrows():
            entity = clean_text(row['entity'])
            sentence = clean_text(sent)
            #print(entity,sent)
            try:
                if re.search(r'\b' + str(entity) + r'\b', str(sentence)):
                    new_data.append((r,entity,sentence,row['assertion']))
            except:
                print(r)
                print("entity:",str(entity))
                print("****")
                
    return new_data

In [81]:
len(data), len(df_data_labels)

(9921, 7073)

In [84]:
len(df_data_labels)

7073

In [72]:
df_data_labels.head()

Unnamed: 0,record,assertion,label,entity
0,record-105,present,problem,left basilar atelectasis
1,record-105,present,problem,ventral hernia
2,record-105,present,problem,htn
3,record-105,absent,problem,spontaneous echo contrast
4,record-105,present,problem,80 lm lesion


In [None]:
data

In [11]:
# ~50 seconds
new_data = linking_train_labels_data(data,df_data_labels)

In [77]:
new_data[0]

('record-105',
 'abnormal ekg',
 'admission date 2017 06 13 discharge date 2017 06 17 date of birth 1956 02 17 sex m service cardiothoracic allergies patient recorded as having no known allergies to drugs attending jordan u kostohryz m d chief complaint recent mild angina with exertion major surgical or invasive procedure emergency cabg x 3 2017 06 13 lima to lad svg to ramus svg to om history of present illness 61 yo african american hispanic male had abnormal ekg found as part of pre op eval',
 'present')

Then, we can annotate each sentence with the token [entity].

In [13]:
def annotate_data(new_data):
    processed_data = []
    for r,entity,text,label in new_data:
        #print(text)
        match = re.search(r'\b' + entity + r'\b',text)

        res = list(text)
        res.insert(match.start(), '[entity] ')
        res.insert(match.end()+1, ' [entity]')
        res = ''.join(res)
        processed_data.append((r,entity,res,label))  

    return processed_data

processed_data =  annotate_data(new_data)

In [42]:
processed_data[0]

('record-105',
 'abnormal ekg',
 'admission date 2017 06 13 discharge date 2017 06 17 date of birth 1956 02 17 sex m service cardiothoracic allergies patient recorded as having no known allergies to drugs attending jordan u kostohryz m d chief complaint recent mild angina with exertion major surgical or invasive procedure emergency cabg x 3 2017 06 13 lima to lad svg to ramus svg to om history of present illness 61 yo african american hispanic male had [entity] abnormal ekg [entity] found as part of pre op eval',
 'present')

Finally , we create a dataframe where for our example we only keep 3 assertions labels : present , absent and possible.

In [15]:
prepare_data = [{'record': r, 'sentence':text , 'label':label,'idx':idx} for idx,(r, entity, text,label) in enumerate(processed_data)]

df_i2b2 = pd.DataFrame(prepare_data)
df_i2b2 = df_i2b2[(df_i2b2.label=='present') | (df_i2b2.label=='absent') | (df_i2b2.label=='possible') ].copy()
df_i2b2

Unnamed: 0,record,sentence,label,idx
0,record-105,admission date 2017 06 13 discharge date 2017 ...,present,0
2,record-105,admission date 2017 06 13 discharge date 2017 ...,absent,2
3,record-105,for [entity] ventral hernia [entity] repair,present,3
4,record-105,for [entity] ventral hernia [entity] repair,present,4
5,record-105,had subsequent abnormal stress test and pefusi...,present,5
...,...,...,...,...
10662,989519730_WGH,when she was taking clear liquids her [entity]...,present,10662
10665,989519730_WGH,[entity] pain [entity] regimen with good affect,present,10665
10668,989519730_WGH,with the help of physical therapy the patient ...,present,10668
10671,989519730_WGH,[entity] pain [entity],present,10671


In [45]:
df_i2b2[df_i2b2['record'] == 'record-105'][48:70]

Unnamed: 0,record,sentence,label,idx
49,record-105,the right ventricular cavity is [entity] mildl...,present,49
50,record-105,there are [entity] simple atheroma [entity] in...,present,50
51,record-105,there are [entity] simple atheroma [entity] in...,present,51
52,record-105,there are [entity] simple atheroma [entity] in...,present,52
53,record-105,there are [entity] simple atheroma [entity] in...,present,53
54,record-105,there are [entity] simple atheroma [entity] in...,present,54
55,record-105,there are [entity] simple atheroma [entity] in...,present,55
56,record-105,there are [entity] simple atheroma [entity] in...,present,56
57,record-105,there are [entity] simple atheroma [entity] in...,present,57
58,record-105,there are [entity] simple atheroma [entity] in...,present,58


In [49]:
print(df_i2b2.loc[50].values[1])
print(df_i2b2.loc[51].values[1])
print(df_i2b2.loc[52].values[1])
print(df_i2b2.loc[53].values[1])
print(df_i2b2.loc[54].values[1])
print(df_i2b2.loc[55].values[1])
print(df_i2b2.loc[56].values[1])
print(df_i2b2.loc[57].values[1])
print(df_i2b2.loc[58].values[1])
print(df_i2b2.loc[58].values[1])
print(df_i2b2.loc[59].values[1])
print(df_i2b2.loc[60].values[1])
print(df_i2b2.loc[61].values[1])
print(df_i2b2.loc[62].values[1])
print(df_i2b2.loc[63].values[1])
print(df_i2b2.loc[64].values[1])
print(df_i2b2.loc[65].values[1])

there are [entity] simple atheroma [entity] in the aortic root
there are [entity] simple atheroma [entity] in the aortic root
there are [entity] simple atheroma [entity] in the aortic root
there are [entity] simple atheroma [entity] in the aortic root
there are [entity] simple atheroma [entity] in the ascending aorta
there are [entity] simple atheroma [entity] in the ascending aorta
there are [entity] simple atheroma [entity] in the ascending aorta
there are [entity] simple atheroma [entity] in the ascending aorta
there are [entity] simple atheroma [entity] in the aortic arch
there are [entity] simple atheroma [entity] in the aortic arch
there are [entity] simple atheroma [entity] in the aortic arch
there are [entity] simple atheroma [entity] in the aortic arch
there are [entity] simple atheroma [entity] in the aortic arch
there are [entity] simple atheroma [entity] in the descending thoracic aorta
there are [entity] simple atheroma [entity] in the descending thoracic aorta
there are [

In [18]:
df_i2b2.iloc[0]

record                                             record-105
sentence    admission date 2017 06 13 discharge date 2017 ...
label                                                 present
idx                                                         0
Name: 0, dtype: object

In [54]:
prepare_data = [{'record': r, 'sentence': text.strip(), 'label': label, 'idx': idx}
                for idx, (r, entity, text, label) in enumerate(processed_data)]

df_i2b2_2 = pd.DataFrame(prepare_data)
df_i2b2_2 = df_i2b2_2[(df_i2b2_2.label == 'present') | (df_i2b2_2.label == 'absent') | (df_i2b2_2.label == 'possible')].copy()

In [56]:
df_i2b2_2[df_i2b2_2['record'] == 'record-105'][48:70]

Unnamed: 0,record,sentence,label,idx
49,record-105,the right ventricular cavity is [entity] mildl...,present,49
50,record-105,there are [entity] simple atheroma [entity] in...,present,50
51,record-105,there are [entity] simple atheroma [entity] in...,present,51
52,record-105,there are [entity] simple atheroma [entity] in...,present,52
53,record-105,there are [entity] simple atheroma [entity] in...,present,53
54,record-105,there are [entity] simple atheroma [entity] in...,present,54
55,record-105,there are [entity] simple atheroma [entity] in...,present,55
56,record-105,there are [entity] simple atheroma [entity] in...,present,56
57,record-105,there are [entity] simple atheroma [entity] in...,present,57
58,record-105,there are [entity] simple atheroma [entity] in...,present,58


In [70]:
for i in new_data:
    print(i[1])

abnormal ekg
recent mild angina
known allergies
ventral hernia
ventral hernia
perfusion defects
abnormal stress test
mild mr
inferior hk
mild lae
mild lvh
80 lm lesion
htn
htn
severe systolic htn
ventral hernia
htn
htn
ventral hernia
right facial droop
gsw
interstitial edema
left basilar atelectasis
right basilar atelectasis
bilateral pleural effusions
bibasilar minor atelectasis
persistent pleural effusions
mildly dilated
mildly dilated
spontaneous echo contrast
spontaneous echo contrast
spontaneous echo contrast
thrombus
spontaneous echo contrast
mass
mass thrombus
thrombus
atrial septal defect
mild symmetric left ventricular hypertrophy
mild symmetric left ventricular hypertrophy
mild symmetric left ventricular hypertrophy
mild symmetric left ventricular hypertrophy
left ventricular aneurysm
mild regional left ventricular systolic dysfunction
mildly depressed
ventricular septal defect
mild inferior hypokinesis
resting regional wall motion abnormalities
mildly dilated
mildly dilated


In [86]:
len(df_i2b2)

9863

### 5.  Splitting the data and create dataset

We use sklearn to split the data into train, validation and test set. We have 80% for training, 10% for testing and 10% for validation.

In [35]:
from sklearn.model_selection import train_test_split

df_i2b2 = df_i2b2.sample(frac=0.2).copy()

X = df_i2b2['sentence']
y = df_i2b2['label']

X_train_valid,X_test,y_train_valid, y_test= train_test_split(X,y,test_size=0.1,stratify=y,random_state=42)
X_train,X_valid,y_train,y_valid = train_test_split(X_train_valid,y_train_valid,train_size=0.8,random_state=42,stratify=y_train_valid)

print(f"X_train shape {X_train.shape} y_train shape : {y_train.shape}")
print(f"X_valid shape {X_valid.shape} y_valid shape : {y_valid.shape}")
print(f"X_test shape {X_test.shape} y_test shape : {y_test.shape}")

X_train shape (873,) y_train shape : (873,)
X_valid shape (219,) y_valid shape : (219,)
X_test shape (122,) y_test shape : (122,)


In [36]:
import numpy as np
print(X_train.shape,y_train.shape)
np.vstack((y_train,X_train))

(873,) (873,)


array([['present', 'present', 'absent', ..., 'present', 'present',
        'absent'],
       ['[entity] nephrolithiasis [entity]',
        'mr mackey also had signs and symptoms consistent with cervical myelopathy a cervical mri scan showed very impressive disc herniations at c5 c6 less so at c4 c5 with [entity] clear cut cord compression [entity] particularly on the right side c5 c6',
        'pt feeling well denies [entity] abd pain [entity] n v lh sob cp',
        ...,
        'past medical history aicd pocket infection c b mssa bacteremia pericardial effusion s p mediastinal exploration evacuation of pericardial effusion hematoma 2017 07 24 ischemic colitis and ischemic liver 2017 06 22 post air embolism from post mediastinal exploration cad s p lad ptca 33 years ago [entity] t2dm [entity] c b neuropathy and nephropathy copd hypothyroidism cva s p bovine avr 1999 hyperlipidemia gerd chronic lbp lumbar sympathectomy social history the patient is a retired truck driver',
        '[en

We also use sklearn to encode our labels

In [37]:
import numpy as np
from sklearn.preprocessing import LabelEncoder

print("Encoding Labels .....")
encoder = LabelEncoder()
encoder.fit(y_train)
y_train_encode = np.asarray(encoder.transform(y_train))
y_valid_encode = np.asarray(encoder.transform(y_valid))
y_test_encode = np.asarray(encoder.transform(y_test))

Encoding Labels .....


In [38]:
X_train

6220                    [entity] nephrolithiasis [entity]
671     mr mackey also had signs and symptoms consiste...
3849    pt feeling well denies [entity] abd pain [enti...
5917    pertinent results 2018 12 26 11 10 pm lactate ...
4968    admission date 2013 12 24 discharge date 2014 ...
                              ...                        
3948    the patient denies [entity] any chest pain [en...
3229    past medical history htn afib sss copd cad pac...
5354    past medical history aicd pocket infection c b...
6245    [entity] end stage renal disease [entity] stat...
2962    his right tib fib fracture was repaired by ort...
Name: sentence, Length: 873, dtype: object

In [39]:
import pandas as pd
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import Dataset, DatasetDict


train_df = pd.DataFrame(X_train)
valid_df = pd.DataFrame(X_valid)
test_df = pd.DataFrame(X_test)

train_df['label'] = y_train_encode.tolist()
valid_df['label'] = y_valid_encode.tolist()
test_df['label'] = y_test_encode.tolist()

print(train_df.head())

ds = DatasetDict ({
 'train': Dataset.from_pandas(train_df),
 'validation': Dataset.from_pandas(valid_df),
 'test': Dataset.from_pandas(test_df)
})

  from .autonotebook import tqdm as notebook_tqdm


                                               sentence  label
6220                  [entity] nephrolithiasis [entity]      2
671   mr mackey also had signs and symptoms consiste...      2
3849  pt feeling well denies [entity] abd pain [enti...      0
5917  pertinent results 2018 12 26 11 10 pm lactate ...      2
4968  admission date 2013 12 24 discharge date 2014 ...      2


In [22]:
ds

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', '__index_level_0__'],
        num_rows: 873
    })
    validation: Dataset({
        features: ['sentence', 'label', '__index_level_0__'],
        num_rows: 219
    })
    test: Dataset({
        features: ['sentence', 'label', '__index_level_0__'],
        num_rows: 122
    })
})

### 6.  Fine-tuning ClinicalBERT

In [23]:
import torch 
# setting device on GPU if available, else CPU
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [24]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModel 
tokenizer_clinical_bio  = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT",model_max_length=150)
model_clinical = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT", 
                                                                    num_labels=3,id2label={0: 'PRESENT', 1: 'ABSENT', 2:'POSSIBLE'})

(…)ge_Summary_BERT/resolve/main/config.json: 100%|████████████████████████████████████████████| 385/385 [00:00<?, ?B/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
(…)arge_Summary_BERT/resolve/main/vocab.txt: 100%|██████████████████████████████████| 213k/213k [00:00<00:00, 3.68MB/s]
pytorch_model.bin: 100%|████████████████████████████████████████████████████████████| 436M/436M [00:39<00:00, 11.0MB/s]
Some weights of the model checkpoint at emilyalsentzer/Bio_Discharge_Summary_BERT were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.tran

In [25]:
model_clinical

BertForSequenceClassification(
  (shared_parameters): ModuleDict()
  (bert): BertModel(
    (shared_parameters): ModuleDict()
    (invertible_adapters): ModuleDict()
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(
                in_features=768, out_features=768, bias=True
                (loras): ModuleDict()
              )
              (key): Linear(
                in_features=768, out_features=768, bias=True
                (loras): ModuleDict()
              )
              (value): Linear(
                in_features=7

In [26]:
model_clinical = model_clinical.to(device)

In [27]:
special_tokens_dict = {"additional_special_tokens": ["[entity]"]}
num_added_toks = tokenizer_clinical_bio.add_special_tokens(special_tokens_dict,False)

print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model_clinical.resize_token_embeddings(len(tokenizer_clinical_bio))

We have added 1 tokens


Embedding(28997, 768)

In [28]:
def tokenize_function(example):
    return tokenizer_clinical_bio(example["sentence"],   padding="max_length", truncation=True)

In [29]:
tokenized_ds = ds.map(tokenize_function, batched=True)
tokenized_ds = tokenized_ds.rename_column("label", "labels")
tokenized_ds = tokenized_ds.remove_columns(["sentence"])
tokenized_ds = tokenized_ds.remove_columns(["__index_level_0__"])
tokenized_ds.set_format("torch")

Map: 100%|█████████████████████████████████████████████████████████████████| 873/873 [00:00<00:00, 10124.42 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████| 219/219 [00:00<00:00, 8974.36 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 21061.29 examples/s]


In [30]:
import numpy as np
import evaluate

def compute_metrics(eval_pred):
    metric = evaluate.load("accuracy")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [31]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="clinbert_trainer", evaluation_strategy="epoch", learning_rate=1e-5, num_train_epochs=1,)

trainer = Trainer(
    model=model_clinical,
    args=training_args,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['validation'],
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

***** Running training *****
  Num examples = 873
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 110
  Number of trainable parameters = 108313347


Epoch,Training Loss,Validation Loss


In [None]:
trainer.evaluate()

In [88]:
len(df_i2b2)

9863

In [91]:
dropped = df_i2b2[['record', 'sentence', 'label']]
dropped = dropped.drop_duplicates()
dropped

Unnamed: 0,record,sentence,label
0,record-105,admission date 2017 06 13 discharge date 2017 ...,present
2,record-105,admission date 2017 06 13 discharge date 2017 ...,absent
3,record-105,for [entity] ventral hernia [entity] repair,present
5,record-105,had subsequent abnormal stress test and pefusi...,present
6,record-105,had subsequent [entity] abnormal stress test [...,present
...,...,...,...
10658,989519730_WGH,she was maintained on an epidural and pca for ...,present
10662,989519730_WGH,when she was taking clear liquids her [entity]...,present
10665,989519730_WGH,[entity] pain [entity] regimen with good affect,present
10668,989519730_WGH,with the help of physical therapy the patient ...,present


In [103]:
df_data_labels[(df_data_labels['record'] == 'record-105') & (df_data_labels['entity'] == 'simple atheroma')]

Unnamed: 0,record,assertion,label,entity
27,record-105,present,problem,simple atheroma
28,record-105,present,problem,simple atheroma
30,record-105,present,problem,simple atheroma
33,record-105,present,problem,simple atheroma


In [121]:
dropped[35:45]

Unnamed: 0,index,record,sentence,label
35,45,record-105,there is no [entity] ventricular septal defect...,absent
36,46,record-105,resting regional wall motion abnormalities inc...,present
37,47,record-105,[entity] resting regional wall motion abnormal...,present
38,48,record-105,the right ventricular cavity is [entity] mildl...,present
39,50,record-105,there are [entity] simple atheroma [entity] in...,present
40,54,record-105,there are [entity] simple atheroma [entity] in...,present
41,58,record-105,there are [entity] simple atheroma [entity] in...,present
42,62,record-105,there are [entity] simple atheroma [entity] in...,present
43,66,record-105,the aortic valve leaflets are [entity] mildly ...,present
44,68,record-105,no masses or [entity] vegetations [entity] are...,absent


In [123]:
print(dropped.iloc[39].values[2])
print(dropped.iloc[40].values[2])
print(dropped.iloc[41].values[2])
print(dropped.iloc[42].values[2])

there are [entity] simple atheroma [entity] in the aortic root
there are [entity] simple atheroma [entity] in the ascending aorta
there are [entity] simple atheroma [entity] in the aortic arch
there are [entity] simple atheroma [entity] in the descending thoracic aorta


In [96]:
df_i2b2[df_i2b2['record'] == 'record-105']

Unnamed: 0,record,sentence,label,idx
0,record-105,admission date 2017 06 13 discharge date 2017 ...,present,0
2,record-105,admission date 2017 06 13 discharge date 2017 ...,absent,2
3,record-105,for [entity] ventral hernia [entity] repair,present,3
4,record-105,for [entity] ventral hernia [entity] repair,present,4
5,record-105,had subsequent abnormal stress test and pefusi...,present,5
...,...,...,...,...
84,record-105,disp 14 tablet s refills 0 discharge dispositi...,present,84
85,record-105,disp 14 tablet s refills 0 discharge dispositi...,present,85
86,record-105,disp 14 tablet s refills 0 discharge dispositi...,present,86
87,record-105,disp 14 tablet s refills 0 discharge dispositi...,present,87
