<h3><b>Train and evaluate a chatbot based on an encoder-decoder transformer model ( i.e. same as the original transformer model ) . The model is trained on the Cornell-Movie-Dialog dataset.</b></h3>

<h5><b> 0. Setup</b></h5>

In [2]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from dataset import DatasetHp, preprocess_sentence, get_cornell_dataset
from transformer_model import ModelHp, encoder_decoder_transformer

<h5><b> 1. Load dataset and tokenizer</b></h5>

In [3]:
dataset_hp = DatasetHp(
    max_length = 128,
    vocab_size = 10_000,
    max_sample=None,
)

dataset, tokenizer = get_cornell_dataset(dataset_hp)

Downloading data from http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
loading conversations ... 


100%|██████████| 83097/83097 [00:14<00:00, 5849.40it/s]


initializing tokenizer ...
tokenizer saved in `./transformer/tokenizer`
vocab size updated from 10000 to 10054
tokenization ... 


221616it [00:14, 14881.07it/s]


<h5><b> 2. Define loss and metric functions.</b></h5>

In [4]:
optimizer = tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.98, epsilon=1e-9)

cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
def loss_fn(y_true, y_pred):
    y_true = tf.reshape(y_true, shape=(-1, dataset_hp.max_length - 1))
    loss = cross_entropy(y_true, y_pred)
    mask = tf.cast(tf.not_equal(y_true, 0), dtype=tf.float32)
    loss = tf.multiply(loss, mask)
    return tf.reduce_mean(loss)

def accuracy(y_true, y_pred):
    y_true = tf.reshape(y_true, shape=(-1, dataset_hp.max_length - 1))
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)

<h5><b> 3. Build and train the model.</b></h5>

In [5]:
hparams = ModelHp(
    d_model = 512,
    num_attention_heads = 8,
    dropout_rate = 0.1,
    num_units = 1024,
    activation = "relu",
    vocab_size = 10054,
    num_layers = 4,
)

model = encoder_decoder_transformer(hparams, "transformer")
print(f"Total number of model's parameters: {model.count_params()}")

Total number of model's parameters: 36481862


In [6]:
model.compile(optimizer=optimizer, loss=loss_fn, metrics=[accuracy])
model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("model_checkpoint_cb.h5")

In [7]:
history = model.fit(dataset, epochs=5, callbacks=[model_checkpoint_cb])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
