In [None]:
import tensorflow as tf
import numpy as np
import tensorflow_text as text

In [None]:
num_layers = 4
num_heads = 8
ffn_dim = 2048
embed_dim = 512
vocab_size = 31000
maxlen = 200
dataset_file = "datasets/dataset.txt"
bpe_model = "bpe_model.model"

Simple linear layer:

In [None]:
class Linear(tf.Module):
    def __init__(self, input_dim, output_dim, name="Linear"):
        super().__init__(name=name)
        self.initializer = tf.initializers.GlorotUniform()
        self.w = tf.Variable(self.initializer(shape=[input_dim, output_dim]), name=name + "_w")
        self.b = tf.Variable(tf.zeros([output_dim]), name=name+"_b")
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x):
        return tf.matmul(x, self.w) + self.b

Gaussan Error Linear Unit activation for FFN.
Same as tf.keras.activations.Gelu(x, approximate=True):

In [None]:
@tf.function
def gelu_new(x):
    return 0.5*x*(1+tf.math.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.math.pow(x, 3))))

Create triangular mask for decoder layer, to make sure its attention uses only previous tokens:

In [None]:
@tf.function
def CasualMask(size): 
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 
    return mask

FeedForwardNetwork. Make transformer block attention outputs fit to the input of the next transformer block:

In [None]:
class FFN(tf.Module):
    def __init__(self, embed_dim, ffn_dim, name="FeedForward"):
        super().__init__(name=name)
        self.dense0 = Linear(embed_dim, ffn_dim, name=name)
        self.dense1 = Linear(ffn_dim, embed_dim, name=name)
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x):
        out1 = self.dense0(x)
        out2 = gelu_new(out1)
        return self.dense1(out2)

Calculate attention outputs and weights. The mask is multiplied by -1e9 before softmax to assign zero weights for useless tokens that are marked with 1. Setting weights after softmax breaks the probability distribution.

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    attention_logits = qk / tf.math.sqrt(dk)
    
    if mask is not None:
        attention_logits += (mask * -1e9)
        
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights

Multi Head Attention layer. This layer splits the Q, K, V into separate heads. Each value has a dimension divided by the number of heads after splitting.

In [None]:
class MHA(tf.Module):
    def __init__(self, num_heads, key_dim, name="MHA"):
        super().__init__(name=name)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.depth = key_dim // num_heads
        
        self.wq = Linear(key_dim, key_dim, name="MHA_query")
        self.wk = Linear(key_dim, key_dim, name="MHA_key")
        self.wv = Linear(key_dim, key_dim, name="MHA_value")
        
        self.dense = Linear(key_dim, key_dim, name="MHA_dense")
        
    @tf.function
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, - 1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0,2,1,3])
    
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, q, k, v, mask):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        
        scaled_attention = tf.transpose(scaled_attention, perm=[0,2,1,3])
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.key_dim))
        
        output = self.dense(concat_attention)
        
        return output, attention_weights

Token and postitional embedding. If it is needed, this layer also returns the padding mask used in the scaled dot product attention function.

In [None]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)


In [None]:
class Embedding(tf.Module):
    def __init__(self, vocab_size, embed_dim, maxlen, name="Embeddings"):
        super().__init__(name=name)
        self.embed_dim = embed_dim
        self.initializer = tf.initializers.GlorotUniform()
        
        self.w0 = tf.Variable(self.initializer([vocab_size, embed_dim]),name="token_embedding")
        self.pos_encoding = positional_encoding(maxlen,self.embed_dim)
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, masking=False):
        x = tf.cast(x, tf.int32)
        seq_len = tf.shape(x)[-1]
        
        if isinstance(x, tf.sparse.SparseTensor):
            sparse_inputs_expanded = tf.sparse.expand_dims(x, axis=-1)
            out = tf.nn.safe_embedding_lookup_sparse(embedding_weights=self.w0,sparse_ids=sparse_inputs_expanded, default_id=0)
            
        else:
            out = tf.nn.embedding_lookup(self.w0, x)

        out *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        out += self.pos_encoding[:, :seq_len, :]
        
        if masking:
            #Create mask to prevent encoder paying attention to PAD tokens
            mask = tf.cast(tf.math.equal(x, 0)[:, tf.newaxis, tf.newaxis, :], tf.float32)
            return out, mask
        
        return out

