## Fine-tuning ClinicalBERT for clinical assertion detection

In this blog post, we will show how to fine-tune a bert language model for a downstream task: clinical assertion detection. We are going to leverage the Hugging Face transformer library and the model hub. 

We will show how  :

1. Load and prepare the data for assertion dectection

2. Fine-tune an auto-encoding language model such as Clinical BERT

3. Evaluate and run Inference with the trained model

### 1. Background and Context 

#### Clinical assertion dectection

This work is based on the paper "Assertion Detection in Clinical Notes: Medical Language Models to the Rescue?" , using Language model for assertion detection. Assertion detection is the task to identify the assertion of an entity based on textual cues in unstructured text. In other words we want to classify the assertions made on given medical concepts as being :
* present
* absent
* possible in the patient
* conditionally present in the patient under certain circumstances
* hypothtically present in the patient at some future point
* mentioned in the patient report but associated with somenone else

For example given the text "The patient recovered during the night and now denies any shortness of breath.", the model should identify that the entity: shortness of breath is absent. 

#### The data

For this demo we use The 2010 i2b2/VA Workshop on Natural Language Processing Challenges for Clinical Records presented three tasks: a concept extraction task focused on the extraction of medical concepts from patient reports; an assertion classification task focused on assigning assertion types for medical problem concepts; 
and a relation classification task focused on assigning relation types that hold between medical problems, tests, and treatments. These are be available to the research community from [i2b2](https://i2b2.org/NLP/DataSets) portal under data use agreements. For more information please consult the paper [2010 i2b2/VA challenge on concepts, assertions, and relations in clinical text](https://academic.oup.com/jamia/article/18/5/552/830538)

You need to request access, download and extract  the data needed.

### 2. Install the dependencies

In this example we will use the Pytorch and HuggingFace library, an run the experiemnt on a Google Colab. You will also need to install spacy and the biomedical pretrained model  **en_ner_bc5cdr_md** a spaCy NER model trained on the BC5CDR corpus. The model en_ner_bc5cdr_md was trained for DISEASE and CHEMICAL entity recognition. To install all the dependencies run the following cell.

In [1]:
# install all the libraries and dependencies
#%pip install -r requirements.txt
#python=3.8
# conda install matplotlib numpy scikit-learn
# conda install pandas
# pip install spacy
# pip install scispacy ? gave error ignored
# pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_ner_bc5cdr_md-0.4.0.tar.gz
# conda install pytorch torchvision -c pytorch
# pip install ipykernel
# pip install transformers
# pip install datasets
# pip install evaluate
# pip install accelerate -U

### 3. Load and prepare the data

We will use the assertion classification data from i2b2, which consist of XXXXX records of discharge summary notes. 

In [2]:
import os 
cwd  = os.getcwd()
labels_path = os.path.join(cwd,"Data/concept_assertion_relation_training_data","beth","ast")
data_path = os.path.join(cwd,"Data/concept_assertion_relation_training_data","beth","txt")

print(labels_path)
print(data_path)

C:\Users\kcaro\Documents\GitHub\clinical-adapter\Data/concept_assertion_relation_training_data\beth\ast
C:\Users\kcaro\Documents\GitHub\clinical-adapter\Data/concept_assertion_relation_training_data\beth\txt


In [3]:
# creating a list of the files names
records = [i for i in range(13, 39)]
records = records + [i for i in range(45, 57)]
records = records + [58,59]
records = records + [i for i in range(65, 71)]
records = records + [73,74]
records = records + [i for i in range(81, 85)]
records = records + [i for i in range(105,109)]
records = records + [i for i in range(121,125)]
records = records + [i for i in range(140,145)]
records = records + [i for i in range(175,180)]
records_files = [f"record-{i}.txt" for i in records]

We create a function which loops in text files list and read each file content

In [5]:
def load_clinical_notes(records_files):
    # reading the data files in a list
    content_records = []
    for record in records_files:
        _file = os.path.join(data_path,record)
        with open(_file) as f:
            content = f.read()
            #lines = content.split("\n")
            content_records.append((record[:-4],content))

    f.close()

    return content_records

content_records = load_clinical_notes(records_files)

Then we split each note into sentences using spacy biomedical pretrained model.

In [6]:
import en_ner_bc5cdr_md
import spacy

def split_note_sentences(content_records):
    # load spacy
    nlp1 = spacy.load("en_ner_bc5cdr_md",disable = ['parser'])
    nlp1.add_pipe('sentencizer')

    # transform the data into a list of sentences
    docs = [(r,nlp1(text)) for r,text in content_records]
    data = []
    for r,doc in docs:
        for s in doc.sents:
            sentence = str.strip(str(s))
            sentence = sentence.replace("\n"," ")
            data.append((r,sentence))
    
    return data

data = split_note_sentences(content_records)



In [28]:
data[0]

('record-13',
 'Admission Date : 2018-10-25 Discharge Date : 2018-10-31 Date of Birth : 1951-06-15 Sex : M Service :  CARDIOTHORACIC Allergies : Patient recorded as having No Known Allergies to Drugs Attending : Michael D. Christensen , M.D. Chief Complaint : Shortness of Breath Major Surgical or Invasive Procedure : Coronary Artery Bypass Graft x3 ( Left internal mammary -> left anterior descending , saphaneous vein graft -> obtuse marginal , saphaneous vein graft -> posterior descending artery ) 2018-10-25 History of Present Illness : 67 y/o male with worsening shortness of breath.')

The next step is to load and process the labels. They are provided as ast files.

In [7]:
import re
import pandas as pd 

def load_notes_labels(records):

    records_files_ast = [f"record-{i}.ast" for i in records]
    
    # load labels in a list
    labels_records = []
    for record in records_files_ast:
        _file = os.path.join(labels_path,record)
        #print(_file)
        with open(_file) as f:
            content = f.readlines()
            file_data = []
            for line in content:
                ast = line.strip().split('||')
                line_entity = []

                assertion = ast[2].split('=')
                entity_label = ast[1].split("=")
        
                entity_text = re.findall('"([^"]*)"', ast[0])

                line_entity.append(record[:-4])
                line_entity.append(assertion[1].replace('"',''))
                line_entity.append(entity_label[1].replace('"',''))
                line_entity.append(entity_text[0])
                file_data.append(line_entity)
        labels_records.append(file_data)

    f.close()

    return labels_records

labels_records  = load_notes_labels(records)

# labels in a dataframe
data_labels = [line for f in labels_records for line in f]
df_data_labels = pd.DataFrame(data_labels,columns=['record','assertion','label','entity'])

df_data_labels

Unnamed: 0,record,assertion,label,entity
0,record-13,present,problem,coronary artery disease
1,record-13,present,problem,burst of atrial fibrillation
2,record-13,present,problem,left arm phlebitis
3,record-13,absent,problem,further episodes of afib
4,record-13,present,problem,mildly thickened
...,...,...,...,...
4107,record-179,present,problem,seasonal allergies
4108,record-179,present,problem,his embolus
4109,record-179,absent,problem,cough
4110,record-179,present,problem,discoid lateral meniscus


### 4. Annotate text for clinical assertion detection

After pre-processing the data we need to annotate each entity in our training data between the token '[entity]' .

In [8]:
import re

def clean_text(text):
    """
    Applies some pre-processing on the given text.

    Steps :
    - Removing HTML tags
    - Removing punctuation
    - Lowering text
    """
    
    # remove HTML tags
    text = re.sub(r'<.*?>', '', text)
    
    # remove the characters [\], ['] and ["]
    text = re.sub(r"\\", "", text)    
    text = re.sub(r"\'", "", text)    
    text = re.sub(r"\"", "", text)    
    
    # convert text to lowercase
    text = text.strip().lower()
    
    # remove all non-ASCII characters:
    text = re.sub(r'[^\x00-\x7f]',r'', text) 
    
    # replace punctuation characters with spaces
    filters='!"\'#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n'
    translate_dict = dict((c, " ") for c in filters)
    translate_map = str.maketrans(translate_dict)
    text = text.translate(translate_map)
    text = " ".join(text.split())
    return text

First, we need to map the data with the labels. We do so by using the record id and searching the entity name in the text to do the mapping.

In [9]:

def linking_train_labels_data(data,df_data_labels):
    new_data = []
    for r , sent in data:
        for index,row in df_data_labels.loc[df_data_labels['record'] == r,['entity','assertion']].iterrows():
            entity = clean_text(row['entity'])
            sentence = clean_text(sent)
            #print(entity,sent)
            try:
                if re.search(r'\b' + str(entity) + r'\b', str(sentence)):
                    new_data.append((r,entity,sentence,row['assertion']))
            except:
                print(r)
                print("entity:",str(entity))
                print("****")
    return new_data




In [10]:
new_data = linking_train_labels_data(data,df_data_labels)

Then, we can annotate each sentence with the token [entity].

In [11]:
def annotate_data(new_data):
    processed_data = []
    for r,entity,text,label in new_data:
        #print(text)
        match = re.search(r'\b' + entity + r'\b',text)

        res = list(text)
        res.insert(match.start(), '[entity] ')
        res.insert(match.end()+1, ' [entity]')
        res = ''.join(res)
        processed_data.append((r,entity,res,label))  

    return processed_data

processed_data =  annotate_data(new_data)

In [12]:
df_data_labels

Unnamed: 0,record,assertion,label,entity
0,record-13,present,problem,coronary artery disease
1,record-13,present,problem,burst of atrial fibrillation
2,record-13,present,problem,left arm phlebitis
3,record-13,absent,problem,further episodes of afib
4,record-13,present,problem,mildly thickened
...,...,...,...,...
4107,record-179,present,problem,seasonal allergies
4108,record-179,present,problem,his embolus
4109,record-179,absent,problem,cough
4110,record-179,present,problem,discoid lateral meniscus


In [13]:
df_data_labels.loc[df_data_labels['record'] =='record-13' ,['entity','assertion']]

Unnamed: 0,entity,assertion
0,coronary artery disease,present
1,burst of atrial fibrillation,present
2,left arm phlebitis,present
3,further episodes of afib,absent
4,mildly thickened,present
5,severe 3 vessel disease,present
6,mildly dilated,present
7,hypertension,present
8,carpal tunnel syndrome,present
9,increased pain,hypothetical


In [14]:
('record-13',
 'known allergies',
 'admission date 2018 10 25 discharge date 2018 10 31 date of birth 1951 06 15 sex m service cardiothoracic allergies patient recorded as having no [entity] known allergies [entity] to drugs attending michael d christensen m d chief complaint shortness of breath major surgical or invasive procedure coronary artery bypass graft x3 left internal mammary left anterior descending saphaneous vein graft obtuse marginal saphaneous vein graft posterior descending artery 2018 10 25 history of present illness 67 y o male with worsening shortness of breath',
 'absent')

('record-13',
 'known allergies',
 'admission date 2018 10 25 discharge date 2018 10 31 date of birth 1951 06 15 sex m service cardiothoracic allergies patient recorded as having no [entity] known allergies [entity] to drugs attending michael d christensen m d chief complaint shortness of breath major surgical or invasive procedure coronary artery bypass graft x3 left internal mammary left anterior descending saphaneous vein graft obtuse marginal saphaneous vein graft posterior descending artery 2018 10 25 history of present illness 67 y o male with worsening shortness of breath',
 'absent')

Finally , we create a dataframe where for our example we only keep 3 assertions labels : present , absent and possible.

In [15]:
prepare_data = [{'sentence':text , 'label':label,'idx':idx} for idx,(r, entity, text,label) in enumerate(processed_data)]

df_i2b2 = pd.DataFrame(prepare_data)
df_i2b2 = df_i2b2[(df_i2b2.label=='present') | (df_i2b2.label=='absent') | (df_i2b2.label=='possible') ].copy()
df_i2b2


Unnamed: 0,sentence,label,idx
0,admission date 2018 10 25 discharge date 2018 ...,present,0
1,admission date 2018 10 25 discharge date 2018 ...,absent,1
2,admission date 2018 10 25 discharge date 2018 ...,present,2
3,had [entity] abnormal ett [entity] and referre...,present,3
4,cath revealed [entity] severe 3 vessel disease...,present,4
...,...,...,...
6530,medications on admission claritin prn flonase ...,present,6530
6531,medications on admission claritin prn flonase ...,present,6531
6532,medications on admission claritin prn flonase ...,present,6532
6533,the mri of your knee showed [entity] a menisca...,present,6533


In [16]:
df_i2b2.iloc[0]

sentence    admission date 2018 10 25 discharge date 2018 ...
label                                                 present
idx                                                         0
Name: 0, dtype: object

In [21]:
df_data_labels.head()

Unnamed: 0,record,assertion,label,entity
0,record-13,present,problem,coronary artery disease
1,record-13,present,problem,burst of atrial fibrillation
2,record-13,present,problem,left arm phlebitis
3,record-13,absent,problem,further episodes of afib
4,record-13,present,problem,mildly thickened


In [22]:
df_i2b2.head()

Unnamed: 0,sentence,label,idx
0,admission date 2018 10 25 discharge date 2018 ...,present,0
1,admission date 2018 10 25 discharge date 2018 ...,absent,1
2,admission date 2018 10 25 discharge date 2018 ...,present,2
3,had [entity] abnormal ett [entity] and referre...,present,3
4,cath revealed [entity] severe 3 vessel disease...,present,4


In [17]:
# Why isn't the i2b2 dataset the same length as the df_data_labels?
len(df_i2b2), len(df_data_labels)

(6072, 4112)

In [18]:
# Example of repeat data
print(df_i2b2['sentence'][31], " - l:", df_i2b2['label'][31])
print(df_i2b2['sentence'][4488], " - l:", df_i2b2['label'][4488])
print(df_i2b2['sentence'][4489], " - l:", df_i2b2['label'][4489])
print(df_i2b2['sentence'][4490], " - l:", df_i2b2['label'][4490])
print(df_i2b2['sentence'][4492], " - l:", df_i2b2['label'][4492])
print(df_i2b2['sentence'][4493], " - l:", df_i2b2['label'][4493])
print(df_i2b2['sentence'][4494], " - l:", df_i2b2['label'][4494])
print(df_i2b2['sentence'][4495], " - l:", df_i2b2['label'][4495])
print(df_i2b2['sentence'][4496], " - l:", df_i2b2['label'][4496])
print(df_i2b2['sentence'][4497], " - l:", df_i2b2['label'][4497])
print(df_i2b2['sentence'][4498], " - l:", df_i2b2['label'][4498])
print(df_i2b2['sentence'][4499], " - l:", df_i2b2['label'][4499])
print(df_i2b2['sentence'][4500], " - l:", df_i2b2['label'][4500])
print(df_i2b2['sentence'][4501], " - l:", df_i2b2['label'][4501])
print(df_i2b2['sentence'][4502], " - l:", df_i2b2['label'][4502])
print(df_i2b2['sentence'][4503], " - l:", df_i2b2['label'][4503])

there are [entity] simple atheroma in the descending thoracic aorta [entity]  - l: present
there are [entity] simple atheroma [entity] in the aortic root  - l: present
there are [entity] simple atheroma [entity] in the aortic root  - l: present
there are [entity] simple atheroma [entity] in the aortic root  - l: present
there are [entity] simple atheroma [entity] in the ascending aorta  - l: present
there are [entity] simple atheroma [entity] in the ascending aorta  - l: present
there are [entity] simple atheroma [entity] in the ascending aorta  - l: present
there are [entity] simple atheroma [entity] in the ascending aorta  - l: present
there are [entity] simple atheroma [entity] in the aortic arch  - l: present
there are [entity] simple atheroma [entity] in the aortic arch  - l: present
there are [entity] simple atheroma [entity] in the aortic arch  - l: present
there are [entity] simple atheroma [entity] in the aortic arch  - l: present
there are [entity] simple atheroma [entity] in

In [19]:
# However, even when we drop duplicates, there appears to be extra data that does not belong
df_i2b2_dropped = df_i2b2[['sentence', 'label']]
df_i2b2_dropped = df_i2b2_dropped.drop_duplicates()
print("num labels:", len(df_data_labels))
print("num in data:", len(df_i2b2_dropped))

num labels: 4112
num in data: 4403


In [23]:
# Example of repeat data
# print("s:", df_i2b2['sentence'][31], " - l:", df_i2b2['label'][31])
print(df_i2b2_dropped['sentence'][4488], " - l:", df_i2b2['label'][4488])
# print(df_i2b2_dropped['sentence'][4489], " - l:", df_i2b2['label'][4489])
# print(df_i2b2_dropped['sentence'][4490], " - l:", df_i2b2['label'][4490])
print(df_i2b2_dropped['sentence'][4492], " - l:", df_i2b2['label'][4492])
# print(df_i2b2_dropped['sentence'][4493], " - l:", df_i2b2['label'][4493])
# print(df_i2b2_dropped['sentence'][4494], " - l:", df_i2b2['label'][4494])
# print(df_i2b2_dropped['sentence'][4495], " - l:", df_i2b2['label'][4495])
print(df_i2b2_dropped['sentence'][4496], " - l:", df_i2b2['label'][4496])
# print(df_i2b2_dropped['sentence'][4497], " - l:", df_i2b2['label'][4497])
# print(df_i2b2_dropped['sentence'][4498], " - l:", df_i2b2['label'][4498])
# print(df_i2b2_dropped['sentence'][4499], " - l:", df_i2b2['label'][4499])
print(df_i2b2_dropped['sentence'][4500], " - l:", df_i2b2['label'][4500])
# print(df_i2b2_dropped['sentence'][4501], " - l:", df_i2b2['label'][4501])
# print(df_i2b2_dropped['sentence'][4502], " - l:", df_i2b2['label'][4502])
# print(df_i2b2_dropped['sentence'][4503], " - l:", df_i2b2['label'][4503])

there are [entity] simple atheroma [entity] in the aortic root  - l: present
there are [entity] simple atheroma [entity] in the ascending aorta  - l: present
there are [entity] simple atheroma [entity] in the aortic arch  - l: present
there are [entity] simple atheroma [entity] in the descending thoracic aorta  - l: present


### 5.  Splitting the data and create dataset

We use sklearn to split the data into train, validation and test set. We have 80% for training, 10% for testing and 10% for validation.

In [15]:
from sklearn.model_selection import train_test_split

df_i2b2 = df_i2b2.sample(frac=0.2).copy()

X = df_i2b2['sentence']
y = df_i2b2['label']

X_train_valid,X_test,y_train_valid, y_test= train_test_split(X,y,test_size=0.1,stratify=y,random_state=42)
X_train,X_valid,y_train,y_valid = train_test_split(X_train_valid,y_train_valid,train_size=0.8,random_state=42,stratify=y_train_valid)

print(f"X_train shape {X_train.shape} y_train shape : {y_train.shape}")
print(f"X_valid shape {X_valid.shape} y_valid shape : {y_valid.shape}")
print(f"X_test shape {X_test.shape} y_test shape : {y_test.shape}")

X_train shape (873,) y_train shape : (873,)
X_valid shape (219,) y_valid shape : (219,)
X_test shape (122,) y_test shape : (122,)


In [16]:
import numpy as np
print(X_train.shape,y_train.shape)
np.vstack((y_train,X_train))

(873,) (873,)


array([['present', 'present', 'absent', ..., 'present', 'present',
        'absent'],
       ['[entity] prematurity [entity] at 30 6 7 weeks',
        'impression unchanged moderate [entity] bilateral pleural effusions [entity] and mild chf',
        'nicholas seizures no [entity] loc [entity] no head or neck trauma',
        ...,
        'discharge diagnosis [entity] st elevation myocardial infarction [entity] status post left anterior descending artery stent',
        'status post tracheostomy 2017 07 21 02 22 [entity] failure to wean [entity] 14',
        'no [entity] lymphadenopathy [entity]']], dtype=object)

We also use sklearn to encode our labels

In [17]:
import numpy as np
from sklearn.preprocessing import LabelEncoder

print("Encoding Labels .....")
encoder = LabelEncoder()
encoder.fit(y_train)
y_train_encode = np.asarray(encoder.transform(y_train))
y_valid_encode = np.asarray(encoder.transform(y_valid))
y_test_encode = np.asarray(encoder.transform(y_test))

Encoding Labels .....


In [18]:
X_train

4702        [entity] prematurity [entity] at 30 6 7 weeks
3295    impression unchanged moderate [entity] bilater...
1850    nicholas seizures no [entity] loc [entity] no ...
6506           syncope in setting of [entity] pe [entity]
2465    pt s [entity] hypertensive urgency [entity] wa...
                              ...                        
562     left upper extremity examination demonstrated ...
6438    social history lives w wife son in jose ma den...
3837    discharge diagnosis [entity] st elevation myoc...
5801    status post tracheostomy 2017 07 21 02 22 [ent...
5247                 no [entity] lymphadenopathy [entity]
Name: sentence, Length: 873, dtype: object

In [19]:
import pandas as pd
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import Dataset, DatasetDict


train_df = pd.DataFrame(X_train)
valid_df = pd.DataFrame(X_valid)
test_df = pd.DataFrame(X_test)

train_df['label'] = y_train_encode.tolist()
valid_df['label'] = y_valid_encode.tolist()
test_df['label'] = y_test_encode.tolist()

print(train_df.head())

ds = DatasetDict ({
 'train': Dataset.from_pandas(train_df),
 'validation': Dataset.from_pandas(valid_df),
 'test': Dataset.from_pandas(test_df)
})

  from .autonotebook import tqdm as notebook_tqdm


                                               sentence  label
4702      [entity] prematurity [entity] at 30 6 7 weeks      2
3295  impression unchanged moderate [entity] bilater...      2
1850  nicholas seizures no [entity] loc [entity] no ...      0
6506         syncope in setting of [entity] pe [entity]      2
2465  pt s [entity] hypertensive urgency [entity] wa...      2


In [20]:
ds

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', '__index_level_0__'],
        num_rows: 873
    })
    validation: Dataset({
        features: ['sentence', 'label', '__index_level_0__'],
        num_rows: 219
    })
    test: Dataset({
        features: ['sentence', 'label', '__index_level_0__'],
        num_rows: 122
    })
})

