In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, Embedding, Lambda, TimeDistributed, \
                                    Add, Conv1D, Dropout, Concatenate, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import TensorBoard

from model import _get_positional_encoding_matrix, Encoder, Decoder

Using TensorFlow backend.


In [2]:
physical_devices = tf.config.experimental.list_physical_devices('GPU') 
for physical_device in physical_devices: 
    tf.config.experimental.set_memory_growth(physical_device, True)

print(physical_devices)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
import numpy as np
import json
from tqdm.auto import tqdm
import time

In [4]:
!pwd

/home/sweet/1-workdir/nlp_attention/en_vi_attention_nlp/src


In [5]:
num_layers=2
num_multi_heads=4
d_k=64
d_v=64
d_model=256
optimizer="adam"
null_token_value=0
source_vocab_size = 48114
target_vocab_size = 22468
share_word_embedding=False
MAXIMUM_TEXT_LENGTH = 866

In [6]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [7]:
training_records="/home/sweet/1-workdir/nlp_attention/en_vi_attention_nlp/src/train.tfrecords"
validating_records="/home/sweet/1-workdir/nlp_attention/en_vi_attention_nlp/src/test.tfrecords"

raw_training_set = tf.data.TFRecordDataset(training_records)
raw_validating_set = tf.data.TFRecordDataset(validating_records)


feature_description = {
    'input': tf.io.FixedLenFeature([], tf.string),
    'target': tf.io.FixedLenFeature([], tf.string)
}

def _parse_record_function(example_proto):
    # Parse the input tf.Example proto using the dictionary above.
    features = tf.io.parse_single_example(example_proto, feature_description)
    X = tf.io.decode_raw(features['input'], np.int32)
    Y = tf.io.decode_raw(features['target'], np.int32)
    return X[3:], Y[3:]

In [8]:
with tf.device('/cpu:0'):
    BUFFER_SIZE_TRAIN = 133317
    BATCH_SIZE_TRAIN = 16
    N_STEPS_PER_EPOCH_TRAIN = int(np.ceil(BUFFER_SIZE_TRAIN/BATCH_SIZE_TRAIN))

    train_dataset = raw_training_set.map(_parse_record_function)\
                                    .batch(batch_size=BATCH_SIZE_TRAIN)\
                                    .prefetch(buffer_size=AUTOTUNE)
    train_gen = train_dataset.__iter__()

In [9]:
with tf.device('/cpu:0'):
    BUFFER_SIZE_VAL = 2821
    BATCH_SIZE_VAL = 16
    N_STEPS_PER_EPOCH_VAL = int(np.ceil(BUFFER_SIZE_VAL/BATCH_SIZE_VAL))

    validation_dataset = raw_validating_set.map(_parse_record_function)\
                                        .batch(batch_size=BATCH_SIZE_VAL)\
                                        .prefetch(buffer_size=AUTOTUNE)
    val_gen = validation_dataset.__iter__()

In [10]:
source_word_embedding = Embedding(source_vocab_size, d_model, name="source_embedding" if share_word_embedding else "source_embedding")  # weights=[_get_positional_encoding_matrix(max_length, d_model)]
if share_word_embedding:
    target_word_embedding = source_word_embedding
else:
    target_word_embedding = Embedding(target_vocab_size, d_model, name="target_embedding")
# embedding for the position encoding
position_encoding = Embedding(MAXIMUM_TEXT_LENGTH, 
                              d_model, 
                              trainable=False, 
                              weights=[_get_positional_encoding_matrix(MAXIMUM_TEXT_LENGTH, d_model)], 
                              name="position_embedding")

In [11]:
enc = Encoder(source_word_embedding, position_encoding, 
              n=num_layers, h=num_multi_heads, 
              d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=512)
dec = Decoder(target_word_embedding, position_encoding, 
              n=num_layers, h=num_multi_heads, 
              d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=512)
final_layer = TimeDistributed(Dense(target_vocab_size, activation=None, use_bias=False), name="output")

In [12]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

@tf.function
def loss_function(y_true, y_pred):
    y_true_id = K.cast(y_true, "int32")

    mask = K.cast(K.equal(y_true_id, 0), K.floatx())
    mask = 1.0 - mask
    loss = K.sparse_categorical_crossentropy(
        y_true, y_pred, from_logits=True) * mask

    # take average w.r.t. the number of unmasked entries
    
    return K.sum(loss) / K.sum(mask)

@tf.function
def accuracy_function(y_true, y_pred):
    y_true = K.cast(y_true, "int32")
    mask = 1.0 - K.cast(K.equal(y_true, 0), K.floatx())

    y_pred = K.cast(K.argmax(y_pred, axis=-1), "int32")
    correct = K.cast(
        K.equal(y_pred, y_true),
        K.floatx()
    )
    correct = K.sum(correct * mask, -1) / K.sum(mask, -1)
    
    return K.mean(correct)

In [13]:
@tf.function
def train_step(source_input, target_input):
    enc_input = Lambda(lambda x:x[:,1:])(source_input)
    dec_input  = Lambda(lambda x:x[:,:-1])(target_input)
    dec_target_output = Lambda(lambda x:x[:,1:])(target_input)
    
    loss = 0

    with tf.GradientTape() as tape:
        enc_output = enc(enc_input)
        dec_output = dec(dec_input, enc_output)
        fin_output_out = final_layer(dec_output)
        
        loss = loss_function(dec_target_output, fin_output_out)
        batch_loss = (loss / int(targ.shape[1]))
        
        variables = enc.trainable_variables + dec.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss

In [14]:
EPOCHS = 10

for epoch in range(EPOCHS):
    start = time.time()
    losses = np.zeros(N_STEPS_PER_EPOCH_TRAIN)
    pbar = tqdm(total=N_STEPS_PER_EPOCH_TRAIN)
    i = 0
    for inp, targ in train_gen:
        batch_loss = train_step(inp, targ)
        losses[i] = tf.cast(batch_loss, dtype=tf.float32)

        pbar.set_postfix(loss=np.mean(losses[0:i]))
        pbar.update(1)
        
        i = i + 1

    print('Epoch {} Loss {:.4f}'.format(epoch + 1, np.mean(losses)))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

HBox(children=(FloatProgress(value=0.0, max=8333.0), HTML(value='')))

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


KeyboardInterrupt: 