Layer normalization. Normalize to mean = 0, std = 1.

In [None]:
class LayerNorm(tf.Module):
    def __init__(self, shape, epsilon=1e-5, axis=-1, name="LayerNorm"):
        super().__init__(name=name)
        self.gamma = None
        self.beta = None
        self.epsilon = epsilon
        self.axis = axis
        self.gamma = tf.Variable(tf.ones(shape), name="layernorm_gamma")
        self.beta = tf.Variable(tf.zeros(shape), name="layernorm_beta")
    
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x):
        u = tf.reduce_mean(x, axis=self.axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x-u), axis=self.axis, keepdims=True)
        x = (x - u) * tf.math.rsqrt(s + self.epsilon)
        return x * self.gamma + self.beta

Create encoder block:

In [None]:
class Encoder(tf.Module):
    def __init__(self, num_heads, embed_dim, ffn_dim, name="encoder"):
        super().__init__(name=name)
        self.att = MHA(num_heads, embed_dim)
        self.ffn = FFN(embed_dim, ffn_dim)
        self.norm1 = LayerNorm(embed_dim)
        self.norm2 = LayerNorm(embed_dim)
        
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, mask=None):
        att_output, att_weights = self.att(x,x,x, mask)
        att_output = tf.nn.dropout(att_output, rate=0.1)
        norm_output = self.norm1(x + att_output)
        ffn_out = self.ffn(norm_output)
        ffn_out = tf.nn.dropout(ffn_out, rate=0.1)
        
        return self.norm2(norm_output + ffn_out)

Create decoder block:

In [None]:
class Decoder(tf.Module):
    def __init__(self, num_heads, embed_dim, ffn_dim, name="decoder"):
        super().__init__(name=name)
        self.att1 = MHA(num_heads, embed_dim)
        self.att2 = MHA(num_heads, embed_dim)
        self.ffn = FFN(embed_dim, ffn_dim)
        
        self.layernorm1 = LayerNorm(embed_dim)
        self.layernorm2 = LayerNorm(embed_dim)
        self.layernorm3 = LayerNorm(embed_dim)
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, encoder_seq, mask=None):
        seq_len = tf.shape(x)[1]
        
        casual_mask = CasualMask(seq_len)
        
        if mask is not None:
            mask = tf.maximum(tf.cast(mask, tf.float32), casual_mask)
            
        att1_output, _ = self.att1(x, x, x, casual_mask)
        att1_dropout = tf.nn.dropout(att1_output, rate=0.1)
        out1 = self.layernorm1(x+att1_dropout)
        
        att2_output, _ = self.att2(out1, encoder_seq, encoder_seq, mask)
        att2_dropout = tf.nn.dropout(att2_output, rate=0.1)
        out2 = self.layernorm2(out1 + att2_dropout)
        
        ffn_output = self.ffn(out2)
        ffn_dropout = tf.nn.dropout(ffn_output, rate=0.1)
        return self.layernorm3 (out2+ffn_dropout)

These layers makes it possible to use multiple encoder and decoder layers by iterating over the specified number of layers.

In [None]:
class Encoder_layer(tf.Module):
    def __init__(self, num_layers, num_heads, embed_dim, ffn_dim, name="encoder_layer"):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.enc_layers = [Encoder(num_heads, embed_dim, ffn_dim, name="encoder"+str(_)) #Every block needs a name
                    for _ in range(num_layers)]                                          #Or it will not be recognized
                                                                                         #After loading model from 
                                                                                         #tf.saved_model
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, mask=None):
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, mask)
        return x

