# Text generation with a miniature GPT

**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>
**Date created:** 2020/05/29<br>
**Last modified:** 2020/05/29<br>
**Description:** Implement a miniature version of GPT and train it to generate text.

## Introduction

This example demonstrates how to implement an autoregressive language model
using a miniature version of the GPT model.
The model consists of a single Transformer block with causal masking
in its attention layer.
We use the text from the IMDB sentiment classification dataset for training
and generate new movie reviews for a given prompt.
When using this script with your own dataset, make sure it has at least
1 million words.

This example should be run with `tf-nightly>=2.3.0-dev20200531` or
with TensorFlow 2.3 or higher.

**References:**

- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)
- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)
- [GPT-3](https://arxiv.org/abs/2005.14165)

## Setup

In [None]:
# We set the backend to TensorFlow. The code works with
# both `tensorflow` and `torch`. It does not work with JAX
# due to the behavior of `jax.numpy.tile` in a jit scope
# (used in `causal_attention_mask()`: `tile` in JAX does
# not support a dynamic `reps` argument.
# You can make the code work in JAX by wrapping the
# inside of the `causal_attention_mask` function in
# a decorator to prevent jit compilation:
# `with jax.ensure_compile_time_eval():`.
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers
from keras import ops
from keras.layers import TextVectorization
import numpy as np
import os
import string
import random
import tensorflow
import tensorflow.data as tf_data
import tensorflow.strings as tf_strings
import tensorflow_text as tf_text
import sentencepiece as spm
import tensorflow as tf


## Implement a Transformer block as a layer

In [None]:

def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
    i = ops.arange(n_dest)[:, None]
    j = ops.arange(n_src)
    m = i >= j - n_src + n_dest
    mask = ops.cast(m, dtype)
    mask = ops.reshape(mask, [1, n_dest, n_src])
    mult = ops.concatenate(
        [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
    )
    return ops.tile(mask, mult)


class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads, embed_dim)
        self.ffn = keras.Sequential(
            [
                layers.Dense(ff_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
        attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
        attention_output = self.dropout1(attention_output)
        out1 = self.layernorm1(inputs + attention_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)


## Implement an embedding layer

Create two separate embedding layers: one for tokens and one for token index
(positions).

In [None]:

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = ops.shape(x)[-1]
        positions = ops.arange(0, maxlen, 1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions


## Implement the miniature GPT model

In [None]:
maxlen = 80  # Max sequence size
embed_dim = 256  # Embedding size for each token
num_heads = 2  # Number of attention heads
feed_forward_dim = 256  # Hidden layer size in feed forward network inside transformer
vocab_size = 5000  # Only consider the top 20k words

def create_model():
    inputs = layers.Input(shape=(maxlen,), dtype="int32")
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    x = embedding_layer(inputs)
    transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
    x = transformer_block(x)
    outputs = layers.Dense(vocab_size)(x)
    model = keras.Model(inputs=inputs, outputs=[outputs, x])
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = keras.optimizers.Adam(epsilon=1e-9, beta_1=0.9, beta_2=0.95, clipvalue=1)
    model.compile(
        optimizer=optimizer,
        loss=[loss_fn, None],
    )  # No loss and optimization based on word embeddings from transformer block
    return model


## Prepare the data for word-level language modelling

Download the IMDB dataset and combine training and validation sets for a text
generation task.

In [None]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  19.6M      0  0:00:04  0:00:04 --:--:-- 19.6M


In [None]:
batch_size = 128

# The dataset contains each review in a separate text file
# The text files are present in four different folders
# Create a list all files
filenames = []
directories = [
    "aclImdb/train/pos",
    "aclImdb/train/neg",
    "aclImdb/test/pos",
    "aclImdb/test/neg",
]
for dir in directories:
    for f in os.listdir(dir):
        filenames.append(os.path.join(dir, f))

print(f"{len(filenames)} files")

def train_sentencepiece(filenames, vocab_size=5000, model_prefix='sp_model'):
    with open('temp_data.txt', 'w', encoding='utf-8') as f:
        for fname in filenames:
            with open(fname, 'r', encoding='utf-8') as infile:
                text = infile.read().lower().replace('<br />', ' ')
                f.write(text + '\n')

    spm.SentencePieceTrainer.train(
        input='temp_data.txt',
        model_prefix=model_prefix,
        vocab_size=vocab_size,
        model_type='bpe',
        character_coverage=0.995,
        unk_surface=' <unk> ',
        bos_id=1,
        eos_id=2,
        unk_id=0,
        pad_id=3
    )
    os.remove('temp_data.txt')

# Train the model
train_sentencepiece(filenames, vocab_size)

# Create TensorFlow tokenizer
with open('sp_model.model', 'rb') as f:
    sp_model = f.read()

tokenizer = tf_text.SentencepieceTokenizer(model=sp_model, out_type=tf.int32)

def custom_standardization(input_string):
    """Remove html line-break tags"""
    lowercased = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercased, "<br />", " ")

def tokenize_and_pad(text):
    text = custom_standardization(text)
    tokens = tokenizer.tokenize(text)
    # Pad/truncate to maxlen + 1
    tokens = tokens[-maxlen:]
    padded = tf.pad(tokens, [[0, maxlen - tf.shape(tokens)[0]]], constant_values=3)  # pad with pad_id=3
    return padded

random.shuffle(filenames)
text_ds = tf.data.TextLineDataset(filenames)
text_ds = text_ds.shuffle(buffer_size=256)
text_ds = text_ds.repeat(2)
text_ds = text_ds.batch(batch_size)

text_ds = text_ds.map(
    lambda x: tf.map_fn(tokenize_and_pad, x, fn_output_signature=tf.int32),
    num_parallel_calls=tf.data.AUTOTUNE
)

def prepare_lm_inputs_labels(text):
    """
    Shift word sequences by 1 position so that the target for position (i) is
    word at position (i+1). The model will use all words up till position (i)
    to predict the next word.
    """
    x = text[:, :-1]
    y = text[:, 1:]
    return x, y

text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf.data.AUTOTUNE)
text_ds = text_ds.prefetch(tf.data.AUTOTUNE)

sp_processor = spm.SentencePieceProcessor()
sp_processor.load('sp_model.model')
vocab_size = sp_processor.get_piece_size()
print(f"Final vocabulary size: {vocab_size}")

def decode_tokens(tokens):
    return sp_processor.decode(tokens)

def get_vocab(tokenizer="sp_model.model"):
    sp_processor = spm.SentencePieceProcessor()
    sp_processor.load(tokenizer)

    vocab = []
    for i in range(sp_processor.get_piece_size()):
        vocab.append(sp_processor.id_to_piece(i))

    return vocab

50000 files
Final vocabulary size: 5000


## Implement a Keras callback for generating text

In [None]:

class TextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model.
    1. Feed some starting prompt to the model
    2. Predict probabilities for the next token
    3. Sample the next token and add it to the next input

    Arguments:
        max_tokens: Integer, the number of tokens to be generated after prompt.
        start_tokens: List of integers, the token indices for the starting prompt.
        index_to_word: List of strings, obtained from the TextVectorization layer.
        top_k: Integer, sample from the `top_k` token predictions.
        print_every: Integer, print after this many epochs.
    """

    def __init__(
        self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
    ):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k

    def sample_from(self, logits):
        logits, indices = ops.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.index_to_word[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = [_ for _ in self.start_tokens]
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[-maxlen:]
                sample_index = maxlen - 1
            else:
                x = start_tokens + [3] * pad_len
            # else:
            #     x = start_tokens
            x = np.array([x])
            y, _ = self.model.predict(x, verbose=0)
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)

        txt = decode_tokens([int(t) for t in self.start_tokens + tokens_generated])
        print(f"generated text:\n{txt}\n")


# Tokenize starting prompt
word_to_index = {}
vocab = get_vocab()
for index, word in enumerate(vocab):
    word_to_index[word] = index


start_prompt = "this movie is"
start_tokens = tokenizer.tokenize(start_prompt).numpy().tolist()
num_tokens_generated = 100
text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)


## Train the model

Note: This code should preferably be run on GPU.

In [None]:
model = create_model()

model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])

Epoch 1/25




generated text:
this movie is the only thing to get a bit of the movie, it's just a great story to me. the ending, this is a good for those movies you can't be a great movie. i recommend this movie, it's a very well worth watching if you want to watch it. it would be better off. the ending is just a bit of fun, bad...and just too hard just a little too hard to be to get. the same, i would be the only be found

782/782 - 83s - 107ms/step - loss: 4.8343
Epoch 2/25


[62, 98, 43]