In [13]:
!pip install bert-for-tf2
#!pip install jax jaxlib

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [14]:
from tensorflow import keras
import tensorflow as tf
from keras.utils import to_categorical
#import numpy as np
import numpy as np
import os
import pickle as pkl

train_dict = pkl.load(open("../input/comp4901k/train.pkl", "rb"))
val_dict = pkl.load(open("../input/comp4901k/val.pkl", "rb"))
test_dict = pkl.load(open("../input/comp4901k/test.pkl", "rb"))

print("keys in train_dict:", train_dict.keys())
print("keys in val_dict:", val_dict.keys())
print("keys in test_dict:", test_dict.keys())

keys in train_dict: dict_keys(['id', 'word_seq', 'tag_seq'])
keys in val_dict: dict_keys(['id', 'word_seq', 'tag_seq'])
keys in test_dict: dict_keys(['id', 'word_seq'])


In [15]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

# Set up BERT directory

In [16]:
import bert

bert_dir = "../input/biobert-large/biobert_large"
bert_ckpt = os.path.join(bert_dir, "bio_bert_large_1000k.ckpt")

In [17]:
bert_params = bert.params_from_pretrained_ckpt(bert_dir)
l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")
l_bert.trainable = False

# Prepare the data for training

In [18]:
from bert.tokenization.bert_tokenization import FullTokenizer
tokenizer = FullTokenizer(vocab_file=os.path.join(bert_dir, "vocab_cased_pubmed_pmc_30k.txt"))

def word2idx(word):
    if word == '_w_pad_':
        return 0
    if word in tokenizer.vocab:
        return tokenizer.vocab[word]
    elif word.lower() in tokenizer.vocab:
        return tokenizer.vocab[word.lower()]
    else:
        try:
            return tokenizer.vocab[tokenizer.tokenize(word)[-1]]
        except:
            return 1


In [19]:
train_tokens = np.vectorize(word2idx)(np.array(train_dict['word_seq']))
val_tokens = np.vectorize(word2idx)(np.array(val_dict['word_seq']))
test_tokens = np.vectorize(word2idx)(np.array(test_dict['word_seq']))

In [20]:
tag_dict = {'_t_pad_': 0} # add a padding token

for tag_seq in train_dict['tag_seq']:
    for tag in tag_seq:
        if(tag not in tag_dict):
            tag_dict[tag] = len(tag_dict)

tag2idx = tag_dict
idx2tag = {v:k for k,v in tag2idx.items()} 

tag_dict_size = len(tag_dict)

In [21]:
train_tags = [[tag2idx[t] for t in t_seq] for t_seq in train_dict['tag_seq']]
train_tags = np.array([to_categorical(t_seq, num_classes=len(tag_dict)) for t_seq in train_tags])

val_tags = [[tag2idx[t] for t in t_seq] for t_seq in val_dict['tag_seq']]
val_tags = np.array([to_categorical(t_seq, num_classes=len(tag_dict)) for t_seq in val_tags])

In [22]:
# Provided function to test accuracy
# You could check the validation accuracy to select the best of your models
def calc_accuracy(preds, tags, padding_id="_t_pad_"):
    """
        Input:
            preds (np.narray): (num_data, length_sentence)
            tags  (np.narray): (num_data, length_sentence)
        Output:
            Proportion of correct prediction. The padding tokens are filtered out.
    """
    preds_flatten = preds.flatten()
    tags_flatten = tags.flatten()
    non_padding_idx = np.where(tags_flatten!=padding_id)[0]
    
    return sum(preds_flatten[non_padding_idx]==tags_flatten[non_padding_idx])/len(non_padding_idx)

# Model

In [23]:
bert_params = bert.params_from_pretrained_ckpt(bert_dir)
l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")
l_bert.trainable = False