In [None]:
class Decoder_layer(tf.Module):
    def __init__(self, num_layers, num_heads, embed_dim, ffn_dim, name="decoder_layer"):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.dec_layers = [Decoder(num_heads, embed_dim, ffn_dim, name="decoder"+str(_))
                          for _ in range(num_layers)]
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, encoder_seq, mask=None):
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, encoder_seq, mask)
            
        return x

Define hyperparameters for building the model

Build the transformer model:

In [None]:
class Transformer(tf.Module):
    def __init__(self, num_layers, num_heads, ffn_dim, embed_dim, vocab_size, maxlen, name="transformer"):
        super().__init__(name=name)
        
        self.maxlen = maxlen
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        self.num_layers = num_layers
        
        self.dense = Linear(embed_dim, vocab_size, name="model_output")
        
        self.embedding_encoder = Embedding(vocab_size, embed_dim, maxlen)
        self.embedding_decoder = Embedding(vocab_size, embed_dim, maxlen)
        
        self.encoder = Encoder_layer(num_layers, num_heads, embed_dim, ffn_dim)
        self.decoder = Decoder_layer(num_layers, num_heads, embed_dim, ffn_dim)
    
    @tf.function(input_signature=[tf.TensorSpec(shape=[None, maxlen], dtype=tf.int32), tf.TensorSpec(shape=[None,maxlen], dtype=tf.int32)])
    @tf.Module.with_name_scope
    def __call__(self, inp, targ):
        encoder_emb_output, encoder_pad_mask = self.embedding_encoder(inp, masking=True)
        encoder_output = self.encoder(encoder_emb_output, mask=encoder_pad_mask)
        
        decoder_emb_output, decoder_pad_mask = self.embedding_decoder(targ, masking=True)
        decoder_output = self.decoder(decoder_emb_output, encoder_output, mask=decoder_pad_mask)

        return self.dense(decoder_output)

In [None]:
transformer = Transformer(num_layers, num_heads, ffn_dim, embed_dim, vocab_size, maxlen)

In [None]:
transformer(tf.Variable([tf.ones(maxlen)], dtype=tf.int32), tf.Variable([tf.ones(maxlen)], dtype=tf.int32))

Open and read the dataset file. In my case, there was a simple text splitted by [answ] token, to separate the sentences of questions from the answers.

In [None]:
data = open(dataset_file,"r").read().split("\n")[:-1]

Create question-answer pairs

In [None]:
text_pairs = []
for line in data:
    try:
        qe, ans = line.split("[answ]")
        text_pairs.append((qe, ans))
    except:
        pass

In [None]:
import random
random.shuffle(text_pairs)

Open a BPE trained model, a model vocabulary created from a text dataset opened a couple of cells ago

In [None]:
bpe_model = open(bpe_model, "rb").read()

Create tensorflow-text tokenizers from BPE model

In [None]:
tokenizer_nonpack = text.SentencepieceTokenizer(bpe_model, out_type="int32", add_bos=False, add_eos=False)
tokenizer_pack = text.SentencepieceTokenizer(bpe_model, out_type="int32", add_bos=True, add_eos=True)

In [None]:
BATCH_SIZE = 25

Split text pairs into two lists

In [None]:
q_texts = [pair[0] for pair in text_pairs]
ans_texts = [pair[1] for pair in text_pairs]

Dataset tokenization function:

In [None]:
def process(x, pack=False):
    if not pack:
        outputs = tokenizer_nonpack.tokenize(x)
        outputs = text.pad_model_inputs(outputs, maxlen, pad_value=0)
    else:
        outputs = tokenizer_pack.tokenize(x)
        outputs = text.pad_model_inputs(outputs, maxlen+1, pad_value=0)
    return outputs

In [None]:
def format_ds(x, y):
    q = process(x)[0]
    a = process(y, pack=True)[0]
    return ({"encoder_inputs": q, "decoder_inputs": a[:, :-1],}, a[:, 1:])

