### Main module to train and test NER based on Bert embeddings
#### By Isar Nejadgholi

In [1]:
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from model import Net
from data_load import NerDataset, pad
import os
import numpy as np
import argparse

In [2]:
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

#### Check if gpu is available, if not set on cpu

In [3]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print('GPU Type:     ',torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
    print('Number of gpus:   ', torch.cuda.device_count())
    
    torch.cuda.empty_cache()
    !export CUDA_VISIBLE_DEVICES=1
   

Using device: cpu



#### Set bert model and dataset 

In [4]:
Dataset_name ='i2b2'  #i2b2 or 'medmention'
Bert_Model_name ='clinical_15'  #'general' or 'biobert' or 'clinical_10' or 'clinical_15' or 'NCBI'

#### Train, test, dev sets and model path

In [5]:
#SETTINGS 

if Dataset_name =='i2b2':
    trainset = "data/i2b2_train_total.txt"
    validset = "data/i2b2_train_total.txt"
    testset = "data/i2b2_test_total.txt"
else:
    trainset = "data/medmention_PA_train.txt"
    validset = "data/medmention_PA_valiadtion.txt"
    testset = "data/medmention_PA_test.txt"
    
    
saves_file = 'saved_models/general_uncased_i2b2.pt'

#### Training function 

In [6]:


def train(model, iterator, optimizer, criterion,tokenizer):
    model.train()
    for i, batch in enumerate(iterator):
        words, x, is_heads, tags, y, seqlens = batch
        _y = y # for monitoring
        optimizer.zero_grad()
        logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i==0:
            print("=====sanity check======")
            print("words:", words[0])
            print("x:", x.cpu().numpy()[0][:seqlens[0]])
            print("tokens:", tokenizer.convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]])
            print("is_heads:", is_heads[0])
            print("y:", _y.cpu().numpy()[0][:seqlens[0]])
            print("tags:", tags[0])
            print("seqlen:", seqlens[0])
            print("=======================")


        if i%10==0: # monitoring
            print(f"step: {i}, loss: {loss.item()}")



#### Evaluation function

