In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer, TFBertMainLayer

In [None]:
tf.__version__

In [None]:
!python --version

In [None]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU not found')
print('found GPU at {}'.format(device_name))

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')

In [None]:
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [None]:
bert_model = TFBertModel.from_pretrained("bert-base-cased")
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [None]:
bert_model.config

In [None]:
model_bert_layer = TFBertMainLayer(bert_model.config)

In [None]:
train_data = pd.read_csv('../train_data/train_triple_all_signals.csv', delimiter=',')
validation_data = pd.read_csv('../train_data/validation_triple_all_signals.csv', delimiter=',')

In [None]:
# train_data = train_data.head(100)
# validation_data = validation_data.head(100)

In [None]:
Y_train_dummy = np.empty(len(train_data))
Y_validation_dummy = np.empty(len(validation_data))

In [None]:
MAX_TOKENS = 100

In [None]:
input_ids_article_title = []
input_masks_article_title = []
input_token_id_article_title = []

input_ids_table_true_title = []
input_masks_table_true_title = []
input_token_id_table_true_title = []

input_ids_table_false_title = []
input_masks_table_false_title = []
input_token_id_table_false_title = []

for i,row in tqdm(train_data.iterrows()):
    
    article_page_title = str(row['article_page_title'])
    true_table_page_title = str(row['true_table_page_title'])
    false_table_page_title = str(row['false_table_page_title'])
    
    article_page_description = str(row['article_page_meta_description'])
    true_table_page_description = str(row['true_table_page_summary'])
    false_table_page_description = str(row['false_table_page_summary'])
    
    #encoder article title
    return_tokenizer1 = bert_tokenizer.encode_plus(
      article_page_title,
      article_page_description,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
      #return_tensors='tf',
    )
    
    input_ids_article_title.append(return_tokenizer1['input_ids'])
    input_masks_article_title.append(return_tokenizer1['attention_mask'])
    input_token_id_article_title.append(return_tokenizer1['token_type_ids'])  
    
    
    #encoder table true title
    return_tokenizer2 = bert_tokenizer.encode_plus(
      true_table_page_title,
      true_table_page_description,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
      #return_tensors='tf',
    )
    
    input_ids_table_true_title.append(return_tokenizer2['input_ids'])
    input_masks_table_true_title.append(return_tokenizer2['attention_mask'])
    input_token_id_table_true_title.append(return_tokenizer2['token_type_ids'])    
    
    
    
    #encoder table true false
    return_tokenizer3 = bert_tokenizer.encode_plus(
      false_table_page_title,
      false_table_page_description,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
      #return_tensors='tf',
    )
    
    input_ids_table_false_title.append(return_tokenizer3['input_ids'])
    input_masks_table_false_title.append(return_tokenizer3['attention_mask'])
    input_token_id_table_false_title.append(return_tokenizer3['token_type_ids'])  

In [None]:
input_ids_article_title = np.asarray(input_ids_article_title, dtype='int32')
input_masks_article_title = np.asarray(input_masks_article_title, dtype='int32')
input_token_id_article_title = np.asarray(input_token_id_article_title, dtype='int32')

In [None]:
input_ids_table_true_title = np.asarray(input_ids_table_true_title, dtype='int32')
input_masks_table_true_title = np.asarray(input_masks_table_true_title, dtype='int32')
input_token_id_table_true_title = np.asarray(input_token_id_table_true_title, dtype='int32')

In [None]:
input_ids_table_false_title = np.asarray(input_ids_table_false_title, dtype='int32')
input_masks_table_false_title = np.asarray(input_masks_table_false_title, dtype='int32')
input_token_id_table_false_title = np.asarray(input_token_id_table_false_title, dtype='int32')

In [None]:
val_input_ids_article_title = []
val_input_masks_article_title = []
val_input_token_id_article_title = []

val_input_ids_table_true_title = []
val_input_masks_table_true_title = []
val_input_token_id_table_true_title = []

val_input_ids_table_false_title = []
val_input_masks_table_false_title = []
val_input_token_id_table_false_title = []

for i,row in tqdm(validation_data.iterrows()):
    
    article_page_title = str(row['article_page_title'])
    true_table_page_title = str(row['true_table_page_title'])
    false_table_page_title = str(row['false_table_page_title'])
    
    article_page_description = str(row['article_page_meta_description'])
    true_table_page_description = str(row['true_table_page_summary'])
    false_table_page_description = str(row['false_table_page_summary'])
    
    #encoder article title
    return_tokenizer1 = bert_tokenizer.encode_plus(
      article_page_title,
      article_page_description,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
      #return_tensors='tf',
    )
    
    val_input_ids_article_title.append(return_tokenizer1['input_ids'])
    val_input_masks_article_title.append(return_tokenizer1['attention_mask'])
    val_input_token_id_article_title.append(return_tokenizer1['token_type_ids'])  
    
    
    #encoder table true title
    return_tokenizer2 = bert_tokenizer.encode_plus(
      true_table_page_title,
      true_table_page_description,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
      #return_tensors='tf',
    )
    
    val_input_ids_table_true_title.append(return_tokenizer2['input_ids'])
    val_input_masks_table_true_title.append(return_tokenizer2['attention_mask'])
    val_input_token_id_table_true_title.append(return_tokenizer2['token_type_ids'])    
    
    
    
    #encoder table true false
    return_tokenizer3 = bert_tokenizer.encode_plus(
      false_table_page_title,
      false_table_page_description,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
      #return_tensors='tf',
    )
    
    val_input_ids_table_false_title.append(return_tokenizer3['input_ids'])
    val_input_masks_table_false_title.append(return_tokenizer3['attention_mask'])
    val_input_token_id_table_false_title.append(return_tokenizer3['token_type_ids'])  

