In [1]:
import tensorflow as tf
# detect and init the TPU
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

# instantiate a distribution strategy
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

In [2]:
import numpy as np
import json

import pandas as pd

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing import text, sequence
from tensorflow.keras.layers import Input, Embedding, Bidirectional, Dense, LSTM, TimeDistributed, Lambda, SpatialDropout1D, Layer
#from tensorflow.keras.layers import Conv1D, GlobalMaxPooling1D, GlobalAveragePooling1D, concatenate, SpatialDropout1D
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam, RMSprop

import tensorflow_addons as tfa

from keras import backend as K

from sklearn.metrics import f1_score, precision_score, recall_score

from transformers import BertTokenizer, BertConfig, TFBertForTokenClassification, TFBertModel

import warnings

In [63]:
def reformat_data(data_file):
    with open(data_file, 'r') as file:
        article_sentences, article_labels = [], []
        sentence_tokens, sentence_labels = [], []
        
        for line in file.readlines():
            if "-DOCSTART-" in line:
                if sentence_labels != []:
                    article_sentences.append(sentence_tokens)
                    article_labels.append(sentence_labels)
                    sentence_tokens, sentence_labels = [], []
            else:
                try:
                    token = line.split("\t")[0]
                    label = line.split("\t")[3][:-1]
                    sentence_tokens.append(token)
                    sentence_labels.append(label)
                except:
                    if sentence_labels != []:
                        article_sentences.append(sentence_tokens)
                        article_labels.append(sentence_labels)
                        sentence_tokens, sentence_labels = [], []
                
    return article_sentences, article_labels
    
train_sentences, train_detect_labels = reformat_data("../input/medlinker-data/mm_ner_ent.train.conll")
test_sentences, test_detect_labels = reformat_data("../input/medlinker-data/mm_ner_ent.test.conll")
dev_sentences, dev_detect_labels = reformat_data("../input/medlinker-data/mm_ner_ent.dev.conll")
_, train_recog_labels = reformat_data("../input/medlinker-data/mm_ner_sts.train.conll")
_, test_recog_labels = reformat_data("../input/medlinker-data/mm_ner_sts.test.conll")
_, dev_recog_labels = reformat_data("../input/medlinker-data/mm_ner_sts.dev.conll")

In [4]:
def unique_words():
    dict_ = {}
    lengths = []
    sent = []
    i = 0
    for txt in [train_sentences, test_sentences, dev_sentences]:
        for sentence in txt:
            lengths.append(len(sentence))
            sent.append(sentence)
            for word in np.unique(sentence):
                if word.lower() not in dict_.keys():
                    i+=1
                    dict_[word.lower()] = i
                    
    return dict_, np.max(lengths), sent
            
tokens_dict, maxlen, sent = unique_words()
maxlen

178

In [5]:
len(tokens_dict)

54563

In [6]:
label_dict = {}
i = 0
for sent_labels in train_detect_labels:
    for label in sent_labels:
        if label not in label_dict.keys():
            i+=1
            label_dict[label] = i 

In [7]:
label_dict['[PAD]'] = 0
label_dict

{'B-Entity': 1, 'O': 2, 'I-Entity': 3, '[PAD]': 0}

In [8]:
warnings.filterwarnings('ignore')
"""This ignored warning because precision and recall give warnings
that not all the true labels are represented in the predictions"""

def exclude_from_f1(y_true, y_pred, excluded_tags=[0]):
    ytrue, yhat = [], []
    for y_t, y_p in zip(y_true, y_pred):
        if y_t not in excluded_tags:
            ytrue.append(y_t)
            yhat.append(y_p)
    f1 = f1_score(ytrue, yhat, average='micro')
    return f1

def exclude_from_precision(y_true, y_pred, excluded_tags=[0]):
    ytrue, yhat = [], []
    for y_t, y_p in zip(y_true, y_pred):
        if y_t not in excluded_tags:
            ytrue.append(y_t)
            yhat.append(y_p)
    precision = precision_score(ytrue, yhat, average='micro')
    return precision

def exclude_from_recall(y_true, y_pred, excluded_tags=[0]):
    ytrue, yhat = [], []
    for y_t, y_p in zip(y_true, y_pred):
        if y_t not in excluded_tags:
            ytrue.append(y_t)
            yhat.append(y_p)
    recall = recall_score(ytrue, yhat, average='micro')
    return recall

In [9]:
def mask(m, q):
    # Assumes m is 2D
    mask = tf.math.reduce_any(tf.not_equal(m, q), axis=-1)
    #return tf.boolean_mask(m, mask)
    return mask

