In [None]:
# install required packages
!pip install transformers



In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# set path for files
path = '/content/drive/My Drive/thesis_dataset/'

In [None]:
# import all required packages/modules
import csv
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Dense, Flatten
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from transformers import TFBertModel, TFPreTrainedModel
from transformers import BertTokenizer, BertConfig
from tensorflow.keras.models import load_model

In [None]:
# define constants
MAX_TOKENS = 64
BERT_PRETRAIN_MODEL_NAME = "bert-base-cased"
NR_EPOCHS = 50
BATCH_SIZE, BUFFER_SIZE = 32, 10000
REPEAT, PREFETCH = 5, 1

In [None]:
# read data from excel
df = pd.read_excel(path+"All_Questions_V1.xlsx",'data', encoding='utf-8') 
df.head(1)

Unnamed: 0,SlNo,Question,Relation,NER_Tag,Q_Len,T_Len,Subject,Subject_URI,Relation_URI
0,1,what are the brand names of Metipranolol,brand,O O O O O O B-E,7,7,Metipranolol,http://bio2rdf.org/drugbank:DB01214,http://bio2rdf.org/drugbank_vocabulary:brand


In [None]:
# split the full dataset into train, valid and test dataset
rest, test = train_test_split(df, test_size=0.2, random_state=0, 
                               stratify=df['Relation'])
train, valid = train_test_split(rest, test_size=0.1, random_state=0, 
                               stratify=rest['Relation'])
train_size, test_size, validation_size = len(train), len(test), len(valid)
print(f'Train:{train_size}, Test: {test_size}, Validation: {validation_size}')

Train:406, Test: 114, Validation: 46


In [None]:
# create instance of tokenzier from BERT pretrained model
tokenizer = BertTokenizer.from_pretrained(BERT_PRETRAIN_MODEL_NAME, do_lower_case=True)

In [None]:
# process the question phrase, labels to return input_ids, attention_masks, one-hot-encoded labels and label names
def process_data(df_data, tokenizer, max_tokens, train=False):
  # process labels only for training data
  if(train):
    df_class = pd.get_dummies(df_data, columns=["Relation"], prefix=[""], prefix_sep="" )
    df_class.head(1)
    column_names = df_class.columns.to_list()
    label_names = column_names[8:]
    onehot_labels =  df_class[label_names].values
  else:
    onehot_labels, label_names = [], []

  # process data and provide input_ids and attention_masks
  tokens_list = []
  attn_masks_list = []
  for question in tqdm(df_data['Question']):
      tokens = tokenizer.encode(question, max_length = max_tokens, truncation=True, add_special_tokens = True)
      tokens_list.append(tokens)
  # we use post padding for BERT
  padded_tokens_list = pad_sequences(tokens_list, maxlen=max_tokens, truncating="post", padding="post", dtype="long", value=0)

  # create atttion masks
  for tokens in padded_tokens_list:
      attn_masks = [int(token > 0) for token in tokens]
      attn_masks_list.append(attn_masks)

  return padded_tokens_list, np.asarray(attn_masks_list), np.asarray(onehot_labels), label_names

In [None]:
# process question phrases, labels to get input_ids, attention_masks for BERT input and onehot labels
train_input_ids, train_attention_masks, train_labels, labels = process_data(train, tokenizer, MAX_TOKENS, True)
valid_input_ids, valid_attention_masks, valid_labels, _  = process_data(valid, tokenizer, MAX_TOKENS, True)
num_class = len(labels)