### 6.  Fine-tuning ClinicalBERT

In [21]:
import torch 
# setting device on GPU if available, else CPU
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [22]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModel 
tokenizer_clinical_bio  = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT",model_max_length=150)
model_clinical = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT", 
                                                                    num_labels=3,id2label={0: 'PRESENT', 1: 'ABSENT', 2:'POSSIBLE'})

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_Discharge_Summary_BERT and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
model_clinical

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [24]:
model_clinical = model_clinical.to(device)

In [25]:
special_tokens_dict = {"additional_special_tokens": ["[entity]"]}
num_added_toks = tokenizer_clinical_bio.add_special_tokens(special_tokens_dict,False)

print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model_clinical.resize_token_embeddings(len(tokenizer_clinical_bio))

We have added 1 tokens


Embedding(28997, 768)

In [26]:
def tokenize_function(example):
    return tokenizer_clinical_bio(example["sentence"],   padding="max_length", truncation=True)

In [28]:
tokenized_ds = ds.map(tokenize_function, batched=True)
tokenized_ds = tokenized_ds.rename_column("label", "labels")
tokenized_ds = tokenized_ds.remove_columns(["sentence"])
tokenized_ds = tokenized_ds.remove_columns(["__index_level_0__"])
tokenized_ds.set_format("torch")