def recall(y_true, y_pred):
    pad = tf.constant([0 for i in range(4)], dtype=tf.float32)
    mask_ = mask(y_true, pad)
    masked_y_data = tf.boolean_mask(y_true, mask_)
    masked_y_pred = tf.boolean_mask(y_pred, mask_)
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision(y_true, y_pred):
    pad = tf.constant([0 for i in range(4)], dtype=tf.float32)
    mask_ = mask(y_true, pad)
    masked_y_data = tf.boolean_mask(y_true, mask_)
    masked_y_pred = tf.boolean_mask(y_pred, mask_)
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def f1(y_true, y_pred):
    precision_ = precision(y_true, y_pred)
    recall_ = recall(y_true, y_pred)
    return 2*((precision_*recall_)/(precision_+recall_+K.epsilon()))

## Tokenizer

In [10]:
%%capture
#!wget "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar"
!wget -O umlsbert.tar.xz https://www.dropbox.com/s/qaoq5gfen69xdcc/umlsbert.tar.xz?dl=0
#!tar -xf scibert_scivocab_uncased.tar
!tar -xvf umlsbert.tar.xz


In [11]:
tokenizer = BertTokenizer.from_pretrained('./umlsbert', do_lower_case=True)

In [12]:
tokenizer

PreTrainedTokenizer(name_or_path='./umlsbert', vocab_size=28996, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [64]:
train_seq = sequence.pad_sequences(train_sentences, dtype=object, maxlen=maxlen, padding='post', value='[PAD]')
dev_seq = sequence.pad_sequences(dev_sentences, dtype=object, maxlen=maxlen, padding='post', value='[PAD]') 
test_seq = sequence.pad_sequences(test_sentences, dtype=object, maxlen=maxlen, padding='post', value='[PAD]') 
train_seq_tokenized = [tokenizer.convert_tokens_to_ids(s) for s in train_seq]
dev_seq_tokenized = [tokenizer.convert_tokens_to_ids(s) for s in dev_seq]
test_seq_tokenized = [tokenizer.convert_tokens_to_ids(s) for s in test_seq]

## Lebel tokenizer

In [65]:
train_labels, dev_labels, test_labels = train_detect_labels, dev_detect_labels, test_detect_labels

for i, labels in enumerate(train_detect_labels):
    for j, label in enumerate(labels):
        train_labels[i][j] = label_dict[label]
for i, labels in enumerate(dev_detect_labels):
    for j, label in enumerate(labels):
        dev_labels[i][j] = label_dict[label]
for i, labels in enumerate(test_detect_labels):
    for j, label in enumerate(labels):
        test_labels[i][j] = label_dict[label]
        
train_labels_ohe = [to_categorical(i, num_classes=4) for i in train_labels]
dev_labels_ohe = [to_categorical(i, num_classes=4) for i in dev_labels]
test_labels_ohe = [to_categorical(i, num_classes=4) for i in test_labels]

train_labels = sequence.pad_sequences(train_labels_ohe, maxlen=maxlen, dtype='int32', padding='post')
dev_labels = sequence.pad_sequences(dev_labels_ohe, maxlen=maxlen, dtype='int32', padding='post')
test_labels = sequence.pad_sequences(test_labels_ohe, maxlen=maxlen, dtype='int32', padding='post')

train_labels = np.array(train_labels)
dev_labels = np.array(dev_labels)
test_labels = np.array(test_labels)

## Build mask to ignore padded values

In [61]:
train_mask = [[1]*len(sent)+[0]*(maxlen - len(sent)) for sent in train_sentences]
train_mask = tf.cast(train_mask,tf.int32)
dev_mask = [[1]*len(sent)+[0]*(maxlen - len(sent)) for sent in dev_sentences]
dev_mask = tf.cast(dev_mask,tf.int32)
test_mask = [[1]*len(sent)+[0]*(maxlen - len(sent)) for sent in test_sentences]
test_mask = tf.cast(test_mask,tf.int32)

## Cast sequences into tensors

In [66]:
train_seq = train_seq_tokenized
train_seq = tf.cast(train_seq, tf.int32)
dev_seq = dev_seq_tokenized
dev_seq = tf.cast(dev_seq, tf.int32)
test_seq = test_seq_tokenized
test_seq = tf.cast(test_seq, tf.int32)
train_labels = tf.cast(train_labels, tf.int32)
dev_labels = tf.cast(dev_labels, tf.int32)
test_labels = tf.cast(test_labels, tf.int32)

In [67]:
print(train_seq.shape)
print(train_mask.shape)
print(train_labels.shape)
print(dev_seq.shape)
print(dev_mask.shape)
print(dev_labels.shape)
print(test_seq.shape)
print(test_mask.shape)
print(test_labels.shape)

(27892, 178)
(27892, 178)
(27892, 178, 4)
(9219, 178)
(9219, 178)
(9219, 178, 4)
(9283, 178)
(9283, 178)
(9283, 178, 4)


## Scibert LSTM

In [18]:
!ls ./umlsbert/

config.json  pytorch_model.bin	vocab.txt


In [35]:
def build_bert_lstm_model():
    config = BertConfig.from_json_file('./umlsbert/config.json')
    #config.return_dict=True
    config.output_hidden_states=True
    encoder = TFBertModel.from_pretrained("./umlsbert/", from_pt=True, name='scibert', config=config)
    encoder.bert.trainable = False

    input_ids = Input(shape=(maxlen,), dtype=tf.int32)

    attention_mask = Input(shape=(maxlen,), dtype=tf.int32)
    outputs = encoder(input_ids, attention_mask=attention_mask)
    last_hidden_state = outputs.hidden_states[0]
        
    outputs = Bidirectional(LSTM(units=128, return_sequences=True, dropout=0.5, recurrent_dropout=0.5), 
                            merge_mode = 'concat')(last_hidden_state)
    outputs = LSTM(units=128, dropout=0.5,  return_sequences=True, recurrent_dropout=0.5)(outputs)

    outputs = Dense(len(label_dict), activation='softmax', name='output')(outputs)
    
    bert_lstm_model = Model([input_ids, attention_mask], outputs)

    bert_lstm_model.summary()
    
    return bert_lstm_model

In [36]:
tf.random.set_seed(42)
opt = Adam(0.001)

with tpu_strategy.scope(): 
    bert_lstm_model = build_bert_lstm_model()
    bert_lstm_model.compile(loss = 'CategoricalCrossentropy', optimizer=opt, metrics=[f1, recall, precision])

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'bert.embeddings.tui_type_embeddings.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained o

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 178)]        0                                            
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 178)]        0                                            
__________________________________________________________________________________________________
scibert (TFBertModel)           TFBaseModelOutputWit 108310272   input_7[0][0]                    
                                                                 input_8[0][0]                    
