In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer, TFBertMainLayer, BertConfig

In [None]:
tf.__version__

In [None]:
!python --version

In [None]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU not found')
print('found GPU at {}'.format(device_name))

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')

In [None]:
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [None]:
bert_model = TFBertModel.from_pretrained("bert-base-cased")
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [None]:
train_data = pd.read_csv('../train_data/train_triple_all_signals.csv', delimiter=',')
validation_data = pd.read_csv('../train_data/validation_triple_all_signals.csv', delimiter=',')

In [None]:
# train_data = train_data.head(100)
# validation_data = validation_data.head(100)

In [None]:
Y_train_dummy = np.empty(len(train_data))
Y_validation_dummy = np.empty(len(validation_data))

In [None]:
MAX_TOKENS = 125

In [None]:
train_ids_anchor_all = []
train_mask_anchor_all = []
train_seg_anchor_all = []

train_ids_true_all = []
train_mask_true_all = []
train_seg_true_all = []

train_ids_false_all = []
train_mask_false_all = []
train_seg_false_all = []

for i,row in tqdm(train_data.iterrows()):
    
    anchor_catch_all = str(row['article_page_title']) + str(row['article_page_meta_description']) + str(row['article_page_keywords'])
    
    #encoder article title
    return_tokenizer1 = bert_tokenizer.encode_plus(
      anchor_catch_all,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    train_ids_anchor_all.append(return_tokenizer1['input_ids'])
    train_mask_anchor_all.append(return_tokenizer1['attention_mask'])
    train_seg_anchor_all.append(return_tokenizer1['token_type_ids'])  
    
    
    true_catch_all = str(row['true_table_page_title']) + str(row['true_table_page_summary']) + str(row['true_table_page_keywords'])
    
    #encoder table true title
    return_tokenizer2 = bert_tokenizer.encode_plus(
      true_catch_all,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    train_ids_true_all.append(return_tokenizer2['input_ids'])
    train_mask_true_all.append(return_tokenizer2['attention_mask'])
    train_seg_true_all.append(return_tokenizer2['token_type_ids'])    
    
    false_catch_all = str(row['false_table_page_title']) + str(row['false_table_page_summary']) + str(row['false_table_page_keywords'])
    
    #encoder table true false
    return_tokenizer3 = bert_tokenizer.encode_plus(
      false_catch_all,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    train_ids_false_all.append(return_tokenizer3['input_ids'])
    train_mask_false_all.append(return_tokenizer3['attention_mask'])
    train_seg_false_all.append(return_tokenizer3['token_type_ids'])  

In [None]:
train_ids_anchor_all = np.asarray(train_ids_anchor_all)
train_mask_anchor_all = np.asarray(train_mask_anchor_all)
train_seg_anchor_all = np.asarray(train_seg_anchor_all)

train_ids_true_all = np.asarray(train_ids_true_all)
train_mask_true_all = np.asarray(train_mask_true_all)
train_seg_true_all = np.asarray(train_seg_true_all)

train_ids_false_all = np.asarray(train_ids_false_all)
train_mask_false_all = np.asarray(train_mask_false_all)
train_seg_false_all = np.asarray(train_seg_false_all)

In [None]:
val_ids_anchor_all = []
val_mask_anchor_all = []
val_seg_anchor_all = []

val_ids_true_all = []
val_mask_true_all = []
val_seg_true_all = []

val_ids_false_all = []
val_mask_false_all = []
val_seg_false_all = []

for i,row in tqdm(validation_data.iterrows()):
    
    anchor_catch_all = str(row['article_page_title']) + str(row['article_page_meta_description']) + str(row['article_page_keywords'])
    
    #encoder article title
    return_tokenizer1 = bert_tokenizer.encode_plus(
      anchor_catch_all,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    val_ids_anchor_all.append(return_tokenizer1['input_ids'])
    val_mask_anchor_all.append(return_tokenizer1['attention_mask'])
    val_seg_anchor_all.append(return_tokenizer1['token_type_ids'])  
    
    
    true_catch_all = str(row['true_table_page_title']) + str(row['true_table_page_summary']) + str(row['true_table_page_keywords'])
    
    #encoder table true title
    return_tokenizer2 = bert_tokenizer.encode_plus(
      true_catch_all,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    val_ids_true_all.append(return_tokenizer2['input_ids'])
    val_mask_true_all.append(return_tokenizer2['attention_mask'])
    val_seg_true_all.append(return_tokenizer2['token_type_ids'])    
    
    false_catch_all = str(row['false_table_page_title']) + str(row['false_table_page_summary']) + str(row['false_table_page_keywords'])
    
    #encoder table true false
    return_tokenizer3 = bert_tokenizer.encode_plus(
      false_catch_all,
      max_length=MAX_TOKENS,
      add_special_tokens=True,
      return_token_type_ids=True,
      pad_to_max_length=True,
      return_attention_mask=True,
    )
    
    val_ids_false_all.append(return_tokenizer3['input_ids'])
    val_mask_false_all.append(return_tokenizer3['attention_mask'])
    val_seg_false_all.append(return_tokenizer3['token_type_ids'])  

In [None]:
val_ids_anchor_all = np.asarray(val_ids_anchor_all)
val_mask_anchor_all = np.asarray(val_mask_anchor_all)
val_seg_anchor_all = np.asarray(val_seg_anchor_all)

val_ids_true_all = np.asarray(val_ids_true_all)
val_mask_true_all = np.asarray(val_mask_true_all)
val_seg_true_all = np.asarray(val_seg_true_all)

val_ids_false_all = np.asarray(val_ids_false_all)
val_mask_false_all = np.asarray(val_mask_false_all)
val_seg_false_all = np.asarray(val_seg_false_all)

In [None]:
def triplet_loss(y_true, y_pred, alpha = 0.5):
     
    anchor = y_pred[0:,0:768]
    positive = y_pred[0:,768:1536]
    negative = y_pred[0:,1536:2304]
        
    # distance between the anchor and the positive
    pos_dist = tf.keras.layers.Dot(axes=1,normalize=True)([anchor, positive])
    
    # distance between the anchor and the negative
    neg_dist = tf.keras.layers.Dot(axes=1,normalize=True)([anchor, negative])
    
    # compute loss
    basic_loss = (1 - pos_dist) - (1 - neg_dist) + alpha
    loss = tf.keras.backend.maximum(basic_loss,0.0)
 
    return loss

In [None]:
article_title_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_token1', dtype='int32')
article_title_mask_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='masked_token1', dtype='int32')
article_title_token_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='token_ids_token1', dtype='int32')

table_true_title_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_token2', dtype='int32')
table_true_title_mask_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='masked_token2', dtype='int32')
table_true_title_token_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='token_ids_token2', dtype='int32')

table_false_title_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='input_token3', dtype='int32')
table_false_title_mask_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='masked_token3', dtype='int32')
table_false_title_token_id = tf.keras.layers.Input(shape=(MAX_TOKENS,), name='token_ids_token3', dtype='int32')

