In [1]:
import tensorflow as tf
import numpy as np

import sys
sys.path.append('../..')

from transformer_encoder import MLMTransformerEncoder
from mlm_dataset.mlm_dataset_generator import MLMDatasetGenerator

In [2]:
batch_size = 20

# MLM dataset for training
mlm_dataset_generator = MLMDatasetGenerator('../../dataset/resume_dataset.csv')
mlm_dataset = [mlm_dataset_generator.generateMLMDataset(batch_size)[0]]

oov_token = '[OOV]'

# Initialize a Tokenizer and fit on text data
tokenizer = tf.keras.preprocessing.text.Tokenizer(oov_token=oov_token, filters='')
tokenizer.fit_on_texts(mlm_dataset_generator.getVocubulary())

In [3]:
# check how many words are in the dataset (currently: 37032)
# print(list(tokenizer.word_index.keys()))

In [4]:
# MLM dataset checker
# inputs, labels = mlm_dataset[0]

# print(inputs[121], labels[121])
# print(inputs[122], labels[122])

# for index, row in enumerate(inputs):
#     if(row.count('[MASK]') != len(labels[index])):
#         print(index, row, labels[index])

In [5]:
# Usage example with original Transformer hyperparameters
num_layers = 2
d_model = 512
num_heads = 8
dff = 2048
input_vocab_size = 40000
maximum_position_encoding = 10000

model = MLMTransformerEncoder(num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding)
dummy_input = [tf.keras.Input(shape=(None,)), tf.keras.Input(shape=(None,))]
model(dummy_input)

# Define an optimizer (e.g., Adam)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Define a loss function (e.g., categorical cross-entropy for classification)
loss_function = tf.keras.losses.CategoricalCrossentropy()

# np.set_printoptions(threshold=np.inf)

In [6]:
import time

# model_trainable_variables = []
# model_gradients = []

# Define a training loop
def train_step(inputs_batch, labels_batch):
    start_time = time.time()

    total_loss = 0.0
    gradients_accumulator = [tf.zeros_like(var) for var in model.trainable_variables]

    # model_trainable_variables.append([v.numpy() for v in model.trainable_variables])

    counter = 0
    for inputs, labels in zip(inputs_batch, labels_batch):
        # create one-hot encoded mask and get the indices
        mask =[[]]
        token_indices = []
        for index, token in enumerate(inputs):
            if token == '[MASK]':
                mask[0].append(0)
                token_indices.append(index)
            else: 
                mask[0].append(1)
        mask = tf.constant(mask, tf.float32)
        # tokenize inputs
        input_ids = tf.cast(tokenizer.texts_to_sequences([inputs]), tf.float32)
        # tokenize labels
        token_ids = tokenizer.texts_to_sequences(labels)
        # create array of zeroes with dimension [sequence_length, input_vocab_size]
        tokenized_labels = np.zeros((len(inputs), input_vocab_size))
        # change the [masked_token_index, token_id] to ones
        for index, token_index in enumerate(token_indices):
            tokenized_labels[token_index, token_ids[index]] = 1
        tokenized_labels = tf.constant(tokenized_labels, dtype=tf.float32)

        with tf.GradientTape() as tape:
            predictions = model([input_ids, mask], training=False)[0]

            loss = loss_function(tokenized_labels, predictions)
            # print('\n> LOSS')
            # print(loss)

        # get the predicted token(s) ID(s)
        # indices = []
        predicted_token = []
        for index, row in enumerate(mask[0]):
            if (row == 0):
                predicted_token.append(np.argmax(predictions[index]))
                # indices.append(index)

        # if (counter == 9):
        #     print(inputs, labels, mask, token_indices, token_ids)
        #     # print('\n> LABELS')
        #     # print(tokenized_labels)
        #     print('\n> PREDICTIONS')
        #     print(predictions)

        #     # display the token index and element index of values > 0
        #     for index, row in enumerate(tokenized_labels):
        #         for element_index, element in enumerate(row):
        #             if (element > 0):
        #                 print(index, element, element_index)

        # Manual Loss calculation
        # total_loss_test = 0
        # for tokenized_label, prediction in zip(tokenized_labels, predictions):
        #     total_loss_test += np.sum(tokenized_label * -np.log(prediction))
        # print("manual:", total_loss_test / len(predictions))
        
        gradients = tape.gradient(loss, model.trainable_variables)
        # print('GRADIENTS')
        # print(gradients)

        # model_gradients.append(gradients)

        gradients_accumulator = [grad_accum + grad for grad_accum, grad in zip(gradients_accumulator, gradients)]
        total_loss += loss

        print('Seq ' + str(counter) + ', Loss = ' + str(loss.numpy()) + ', Predicted Token = ' + str(predicted_token) + ', True Token = ' + str(token_ids))
        counter += 1

    gradients_avg = [grad / len(inputs_batch) for grad in gradients_accumulator]
    
    # optimizer.minimize(total_loss / len(inputs_batch), model.trainable_variables, tape=tape)
    optimizer.apply_gradients(zip(gradients_avg, model.trainable_variables))

    return total_loss / len(inputs_batch), str(time.time() - start_time)

In [None]:
# Example of usage in the training loop
num_epochs = 100
for epoch in range(num_epochs):
    batch_counter = 0
    for inputs_batch, labels_batch in mlm_dataset:  # Provide training data
        loss, elapsed_time = train_step(inputs_batch, labels_batch)
        # Log or print the loss for monitoring
        print('Epoch ' + str(epoch) + ', Batch ' + str(batch_counter) + ', Loss = ' + str(loss.numpy()) + ', Elapsed Time: ' + elapsed_time + '\n')
        batch_counter += 1

In [None]:
# print(model_trainable_variables[0])
# print(model_gradients[0])

In [None]:
# print(model_trainable_variables[8])