HBox(children=(FloatProgress(value=0.0, max=406.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))




In [None]:
#Function for creating and updating parameters of dataset using input tensors
def load_dataset(dataset, train=True):
    dataset_loader = tf.data.Dataset.from_tensor_slices(dataset)
    if train:
      dataset_loader = dataset_loader.shuffle(buffer_size=BUFFER_SIZE)
      dataset_loader = dataset_loader.repeat(REPEAT)
      dataset_loader = dataset_loader.prefetch(PREFETCH)
    dataset_loader = dataset_loader.batch(BATCH_SIZE)
    return dataset_loader

In [None]:
# cerate dataset from BERT inputs
train_dataset_loader = load_dataset((train_input_ids, train_attention_masks, train_labels))
valid_dataset_loader = load_dataset((valid_input_ids, valid_attention_masks, valid_labels))

In [None]:
config_params = BertConfig.from_pretrained(BERT_PRETRAIN_MODEL_NAME)

In [None]:
# create a class for relation clssifer
# adapt from BERT base model, freeze the base layers (make them non-trainable)
# build a top classifier layer with input as CLS token output
class RelationClassifier(TFPreTrainedModel):    
    def __init__(self, base: TFBertModel, num_relations: int):
        super().__init__(config_params)
        self.base = base
        self.base.trainable= False
        self.top_classifier = Dense(num_relations, activation='softmax')
        
    @tf.function
    def call(self, input_ids, attention_mask):
        outputs = self.base(input_ids, attention_mask=attention_mask, token_type_ids=None,
                               position_ids=None, head_mask=None)
        cls_token_output = outputs[1]
        cls_token_output = self.top_classifier(cls_token_output)
        return cls_token_output

In [None]:
# create a model from relation classifier class
bert_base_model = TFBertModel.from_pretrained(BERT_PRETRAIN_MODEL_NAME)
model = RelationClassifier(bert_base_model, num_class)

Some weights of the model checkpoint at bert-base-cased were not used when initializing TFBertModel: ['mlm___cls', 'nsp___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the model checkpoint at bert-base-cased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [None]:
# check if base layers are frozen
for layer in model.layers:
  print(layer.trainable)

False
True


In [None]:
# define loss object, metrices, optimizer and training/validation steps
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=False) 
train_loss, validation_loss = tf.keras.metrics.Mean(name='train_loss'), tf.keras.metrics.Mean(name='test_loss')
validation_accuracy = tf.keras.metrics.CategoricalAccuracy(name='accuracy')
optimizer = tf.keras.optimizers.Adam(clipnorm=1)
steps_per_epoch = int(train_size / BATCH_SIZE)
validation_steps = int(validation_size / BATCH_SIZE)

In [None]:
# define function for training / validation of model in the epoch run
@tf.function
def model_training(model, input_ids, attn_masks, onehot_labels, train = True):
    act_labels = tf.dtypes.cast(onehot_labels, tf.float32)
    # while train loop, calculate loss and update all parameters for all layers
    if train:
      with tf.GradientTape() as tape:
          pred_labels = model(input_ids, attn_masks)
          training_loss = loss_object(act_labels, pred_labels)
      training_gradients = tape.gradient(training_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(training_gradients, model.trainable_variables))
      train_loss(training_loss)
    # while validation loop, predict labels, calculate loss and accuracy  
    else:
      pred_labels = model(input_ids, attn_masks, training=train)
      valid_loss = loss_object(act_labels, pred_labels)
      validation_loss(valid_loss)
      validation_accuracy.update_state(act_labels, pred_labels)


In [None]:
# train and validate the model for number of epoches
for epoch_num in range(NR_EPOCHS):
    print(f'Epoch Number: {epoch_num+1}')
    for i, (input_ids, attn_masks, act_labels) in enumerate(tqdm(train_dataset_loader, total=steps_per_epoch)):
        model_training(model, input_ids, attn_masks, act_labels, train=True)        
    for i, (input_ids, attn_masks, act_labels) in enumerate(tqdm(valid_dataset_loader, total=validation_steps)):
        model_training(model, input_ids, attn_masks, act_labels, train=False)
    print(f'Training Loss: {train_loss.result()}')
    print(f'Validation Loss: {validation_loss.result()}')
    print(f'Validation Accuracy: {validation_accuracy.result().numpy()}')
    print(f'_______________________________________________________________________________')

Epoch Number: 1


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.6293084621429443
Validation Loss: 3.5044546127319336
Validation Accuracy: 0.06521739065647125
_______________________________________________________________________________
Epoch Number: 2


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.493555784225464
Validation Loss: 3.442646026611328
Validation Accuracy: 0.08695652335882187
_______________________________________________________________________________
Epoch Number: 3


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.4009082317352295
Validation Loss: 3.3950719833374023
Validation Accuracy: 0.11594203114509583
_______________________________________________________________________________
Epoch Number: 4


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.323294162750244
Validation Loss: 3.344515562057495
Validation Accuracy: 0.125
_______________________________________________________________________________
Epoch Number: 5


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.257969379425049
Validation Loss: 3.3026435375213623
Validation Accuracy: 0.1608695685863495
_______________________________________________________________________________
Epoch Number: 6


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.195317029953003
Validation Loss: 3.26389479637146
Validation Accuracy: 0.1666666716337204
_______________________________________________________________________________
Epoch Number: 7


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.1347458362579346
Validation Loss: 3.2180190086364746
Validation Accuracy: 0.18633539974689484
_______________________________________________________________________________
Epoch Number: 8


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.0795133113861084
Validation Loss: 3.1783933639526367
Validation Accuracy: 0.20652173459529877
_______________________________________________________________________________
Epoch Number: 9


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 3.027129888534546
Validation Loss: 3.1443793773651123
Validation Accuracy: 0.21739129722118378
_______________________________________________________________________________
Epoch Number: 10


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.977323055267334
Validation Loss: 3.1150035858154297
Validation Accuracy: 0.23043477535247803
_______________________________________________________________________________
Epoch Number: 11


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.9303388595581055
Validation Loss: 3.0851457118988037
Validation Accuracy: 0.24110671877861023
_______________________________________________________________________________
Epoch Number: 12


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.8846113681793213
Validation Loss: 3.05289363861084
Validation Accuracy: 0.25
_______________________________________________________________________________
Epoch Number: 13


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.841395854949951
Validation Loss: 3.018585205078125
Validation Accuracy: 0.25919732451438904
_______________________________________________________________________________
Epoch Number: 14


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.8015098571777344
Validation Loss: 2.989011764526367
Validation Accuracy: 0.27018633484840393
_______________________________________________________________________________
Epoch Number: 15


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.7619926929473877
Validation Loss: 2.959892511367798
Validation Accuracy: 0.27971014380455017
_______________________________________________________________________________
Epoch Number: 16


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.724057197570801
Validation Loss: 2.933070659637451
Validation Accuracy: 0.2866847813129425
_______________________________________________________________________________
Epoch Number: 17


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.6882643699645996
Validation Loss: 2.9103145599365234
Validation Accuracy: 0.2928388714790344
_______________________________________________________________________________
Epoch Number: 18


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.6529784202575684
Validation Loss: 2.8841025829315186
Validation Accuracy: 0.2995169162750244
_______________________________________________________________________________
Epoch Number: 19


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.6182994842529297
Validation Loss: 2.8597874641418457
Validation Accuracy: 0.30663615465164185
_______________________________________________________________________________
Epoch Number: 20


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.584747314453125
Validation Loss: 2.83626651763916
Validation Accuracy: 0.31413042545318604
_______________________________________________________________________________
Epoch Number: 21


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.5526249408721924
Validation Loss: 2.811387300491333
Validation Accuracy: 0.3219461739063263
_______________________________________________________________________________
Epoch Number: 22


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.521479845046997
Validation Loss: 2.7890303134918213
Validation Accuracy: 0.3290513753890991
_______________________________________________________________________________
Epoch Number: 23


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.491140365600586
Validation Loss: 2.768364667892456
Validation Accuracy: 0.3336483836174011
_______________________________________________________________________________
Epoch Number: 24


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.4617655277252197
Validation Loss: 2.7460243701934814
Validation Accuracy: 0.3378623127937317
_______________________________________________________________________________
Epoch Number: 25


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.433051347732544
Validation Loss: 2.723766565322876
Validation Accuracy: 0.3426086902618408
_______________________________________________________________________________
Epoch Number: 26


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.405400276184082
Validation Loss: 2.704176425933838
Validation Accuracy: 0.3520067036151886
_______________________________________________________________________________
Epoch Number: 27


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.3782284259796143
Validation Loss: 2.684039831161499
Validation Accuracy: 0.3550724685192108
_______________________________________________________________________________
Epoch Number: 28


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.3523201942443848
Validation Loss: 2.6664950847625732
Validation Accuracy: 0.36335402727127075
_______________________________________________________________________________
Epoch Number: 29


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.3267099857330322
Validation Loss: 2.645251989364624
Validation Accuracy: 0.3703148365020752
_______________________________________________________________________________
Epoch Number: 30


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.3018500804901123
Validation Loss: 2.6310222148895264
Validation Accuracy: 0.3753623068332672
_______________________________________________________________________________
Epoch Number: 31


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.2777321338653564
Validation Loss: 2.6124346256256104
Validation Accuracy: 0.38078540563583374
_______________________________________________________________________________
Epoch Number: 32


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.253848075866699
Validation Loss: 2.593419313430786
Validation Accuracy: 0.38790759444236755
_______________________________________________________________________________
Epoch Number: 33


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.2303335666656494
Validation Loss: 2.579267978668213
Validation Accuracy: 0.39328062534332275
_______________________________________________________________________________
Epoch Number: 34


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.207709312438965
Validation Loss: 2.5624072551727295
Validation Accuracy: 0.39897698163986206
_______________________________________________________________________________
Epoch Number: 35


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.1857411861419678
Validation Loss: 2.5462255477905273
Validation Accuracy: 0.40434783697128296
_______________________________________________________________________________
Epoch Number: 36


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.1642658710479736
Validation Loss: 2.5294198989868164
Validation Accuracy: 0.4082125723361969
_______________________________________________________________________________
Epoch Number: 37


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.14326810836792
Validation Loss: 2.514662027359009
Validation Accuracy: 0.41304346919059753
_______________________________________________________________________________
Epoch Number: 38


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.1225979328155518
Validation Loss: 2.497342109680176
Validation Accuracy: 0.41590389609336853
_______________________________________________________________________________
Epoch Number: 39


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.1025569438934326
Validation Loss: 2.481828212738037
Validation Accuracy: 0.4197324514389038
_______________________________________________________________________________
Epoch Number: 40


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.0831055641174316
Validation Loss: 2.4697163105010986
Validation Accuracy: 0.42391303181648254
_______________________________________________________________________________
Epoch Number: 41


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.063549041748047
Validation Loss: 2.456498384475708
Validation Accuracy: 0.42841994762420654
_______________________________________________________________________________
Epoch Number: 42


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.0444769859313965
Validation Loss: 2.4438397884368896
Validation Accuracy: 0.4316770136356354
_______________________________________________________________________________
Epoch Number: 43


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.02595853805542
Validation Loss: 2.4284636974334717
Validation Accuracy: 0.43478259444236755
_______________________________________________________________________________
Epoch Number: 44


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 2.0078084468841553
Validation Loss: 2.4164326190948486
Validation Accuracy: 0.43922924995422363
_______________________________________________________________________________
Epoch Number: 45


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 1.9899898767471313
Validation Loss: 2.405299186706543
Validation Accuracy: 0.44251206517219543
_______________________________________________________________________________
Epoch Number: 46


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 1.9726943969726562
Validation Loss: 2.392132043838501
Validation Accuracy: 0.44612476229667664
_______________________________________________________________________________
Epoch Number: 47


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 1.9554518461227417
Validation Loss: 2.3793349266052246
Validation Accuracy: 0.4491211771965027
_______________________________________________________________________________
Epoch Number: 48


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 1.9386239051818848
Validation Loss: 2.368896245956421
Validation Accuracy: 0.4524456560611725
_______________________________________________________________________________
Epoch Number: 49


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 1.922003149986267
Validation Loss: 2.3573713302612305
Validation Accuracy: 0.4556344151496887
_______________________________________________________________________________
Epoch Number: 50


HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Training Loss: 1.9058620929718018
Validation Loss: 2.3449366092681885
Validation Accuracy: 0.4582608640193939
_______________________________________________________________________________


In [None]:
# try one question and find the predicted relation
question = (['what is the salt of choloroform'])
df_test = pd.DataFrame(question, columns=['Question'])
test_steps = int(len(df_test) / BATCH_SIZE)
test_input_ids, test_attention_masks, _, _ = process_data(df_test, tokenizer, MAX_TOKENS, False)
test_dataset_loader = load_dataset((test_input_ids, test_attention_masks),False)

for i, (token_ids, masks) in enumerate(tqdm(test_dataset_loader, total=test_steps)):
    predictions = model(token_ids, attention_mask=masks).numpy()
    print(predictions)
    max_col = np.argmax(predictions)
    print(max_col)
    print(np.max(predictions))
    print(labels[max_col])

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[[2.7240997e-03 1.4948578e-02 4.2331535e-03 6.6105565e-03 2.4193498e-05
  1.6857095e-03 3.3162109e-05 2.0937969e-03 1.6808147e-02 1.0609289e-03
  8.6686005e-06 1.6113143e-02 1.7795436e-02 2.2446022e-03 2.3474140e-01
  3.5045218e-02 2.6947301e-04 6.1581831e-02 3.4395590e-02 9.1722861e-02
  4.3541435e-03 2.9493760e-02 7.1787253e-02 3.9112408e-02 2.5379297e-03
  9.0293838e-03 5.6243071e-06 4.0882789e-02 1.8454473e-01 1.2302107e-05
  3.7263673e-02 2.0052677e-02 2.9897087e-04 1.7780538e-03 1.2697838e-03
  2.3833776e-03 1.1052560e-02]]
14
0.2347414
ingredient



In [None]:
# define function for evaluating any given dataset
def evaluate(df_test):
  # create input for BERT Model
  test_steps = int(len(df_test) / BATCH_SIZE)
  test_input_ids, test_attention_masks, _, _ = process_data(df_test, tokenizer, MAX_TOKENS, False)
  test_dataset_loader = load_dataset((test_input_ids, test_attention_masks),False)

  # predict the relations
  pred_labels =[]
  for i, (token_ids, masks) in enumerate(tqdm(test_dataset_loader, total=test_steps)):
      predictions = model(token_ids, attention_mask=masks).numpy()
      for i in range(len(predictions)):
        max_col = np.argmax(predictions[i])
        pred_labels.append(labels[max_col])
  # print actual and predicted relations      
  print(df_test['Relation'].values.tolist())
  print(pred_labels)
  # calculate and print accuracy
  print(accuracy_score(df_test['Relation'].values.tolist(),pred_labels))


In [None]:
print(f'--------------------   Validation Dataset   --------------------')
evaluate(valid)

--------------------   Validation Dataset   --------------------


HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


['volume-of-distribution', 'locus', 'ddi-interactor-in', 'kingdom', 'general-function', 'gene-name', 'biotransformation', 'patent', 'food-interaction', 'mixture', 'mixture', 'packager', 'synonym', 'clearance', 'affected-organism', 'route-of-elimination', 'group', 'locus', 'category', 'volume-of-distribution', 'product', 'target', 'theoretical-pi', 'general-function', 'kingdom', 'molecular-weight', 'transporter', 'protein-binding', 'toxicity', 'product', 'pharmacology', 'brand', 'manufacturer', 'specific-function', 'organism', 'mechanism-of-action', 'dosage', 'salt', 'indication', 'cellular-location', 'protein-binding', 'half-life', 'substructure', 'indication', 'gene-name', 'ingredient']
['molecular-weight', 'locus', 'ddi-interactor-in', 'substructure', 'general-function', 'gene-name', 'cellular-location', 'patent', 'food-interaction', 'mixture', 'mixture', 'manufacturer', 'group', 'volume-of-distribution', 'affected-organism', 'route-of-elimination', 'group', 'locus', 'category', 'to

In [None]:
print(f'--------------------   Testing Dataset   --------------------')
evaluate(test)

--------------------   Testing Dataset   --------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


['patent', 'manufacturer', 'synonym', 'mixture', 'transporter', 'toxicity', 'general-function', 'theoretical-pi', 'kingdom', 'group', 'indication', 'pharmacology', 'gene-name', 'target', 'synonym', 'general-function', 'patent', 'cellular-location', 'route-of-elimination', 'general-function', 'substructure', 'category', 'toxicity', 'patent', 'product', 'substructure', 'salt', 'general-function', 'half-life', 'group', 'brand', 'indication', 'mechanism-of-action', 'synonym', 'affected-organism', 'gene-name', 'volume-of-distribution', 'affected-organism', 'product', 'indication', 'volume-of-distribution', 'mixture', 'locus', 'mixture', 'packager', 'half-life', 'molecular-weight', 'ingredient', 'specific-function', 'ddi-interactor-in', 'cellular-location', 'molecular-weight', 'protein-binding', 'organism', 'dosage', 'organism', 'locus', 'volume-of-distribution', 'manufacturer', 'transporter', 'molecular-weight', 'mechanism-of-action', 'theoretical-pi', 'product', 'route-of-elimination', 's

In [None]:
model.summary()

Model: "relation_classifier"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
tf_bert_model (TFBertModel)  multiple                  108310272 
_________________________________________________________________
dense (Dense)                multiple                  28453     
Total params: 108,338,725
Trainable params: 28,453
Non-trainable params: 108,310,272
_________________________________________________________________


**References**

Followed Examples from


---

https://www.depends-on-the-definition.com/named-entity-recognition-with-bert/

https://mccormickml.com/2019/07/22/BERT-fine-tuning/

http://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/

https://www.kaggle.com/nkaenzig/bert-tensorflow-2-huggingface-transformers

https://colab.research.google.com/drive/1ZQvuAVwA3IjybezQOXnrXMGAnMyZRuPU#scrollTo=tBa6vRHknSkv


---