Create dataset:

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((list(q_texts), list(ans_texts))).shuffle(2048).batch(BATCH_SIZE).map(format_ds, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [None]:
epochs = 10

Learning rate decay function:

In [None]:
class lr_decay(tf.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embed_dim, warmup_steps):
        super().__init__()
        self.warmup_steps = warmup_steps
        self.embed_dim = tf.cast(embed_dim, tf.float32)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        a1 = tf.math.rsqrt(step)
        a2 = step * (self.warmup_steps ** -1.5)
        
        return tf.math.rsqrt(self.embed_dim) * tf.math.minimum(a1,a2)

In [None]:
lr_schedule = lr_decay(embed_dim, 4000)

Perplexity metric class:

In [None]:
class Perplexity(tf.metrics.Metric):
    def __init__(self, mask_token_id=0):
        super().__init__(dtype=tf.float32)
        
        self.mask_token_id = mask_token_id
        self.crossentropy = tf.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='sum')
        self.aggregate_crossentropy = tf.Variable(tf.zeros(shape=()))
        self.samples_number = tf.Variable(tf.zeros(shape=()))
        
    def result(self):
        if self.samples_number == 0:
            return 0.0
        perplexity_score = tf.exp(self.aggregate_crossentropy / self.samples_number)
        return perplexity_score
    
    def update_state(self, true, pred):
        batch_size = tf.cast(tf.shape(true)[0], tf.float32)
        
        mask = tf.cast(tf.math.logical_not(tf.equal(true, self.mask_token_id)),tf.float32)

        crossentropy_value = tf.cast(self.crossentropy(true, pred, sample_weight=mask),tf.float32) 
        crossentropy_value = crossentropy_value / tf.reduce_sum(mask)
        self.aggregate_crossentropy.assign_add(batch_size * crossentropy_value)
        self.samples_number.assign_add(batch_size)
        
    def reset_state(self):
        self.aggregate_crossentropy.assign(0.0)
        self.samples_number.assign(0.0)

Setup losses, metrics, and optimizer

In [None]:
loss_obj = tf.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
train_loss = tf.metrics.Mean()
train_ppl = tf.metrics.Mean()
ppl_obj = Perplexity()
optimizer = tf.optimizers.Adam(epsilon=1e-9, learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98)

In [None]:
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_obj(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)


def ppl_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real,0))
    ppl = ppl_obj(real, pred)
    
    mask = tf.cast(mask, ppl.dtype)
    ppl *= mask
    
    return tf.reduce_sum(ppl)/tf.reduce_sum(mask)


In [None]:
@tf.function(input_signature=[tf.TensorSpec(shape=(None, maxlen), dtype=tf.int32), 
                              tf.TensorSpec(shape=(None, maxlen), dtype=tf.int32),
                              tf.TensorSpec(shape=(None, maxlen), dtype=tf.int32)])
def train_step(inp, targ, y):
    with tf.GradientTape() as tape:
        logits = transformer.__call__(inp, targ)
        loss_val = loss_function(y, logits)
        
    grads = tape.gradient(loss_val, transformer.trainable_variables)
    optimizer.apply_gradients(zip(grads, transformer.trainable_variables))
            
    train_loss(loss_val)
    train_ppl(ppl_function(y, logits))

In [None]:
summary = len(train_ds)
for e in range(epochs):
    train_loss.reset_state()
    train_ppl.reset_state()
    iterator = iter(train_ds)
    step = 0
    for i in range(len(train_ds)):
        x, y = iterator.get_next()
        y = tf.cast(y, tf.int32)
        train_step(x['encoder_inputs'], x['decoder_inputs'],y)
        step = step + 1

        print("Epoch "+str(e) + "/" + str(epochs) +" Batch: "+str(step) + "/" + str(summary) + " loss: "+ str(train_loss.result().numpy()) + " perplexity: "+str(train_ppl.result().numpy()), end="\r")

In [None]:
tf.saved_model.save(transformer, "transformer", signatures={'call':transformer.__call__})