In [None]:
val_input_ids_article_title = np.asarray(val_input_ids_article_title, dtype='int32')
val_input_masks_article_title = np.asarray(val_input_masks_article_title, dtype='int32')
val_input_token_id_article_title = np.asarray(val_input_token_id_article_title, dtype='int32')

In [None]:
val_input_ids_table_true_title = np.asarray(val_input_ids_table_true_title, dtype='int32')
val_input_masks_table_true_title = np.asarray(val_input_masks_table_true_title, dtype='int32')
val_input_token_id_table_true_title = np.asarray(val_input_token_id_table_true_title, dtype='int32')

In [None]:
val_input_ids_table_false_title = np.asarray(val_input_ids_table_false_title, dtype='int32')
val_input_masks_table_false_title = np.asarray(val_input_masks_table_false_title, dtype='int32')
val_input_token_id_table_false_title = np.asarray(val_input_token_id_table_false_title, dtype='int32')

In [None]:
def triplet_loss(y_true, y_pred, alpha = 0.5):
     
    anchor = y_pred[:,0:768]
    positive = y_pred[:,768:1536]
    negative = y_pred[:,1536:2304]
        
    # distance between the anchor and the positive
    pos_dist = tf.keras.layers.Dot(axes=1,normalize=True)([anchor, positive])
    
    # distance between the anchor and the negative
    neg_dist = tf.keras.layers.Dot(axes=1,normalize=True)([anchor, negative])
    
    # compute loss
    basic_loss = (1 - pos_dist) - (1 - neg_dist) + alpha
    loss = tf.keras.backend.maximum(basic_loss,0.0)
 
    return loss

In [None]:
article_title_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_token1', dtype='int32')
article_title_mask_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='masked_token1', dtype='int32')
article_title_token_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='token_ids_token1', dtype='int32')

table_true_title_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_token2', dtype='int32')
table_true_title_mask_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='masked_token2', dtype='int32')
table_true_title_token_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='token_ids_token2', dtype='int32')

table_false_title_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_token3', dtype='int32')
table_false_title_mask_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='masked_token3', dtype='int32')
table_false_title_token_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='token_ids_token3', dtype='int32')

#bert model layers
last_hidden_state1, pooled_output1 = model_bert_layer([article_title_id,article_title_mask_id,article_title_token_id])
last_hidden_state2, pooled_output2 = model_bert_layer([table_true_title_id,table_true_title_mask_id,table_true_title_token_id])
last_hidden_state3, pooled_output3 = model_bert_layer([table_false_title_id,table_false_title_mask_id,table_false_title_token_id])

concatenated = tf.keras.layers.Concatenate(axis=-1)([pooled_output1, pooled_output2, pooled_output3])

model = tf.keras.Model(inputs=[article_title_id, 
                               article_title_mask_id,
                               article_title_token_id,
                               table_true_title_id,
                               table_true_title_mask_id,
                               table_true_title_token_id,
                               table_false_title_id,
                               table_false_title_mask_id,
                               table_false_title_token_id], 
                       outputs = concatenated)

In [None]:
model.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)

In [None]:
model.compile(loss=triplet_loss,optimizer=optimizer)

In [None]:
filepath="encoder_BERT_{epoch:02d}_{val_loss:.4f}.h5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]

In [None]:
history = model.fit([input_ids_article_title,
                     input_masks_article_title,
                     input_token_id_article_title,
                     input_ids_table_true_title,
                     input_masks_table_true_title,
                     input_token_id_table_true_title,
                     input_ids_table_false_title,
                     input_masks_table_false_title,
                     input_token_id_table_false_title], 
                    Y_train_dummy, 
                    epochs=5, 
                    batch_size=16,
                    verbose=1,
                    validation_data=([
                     val_input_ids_article_title,
                     val_input_masks_article_title,
                     val_input_token_id_article_title,
                     val_input_ids_table_true_title,
                     val_input_masks_table_true_title,
                     val_input_token_id_table_true_title,
                     val_input_ids_table_false_title,
                     val_input_masks_table_false_title,
                     val_input_token_id_table_false_title], Y_validation_dummy),
                    callbacks = callbacks_list)

In [None]:
# bert_model.save_pretrained('fine_tuning_bert/')

In [None]:
# model.save("encoder_bert.h5")