In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer, TFBertMainLayer, BertConfig

In [None]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU not found')
print('found GPU at {}'.format(device_name))

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')

In [None]:
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [None]:
bert_model = TFBertModel.from_pretrained("bert-base-cased")
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [None]:
MAX_TOKENS = 250

In [None]:
train_dataset = pd.read_csv('../../train_data/train_data_T.csv', delimiter=',')
validation_dataset = pd.read_csv('../../train_data/validation_data_T.csv', delimiter=',')

In [None]:
train_dataset.head(1)

In [None]:
train_title_ids = []
train_title_mask = []
train_title_seg = []

train_label = []

for i,row in tqdm(train_dataset.iterrows()):
    
    article_title_main_passage = str(row['article_page_title'])+" "+str(row['article_meta_description'])+" "+str(row['article_keywords'])
    table_title_main_passage = str(row['table_page_title'])+" "+str(row['table_page_summary'])+" "+str(row['table_page_keywords'])
    
    return_tokenizer1 = bert_tokenizer.encode_plus(
      article_title_main_passage,
      table_title_main_passage,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    train_title_ids.append(return_tokenizer1['input_ids'])
    train_title_mask.append(return_tokenizer1['attention_mask'])
    train_title_seg.append(return_tokenizer1['token_type_ids'])  
    
    train_label.append((row['label']))

In [None]:
train_title_ids = np.array(train_title_ids)
train_title_mask = np.array(train_title_mask)
train_title_seg = np.array(train_title_seg)

train_label = np.array(train_label)

In [None]:
validation_title_ids = []
validation_title_mask = []
validation_title_seg = []

validation_label = []

for i,row in tqdm(validation_dataset.iterrows()):
    
    article_title_main_passage = str(row['article_page_title'])+" "+str(row['article_meta_description'])+" "+str(row['article_keywords'])
    table_title_main_passage = str(row['table_page_title'])+" "+str(row['table_page_summary'])+" "+str(row['table_page_keywords'])
    
    return_tokenizer1 = bert_tokenizer.encode_plus(
      article_title_main_passage,
      table_title_main_passage,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    validation_title_ids.append(return_tokenizer1['input_ids'])
    validation_title_mask.append(return_tokenizer1['attention_mask'])
    validation_title_seg.append(return_tokenizer1['token_type_ids'])  
    
    validation_label.append((row['label']))

In [None]:
validation_title_ids = np.array(validation_title_ids)
validation_title_mask = np.array(validation_title_mask)
validation_title_seg = np.array(validation_title_seg)

validation_label = np.array(validation_label)

In [None]:
title_ids = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_ids', dtype='int32')
title_mask = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_mask', dtype='int32')
title_seg = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_seg', dtype='int32')

last_hidden_state, pooled_output = bert_model.bert([title_ids,title_mask,title_seg])

MLP_output = tf.keras.layers.Dense(1,activation='sigmoid')(pooled_output)

model = tf.keras.Model(inputs=[title_ids,title_mask,title_seg],outputs=MLP_output)  

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)

In [None]:
model.compile(loss="binary_crossentropy",optimizer=optimizer, metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
filepath="model_bert_fine_tuning_TMK_{epoch:02d}_{val_accuracy:.4f}.h5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_format='tf', save_best_only=True, mode='max')
callbacks_list = [checkpoint]

In [None]:
history = model.fit([train_title_ids, train_title_mask, train_title_seg], train_label, 
          epochs=5, 
          batch_size=16,
          verbose=1,
          validation_data=([validation_title_ids, validation_title_mask, validation_title_seg], validation_label),
        callbacks=callbacks_list)