In [24]:
with tpu_strategy.scope():
    model = keras.models.Sequential([
      keras.layers.InputLayer(input_shape=(128,)),
      l_bert,
      keras.layers.Dense(128, activation='relu'),
      keras.layers.Dense(tag_dict_size, activation='softmax')
    ])
    model.build(input_shape=(None, 256))
    bert.load_bert_weights(l_bert, bert_ckpt)
    l_bert.apply_adapter_freeze()

Done loading 388 BERT weights from: ../input/biobert-large/biobert_large/bio_bert_large_1000k.ckpt into <bert.model.BertModelLayer object at 0x7f3c0ec64990> (prefix:bert_1). Count of weights not found in the checkpoint was: [0]. Count of weights with mismatched shape: [0]
Unused weights from checkpoint: 
	bert/embeddings/token_type_embeddings
	bert/pooler/dense/bias
	bert/pooler/dense/kernel
	global_step


In [25]:
model.summary() 

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bert (BertModelLayer)        (None, 128, 1024)         363247616 
_________________________________________________________________
dense_2 (Dense)              (None, 128, 128)          131200    
_________________________________________________________________
dense_3 (Dense)              (None, 128, 65)           8385      
Total params: 363,387,201
Trainable params: 139,585
Non-trainable params: 363,247,616
_________________________________________________________________


In [26]:
with tpu_strategy.scope():
    model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adam(learning_rate=0.02), metrics=['accuracy'])

In [27]:
num_epochs = 60
# with tpu_strategy.scope():
history = model.fit(train_tokens, 
                train_tags, 
                epochs=num_epochs, 
                batch_size=1024,
                validation_data=(val_tokens, val_tags), 
               # callbacks=[EarlyStopping(monitor='val_accuracy', patience=6, min_delta=0.0001, restore_best_weights=True)]
                )

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60


Epoch 58/60
Epoch 59/60
Epoch 60/60


In [28]:
tag_dict

{'_t_pad_': 0,
 'O': 1,
 'LIVESTOCK': 2,
 'DISEASE_OR_SYNDROME': 3,
 'GENE_OR_GENOME': 4,
 'CARDINAL': 5,
 'CHEMICAL': 6,
 'PRODUCT': 7,
 'QUANTITY': 8,
 'NORP': 9,
 'THERAPEUTIC_OR_PREVENTIVE_PROCEDURE': 10,
 'CELL': 11,
 'ORGANISM': 12,
 'GROUP': 13,
 'ORDINAL': 14,
 'GPE': 15,
 'ORG': 16,
 'LABORATORY_PROCEDURE': 17,
 'DATE': 18,
 'CORONAVIRUS': 19,
 'EUKARYOTE': 20,
 'SIGN_OR_SYMPTOM': 21,
 'VIRUS': 22,
 'CELL_COMPONENT': 23,
 'MOLECULAR_FUNCTION': 24,
 'CELL_OR_MOLECULAR_DYSFUNCTION': 25,
 'VIRAL_PROTEIN': 26,
 'HUMAN-CAUSED_PHENOMENON_OR_PROCESS': 27,
 'BODY_PART_ORGAN_OR_ORGAN_COMPONENT': 28,
 'PERSON': 29,
 'TISSUE': 30,
 'RESEARCH_ACTIVITY': 31,
 'EVENT': 32,
 'IMMUNE_RESPONSE': 33,
 'ORGAN_OR_TISSUE_FUNCTION': 34,
 'MATERIAL': 35,
 'EVOLUTION': 36,
 'LABORATORY_OR_TEST_RESULT': 37,
 'BACTERIUM': 38,
 'MONEY': 39,
 'FAC': 40,
 'DAILY_OR_RECREATIONAL_ACTIVITY': 41,
 'ANATOMICAL_STRUCTURE': 42,
 'CELL_FUNCTION': 43,
 'SUBSTRATE': 44,
 'INDIVIDUAL_BEHAVIOR': 45,
 'BODY_SUBSTANCE'