__________________________________________________________________________________________________
bidirectional_3 (Bidirectional) (None, 178, 256)     918528      scibert[0][0]              

In [41]:
bert_lstm_his = bert_lstm_model.fit((train_seq, train_mask), train_labels, 
                                    epochs=10, batch_size=64, 
                                    validation_data=((dev_seq, dev_mask), dev_labels))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## Optimization

-- best results

LR : 0.001

epochs : 20

units: 128

-----------

LR : 0.002

epochs : 10

units: 64

---------

LR : 0.0005

epochs : 10

units: 64

---------

LR : 0.005

epochs : 10

units: 64

---------

LR : 0.001

epochs : 10

units: 64

## Eval

In [23]:
import warnings
warnings.filterwarnings('ignore')
"""This ignored warning because precision and recall give warnings
that not all the true labels are represented in the predictions"""

def exclude_from_f1(y_true, y_pred, excluded_tags=[0]):
    ytrue, yhat = [], []
    for y_t, y_p in zip(np.array(y_true).flatten(), np.array(y_pred).flatten()):
        if y_t not in excluded_tags:
            ytrue.append(y_t)
            yhat.append(y_p)
    f1 = f1_score(ytrue, yhat, average='micro')
    return f1

def exclude_from_precision(y_true, y_pred, excluded_tags=[0]):
    ytrue, yhat = [], []
    for y_t, y_p in zip(y_true, y_pred):
        if y_t not in excluded_tags:
            ytrue.append(y_t)
            yhat.append(y_p)
    precision = precision_score(ytrue, yhat, average='micro')
    return precision

def exclude_from_recall(y_true, y_pred, excluded_tags=[0]):
    ytrue, yhat = [], []
    for y_t, y_p in zip(y_true, y_pred):
        if y_t not in excluded_tags:
            ytrue.append(y_t)
            yhat.append(y_p)
    recall = recall_score(ytrue, yhat, average='macro')
    return recall

In [79]:
test_pred = np.argmax(bert_lstm_model.predict((test_seq, test_mask)), axis=-1)
test_labels_argmax = np.argmax(test_labels, axis=-1)
# This is the f1 for non other entities
exclude_from_f1(test_labels_argmax, test_pred, [0, 2])

0.7248263888888887

## Example

In [77]:
inv_label_map = {v: k for k, v in label_dict.items()}

y_pred = np.argmax(bert_lstm_model.predict((test_seq[1030], test_mask[1030])), axis=-1)
print("{0:35} {1:40} {2:40}".format('Extracted Entity', 'Actual Label', 'Predicted Label'))
print("{0:35} {1:40} {2:40}".format('________________', '____________', '_______________'))
for x, y, yhat in zip(test_sentences[1030], test_labels[1030], y_pred):
    if x != 0:
        print("{0:35} {1:40} {2:40}".format(np.array(x), inv_label_map[np.argmax(np.array(y))], inv_label_map[yhat[0]]))

Extracted Entity                    Actual Label                             Predicted Label                         
________________                    ____________                             _______________                         
Moreover                            O                                        O                                       
,                                   O                                        O                                       
stratification                      B-Entity                                 B-Entity                                
analyses                            I-Entity                                 B-Entity                                
indicated                           O                                        O                                       
that                                O                                        O                                       
the                                 O                   