In [7]:
def eval(model, iterator, f):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    true_tag = []
    pred_tag = []
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [train_dataset.idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")
            true_tag.append(tags.split()[1:-1])
            pred_tag.append(preds[1:-1])

    ## calc metric
    y_true =  np.array([train_dataset.tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([train_dataset.tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred>1])
    num_correct = (np.logical_and(y_true==y_pred, y_true>1)).astype(np.int).sum()
    num_gold = len(y_true[y_true>1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0

    final = f + ".P%.2f_R%.2f_F%.2f" %(precision, recall, f1)
    with open(final, 'w') as fout:
        result = open("temp", "r").read()
        fout.write(f"{result}\n")

        fout.write(f"precision={precision}\n")
        fout.write(f"recall={recall}\n")
        fout.write(f"f1={f1}\n")

    os.remove("temp")

    print("precision=%.2f"%precision)
    print("recall=%.2f"%recall)
    print("f1=%.2f"%f1)
    return precision, recall, f1,true_tag, pred_tag

#### initiate parameters, model and data



In [8]:


batch_size =32
lr =0.00005
gradient_accumulation_steps = 1


finetuning = True
logdir = "checkpoints/01"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Net(bert_name = Bert_Model_name,  device=device, finetuning=finetuning,data_name = Dataset_name)

# send the model to DataParallel only if you are training on multiple gpus
model = nn.DataParallel(model)  

train_dataset = NerDataset(trainset,bert_name = Bert_Model_name,data_name = Dataset_name)
eval_dataset = NerDataset(validset,bert_name = Bert_Model_name,data_name = Dataset_name)

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0,
                             collate_fn=pad)
eval_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=batch_size,
                             num_workers=0,
                             shuffle=False,
                             collate_fn=pad)







In [9]:
train_dataset.tag2idx, train_dataset.label_num 
        

({'<PAD>': 0,
  'O': 1,
  'B-problem': 2,
  'I-problem': 3,
  'B-treatment': 4,
  'I-treatment': 5,
  'B-test': 6,
  'I-test': 7},
 8)

In [10]:
#from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/lm_finetuning/simple_lm_finetuning.py
if finetuning:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    n_epochs =4
    
    num_train_optimization_steps = int(
len(train_dataset) / batch_size / gradient_accumulation_steps) * n_epochs

    optimizer = BertAdam(optimizer_grouped_parameters,
                 lr=0.00005,
                 warmup=0.1,
                 t_total=num_train_optimization_steps)
     
else:
    optimizer = optim.Adam(model.parameters(), lr = 0.00005) 
    n_epochs =4
    
    
    
criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.tag2idx['<PAD>'])


## Train

In [11]:
for epoch in range(1, n_epochs+1):
    
    train(model, train_iter, optimizer, criterion,train_dataset.tokenizer)

    print(f"=========eval at epoch={epoch}=========")
    if not os.path.exists(logdir): os.makedirs(logdir)
    fname = os.path.join(logdir, str(epoch))
    precision, recall, f1, true, pred= eval(model, eval_iter, fname)

    torch.save(model.state_dict(), f"{fname}.pt")
print(f"weights were saved to {fname}.pt")

words: [CLS] Ronchorous breath sounds left chest . [SEP]
x: [  101  6413 22811  2285  2184  3807  1286  2229   119   102]
tokens: ['[CLS]', 'Ron', '##chor', '##ous', 'breath', 'sounds', 'left', 'chest', '.', '[SEP]']
is_heads: [1, 1, 0, 0, 1, 1, 1, 1, 1, 1]
y: [0 2 0 0 3 3 3 3 1 0]
tags: <PAD> B-problem I-problem I-problem I-problem I-problem O <PAD>
seqlen: 10
step: 0, loss: 2.456310987472534
step: 10, loss: 1.0843297243118286
step: 20, loss: 0.8022581338882446
step: 30, loss: 0.7786921262741089
step: 40, loss: 1.188483476638794
step: 50, loss: 0.8169649839401245
step: 60, loss: 0.4862866699695587
step: 70, loss: 0.49962860345840454
step: 80, loss: 0.32892483472824097
step: 90, loss: 0.31139689683914185
step: 100, loss: 0.319164514541626
step: 110, loss: 0.2556055784225464
step: 120, loss: 0.16158416867256165
step: 130, loss: 0.2105257362127304
step: 140, loss: 0.2423766553401947
step: 150, loss: 0.18080030381679535
step: 160, loss: 0.21201051771640778
step: 170, loss: 0.2322048991918

Training beyond specified 't_total'. Learning rate multiplier set to 0.0. Please set 't_total' of WarmupLinearSchedule correctly.


step: 510, loss: 0.01735931821167469


Training beyond specified 't_total'. Learning rate multiplier set to 0.0. Please set 't_total' of WarmupLinearSchedule correctly.
Training beyond specified 't_total'. Learning rate multiplier set to 0.0. Please set 't_total' of WarmupLinearSchedule correctly.


num_proposed:34767
num_correct:34491
num_gold:34758
precision=0.99
recall=0.99
f1=0.99
weights were saved to checkpoints/01/4.pt


In [12]:
#Additional Info when using cuda

#print(torch.cuda.get_device_name(0))
#print('Memory Usage:')
#print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
#print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

#### Evaluate 

In [13]:

test_dataset = NerDataset(testset,bert_name = Bert_Model_name,data_name = Dataset_name)
test_iter = data.DataLoader(dataset=test_dataset,
                             batch_size=batch_size,
                             num_workers=0,
                             shuffle=False,
                             collate_fn=pad)

precision, recall, f1,true_tags,pred_tags = eval(model, test_iter, fname)

num_proposed:64606
num_correct:58436
num_gold:64811
precision=0.90
recall=0.90
f1=0.90


In [14]:
flatten_list_true_tags = [item for sublist in true_tags for item in sublist]
flatten_list_pred_tags = [item for sublist in pred_tags for item in sublist]

In [15]:
from sklearn.metrics import classification_report
print(classification_report(flatten_list_true_tags,flatten_list_pred_tags))

              precision    recall  f1-score   support

   B-problem       0.92      0.92      0.92     12592
      B-test       0.91      0.91      0.91      9225
 B-treatment       0.91      0.90      0.91      9344
   I-problem       0.90      0.92      0.91     17684
      I-test       0.90      0.88      0.89      8012
 I-treatment       0.88      0.85      0.87      7954
           O       0.98      0.98      0.98    203026

    accuracy                           0.96    267837
   macro avg       0.91      0.91      0.91    267837
weighted avg       0.96      0.96      0.96    267837



In [16]:
from seqeval.metrics import classification_report

print(classification_report(true_tags,pred_tags))

           precision    recall  f1-score   support

treatment       0.87      0.88      0.87      9344
  problem       0.86      0.88      0.87     12592
     test       0.87      0.88      0.87      9225

micro avg       0.86      0.88      0.87     31161
macro avg       0.86      0.88      0.87     31161



#### Save the model 

In [17]:
torch.save(model.state_dict(), saves_file )