#bert model layers

last_hidden_state1 = bert_model.bert([article_title_id,article_title_mask_id,article_title_token_id])
last_hidden_state2 = bert_model.bert([table_true_title_id,table_true_title_mask_id,table_true_title_token_id])
last_hidden_state3 = bert_model.bert([table_false_title_id,table_false_title_mask_id,table_false_title_token_id])

out1 = tf.keras.backend.mean(last_hidden_state1[0], axis=1)
out2 = tf.keras.backend.mean(last_hidden_state2[0], axis=1)
out3 = tf.keras.backend.mean(last_hidden_state3[0], axis=1)

concatenated = tf.keras.layers.Concatenate(axis=-1)([out1,out2,out3])

model = tf.keras.Model(inputs=[article_title_id, 
                               article_title_mask_id,
                               article_title_token_id,
                               table_true_title_id,
                               table_true_title_mask_id,
                               table_true_title_token_id,
                               table_false_title_id,
                               table_false_title_mask_id,
                               table_false_title_token_id], 
                       outputs = concatenated)

In [None]:
model.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)

In [None]:
model.compile(loss=triplet_loss,optimizer=optimizer)

In [None]:
history = model.fit([train_ids_anchor_all,
                     train_mask_anchor_all,
                     train_seg_anchor_all,
                     train_ids_true_all,
                     train_mask_true_all,
                     train_seg_true_all,
                     train_ids_false_all,
                     train_mask_false_all,
                     train_seg_false_all], 
                    Y_train_dummy, 
                    epochs=5, 
                    batch_size=16,
                    verbose=1,
                    validation_data=([
                     val_ids_anchor_all,
                     val_mask_anchor_all,
                     val_seg_anchor_all,
                     val_ids_true_all,
                     val_mask_true_all,
                     val_seg_true_all,
                     val_ids_false_all,
                     val_mask_false_all,
                     val_seg_false_all], Y_validation_dummy))

In [None]:
model.save("bert_encoder_model2",save_format='tf')

In [None]:
# loaded_model = tf.keras.models.load_model('bert_encoder_model', custom_objects={'triplet_loss': triplet_loss})

In [None]:
# layer_name = 'bert'
# intermediate_layer_model = tf.keras.Model(inputs=[loaded_model.input[0],loaded_model.input[1],loaded_model.input[2]],
#                                  outputs=loaded_model.get_layer(layer_name).output)

In [None]:
# hid, out = intermediate_layer_model.predict([val_ids_anchor_all,val_mask_anchor_all,val_seg_anchor_all])