In [1]:
import tensorflow as tf
import numpy as np

import sys
sys.path.append('../..')

from transformer_encoder import MLMTransformerEncoder
from mlm_dataset.mlm_dataset_generator import MLMDatasetGenerator

In [2]:
# MLM dataset for training
mlm_dataset_generator = MLMDatasetGenerator('../../dataset/resume_dataset.csv')
mlm_dataset = mlm_dataset_generator.generateMLMDataset(256)

# Initialize a Tokenizer and fit on text data
tokenizer = tf.keras.preprocessing.text.Tokenizer(oov_token='[OOV]')
tokenizer.fit_on_texts(mlm_dataset_generator.getVocubulary())

# check how many words are in the dataset (currently: 37032)
print(list(tokenizer.word_index.keys()))



In [3]:
# Usage example with original Transformer hyperparameters
num_layers = 6
d_model = 512
num_heads = 8
dff = 2048
input_vocab_size = 40000
maximum_position_encoding = 10000

model = MLMTransformerEncoder(num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding)
dummy_input = [tf.keras.Input(shape=(None,)), tf.keras.Input(shape=(None,))]
model(dummy_input)

# Define an optimizer (e.g., Adam)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Define a loss function (e.g., categorical cross-entropy for classification)
loss_function = tf.keras.losses.CategoricalCrossentropy()

In [6]:
import time

# Define a training loop
def train_step(inputs_batch, labels_batch):
    start_time = time.time()
    per_sequence_start_time = start_time

    total_loss = 0.0
    gradients_accumulator = [tf.zeros_like(var) for var in model.trainable_variables]

    counter = 0
    for inputs, labels in zip(inputs_batch, labels_batch):
        with tf.GradientTape() as tape:
            # create one-hot encoded mask and get the indices
            mask =[[]]
            token_indices = []
            for index, token in enumerate(inputs):
                if token == '[MASK]':
                    mask[0].append(0)
                    token_indices.append(index)
                else: 
                    mask[0].append(1)
            mask = tf.cast(mask, tf.float32)
            # tokenize inputs
            input_ids = tf.cast(tokenizer.texts_to_sequences([inputs]), tf.float32)
            # tokenize labels
            token_ids = tokenizer.texts_to_sequences(labels)
            # create array of zeroes with dimension [sequence_length, input_vocab_size]
            tokenized_labels = np.zeros((len(inputs), input_vocab_size))
            # change the [masked_token_index, token_id] to ones
            for index, token_index in enumerate(token_indices):
                tokenized_labels[token_index, token_ids[index]] = 1
            tokenized_labels = tf.constant(tokenized_labels, dtype=tf.float32)

            # print('\n> INPUTS')
            # print(input_ids)
            # print(mask)

            predictions = model([input_ids, mask], training=True)[0]

            tokenized_labels = tf.nn.softmax(tokenized_labels)
            predictions = tf.nn.softmax(predictions)

            # print('\n> LABELS')
            # print(tokenized_labels)
            # print('\n> PREDICTIONS')
            # print(predictions)

            loss = loss_function(tokenized_labels, predictions)
            # print('\n> LOSS')
            # print(loss)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        # print('GRADIENTS')
        # print(gradients)

        gradients_accumulator = [grad_accum + grad for grad_accum, grad in zip(gradients_accumulator, gradients)]
        total_loss += loss

        print('Seq ' + str(counter) + ', Elapsed Time: ' + str(time.time() - per_sequence_start_time))
        per_sequence_start_time = time.time()
        counter += 1

    gradients_avg = [grad / len(inputs_batch) for grad in gradients_accumulator]
    
    optimizer.apply_gradients(zip(gradients_avg, model.trainable_variables))

    return total_loss / len(inputs_batch), str(time.time() - start_time)

In [7]:
# Example of usage in the training loop
num_epochs = 10
for epoch in range(num_epochs):
    batch_counter = 0
    for inputs_batch, labels_batch in mlm_dataset:  # Provide training data
        loss, elapsed_time = train_step(inputs_batch, labels_batch)
        # Log or print the loss for monitoring
        print('Epoch ' + str(epoch) + ', Batch ' + str(batch_counter) + ', Loss = ' + str(loss) + ', Elapsed Time: ' + elapsed_time)
        batch_counter += 1

Seq 0, Elapsed Time: 1.3892762660980225
Seq 1, Elapsed Time: 1.013577938079834
Seq 2, Elapsed Time: 1.2359111309051514
Seq 3, Elapsed Time: 1.0576038360595703
Seq 4, Elapsed Time: 0.8138155937194824
Seq 5, Elapsed Time: 0.8039431571960449
Seq 6, Elapsed Time: 0.7999587059020996
Seq 7, Elapsed Time: 0.8213064670562744
Seq 8, Elapsed Time: 0.8163027763366699
Seq 9, Elapsed Time: 0.8003222942352295
Seq 10, Elapsed Time: 0.8403520584106445
Seq 11, Elapsed Time: 0.8280596733093262
Seq 12, Elapsed Time: 0.9115438461303711
Seq 13, Elapsed Time: 0.8644077777862549
Seq 14, Elapsed Time: 0.9649076461791992
Seq 15, Elapsed Time: 0.822723388671875
Seq 16, Elapsed Time: 0.8256533145904541
Seq 17, Elapsed Time: 0.8278286457061768
Seq 18, Elapsed Time: 1.0708527565002441
Seq 19, Elapsed Time: 0.8761017322540283
Seq 20, Elapsed Time: 0.8139517307281494
Seq 21, Elapsed Time: 1.040633201599121
Seq 22, Elapsed Time: 0.8638637065887451
Seq 23, Elapsed Time: 1.151839017868042
Seq 24, Elapsed Time: 0.898852

KeyboardInterrupt: 