In [None]:
import tensorflow as tf
import sentencepiece
import numpy as np

In [None]:
model = tf.saved_model.load("transformer")

Maximum length hyperparameter of pretrained model

In [None]:
max_len = 200

Performing top-k sampling on the probabilities returned by the model.
The k argument in tf.math.top_k indicates how many elements with the biggest values should be returned.

In [None]:
def sample_from_logits(logits):
    logits, indices = tf.math.top_k(logits, k=10, sorted=True) 
    softmax_preds = tf.nn.softmax(tf.expand_dims(logits,0))[0]
    return np.random.choice(indices.numpy(), p=softmax_preds.numpy())

Pad tokenized sentences to model input size i.e. maximum length.

In [None]:
def pad(sentence):
    x = sentence
    pad_len = max_len - len(sentence)
    sample_index = len(sentence) - 1
    if pad_len < 0:
        x = sentence[:max_len]
        sample_index = max_len - 1
    elif pad_len > 0:
        x = sentence + [0] * pad_len
        
    return x, sample_index

Load pretrained BPE tokenizer model.

In [None]:
tokenizer = sentencepiece.SentencePieceProcessor()
tokenizer.load("bpe_model.model")

The model predicts tokens until the maximum length is reached or until the predicted token is equal to 3. After that the collected tokens are detokenized.

In [None]:
sentence = "hello"
tokens_input = tokenizer.tokenize(sentence)
tokens_input_pad, _ = pad(tokens_input)
number_of_tokens = 0
tokens_target = [2] #<s> start sentence token index must be inserted into this list
current_token = tokens_target[0]

while number_of_tokens < max_len:
    if current_token == 3: # if </s> end of sentence token predicted, stop generating
        break
    tokens_target_pad, index = pad(tokens_target)
    y,_ = model.__call__(tf.Variable([tokens_input_pad]), tf.Variable([tokens_target_pad]))
    sample_token = sample_from_logits(y[0][index])
    tokens_target.append(sample_token)
    current_token = sample_token
    number_of_tokens += 1
    
tokens_target = list(map(lambda x: x.item(), tokens_target[1:len(tokens_target)]))
tokenizer.detokenize(tokens_target)