Map:   0%|          | 0/873 [00:00<?, ? examples/s]Map: 100%|██████████| 873/873 [00:00<00:00, 12339.43 examples/s]
Map: 100%|██████████| 219/219 [00:00<00:00, 2047.84 examples/s]
Map: 100%|██████████| 122/122 [00:00<00:00, 12725.50 examples/s]


In [29]:
import numpy as np
import evaluate

def compute_metrics(eval_pred):
    metric = evaluate.load("accuracy")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [30]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="clinbert_trainer", evaluation_strategy="epoch", learning_rate=1e-5, num_train_epochs=1,)

trainer = Trainer(
    model=model_clinical,
    args=training_args,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['validation'],
    compute_metrics=compute_metrics,
)

In [31]:
trainer.train()

                                                 
100%|██████████| 110/110 [00:56<00:00,  1.93it/s]

{'eval_loss': 0.5837724804878235, 'eval_accuracy': 0.7625570776255708, 'eval_runtime': 3.5802, 'eval_samples_per_second': 61.17, 'eval_steps_per_second': 7.821, 'epoch': 1.0}
{'train_runtime': 57.0262, 'train_samples_per_second': 15.309, 'train_steps_per_second': 1.929, 'train_loss': 0.6306253606622869, 'epoch': 1.0}





TrainOutput(global_step=110, training_loss=0.6306253606622869, metrics={'train_runtime': 57.0262, 'train_samples_per_second': 15.309, 'train_steps_per_second': 1.929, 'train_loss': 0.6306253606622869, 'epoch': 1.0})

In [32]:
trainer.evaluate()

100%|██████████| 28/28 [00:03<00:00,  9.09it/s]


{'eval_loss': 0.5837724804878235,
 'eval_accuracy': 0.7625570776255708,
 'eval_runtime': 3.1911,
 'eval_samples_per_second': 68.628,
 'eval_steps_per_second': 8.774,
 'epoch': 1.0}