In [1]:
######################################################## LIBRARIES ########################################################
import tensorflow as tf

import numpy as np

from sklearn.metrics import classification_report

import CONSTANTS as c
######################################################## CONSTANTS ########################################################
ATTENTION_KEY_DIMS =c.ATTENTION_KEY_DIMS
ATTENTION_NR_OF_HEADS = c.ATTENTION_NR_OF_HEADS
ENCODER_DENSE_DIMS = c.ENCODER_DENSE_DIMS
DROPOUT_RATE = c.DROPOUT_RATE
NR_OF_ENCODER_BLOCKS = c.NR_OF_ENCODER_BLOCKS

CHECKPOINT_PATH =  c.CHECKPOINT_PATH
DISERT_DATA_PATH = c.DISERT_DATA_PATH

STOPPER_THRESHOLD = c.STOPPER_THRESHOLD
######################################################## DATA SOURCE ########################################################
X_MLM = np.load(f'{DISERT_DATA_PATH}\X_MLM.npy')
Y_MLM = np.load(f'{DISERT_DATA_PATH}\Y_MLM.npy')
X_NSP = np.load(f'{DISERT_DATA_PATH}\X_NSP.npy')
Y_NSP = np.load(f'{DISERT_DATA_PATH}\Y_NSP.npy')

In [2]:
X_MLM = X_MLM[:4000]
X_NSP = X_NSP[:4000]
Y_MLM = Y_MLM[:4000]
Y_NSP = Y_NSP[:4000]

In [None]:
# DisERT_model.compile(
#     optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
#     loss = [mlm_custom_loss, nsp__custom_loss]
# )

# print(DisERT_model.summary())
# tf.keras.utils.plot_model(DisERT_model, show_shapes=True)

In [None]:
import time
import os

class CsvLogger(tf.keras.callbacks.Callback):
    
    def __init__(self, sFilePath, sDelimeter = ';'):
        sFilePath = f'{sFilePath}.txt'
        self.sFilePath = sFilePath
        self.sDelimeter = sDelimeter


    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()
        

    def on_epoch_end(self, epoch, logs={}):
        endTime = time.time()
        
        logs['start'] = self.epoch_time_start
        logs['end'] = endTime
        logs['duration'] = endTime - self.epoch_time_start
        
        logs['epoch'] = epoch
        logs['learning_rate'] = self.model.optimizer.lr.numpy()
        
        if os.path.exists(self.sFilePath) == False:
            with open(self.sFilePath, 'a') as f: 
                f.write(self.sDelimeter.join([str(i) for i in logs.keys()]))
                f.write('\n') 
                
        with open(self.sFilePath, 'a') as f: 
            f.write(self.sDelimeter.join([str(i) for i in logs.values()]))
            f.write('\n')
            
    
class ThresholdStopper(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        self.threshold = threshold

    def on_batch_end(self, batch, logs={}):
        if logs.get('loss') <= self.threshold:
             self.model.stop_training = True

In [None]:
sCheckPointFilePath = f'{CHECKPOINT_PATH}\model'
sCsvLogFilePath = f'{CHECKPOINT_PATH}\log'

DisERT_model = tf.keras.models.load_model(sCheckPointFilePath,
                                         custom_objects ={'mlm_custom_loss':mlm_custom_loss, 'nsp_custom_loss':nsp_custom_loss} )

In [None]:
# try:
#     DisERT_model.load_weights(sCheckPointFilePath)
# except:
#     print('There is no pre-trained model checkpoint')

oCsvLogger=  CsvLogger(sCsvLogFilePath)

oDisERTCheckPoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=sCheckPointFilePath,
    save_weights_only=False,
    monitor='loss',
    mode='min',
    save_best_only=True
)

oLearningRateReducer = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='loss', 
    factor=0.80,
    patience=3, 
    min_lr=1e-4
)

oEarlyStopper = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=20)

oThresholdStopper = ThresholdStopper(STOPPER_THRESHOLD)

DisERT_model.fit(
    x = [X_MLM, X_NSP], 
    y = [Y_MLM, Y_NSP], 
    batch_size= 512,
    epochs=2,
    verbose=1,
    callbacks = [oDisERTCheckPoint, oCsvLogger, oLearningRateReducer, oEarlyStopper, oThresholdStopper]
)