# Example: GPT2 Text Generation with KerasNLP

**Based on a example by:** Chen Qian<br>


In this tutorial, you will learn to use [KerasNLP](https://keras.io/keras_nlp/) to load a
pre-trained Large Language Model (LLM) - [GPT-2 model](https://openai.com/research/better-language-models)
(originally invented by OpenAI), finetune it to a specific text style, and
generate text based on users' input (also known as prompt). You will also learn
how GPT2 adapts quickly to non-English languages, such as Chinese.

##  Before we begin

Colab offers different kinds of runtimes. Make sure to go to **Runtime ->
Change runtime type** and choose the GPU Hardware Accelerator runtime
(which should have >12G host RAM and ~15G GPU RAM) since you will finetune the
GPT-2 model. Running this tutorial on CPU runtime will take hours.

## Install KerasNLP, Choose Backend and Import Dependencies

This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of
`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras Core is baked into
KerasNLP, simply change the `"KERAS_BACKEND"` environment variable to select
the backend of your choice. We select the JAX backend below.

In [4]:
!pip install git+https://github.com/keras-team/keras-nlp.git -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [5]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"  # or "tensorflow" or "torch"

import keras_nlp
import keras
import tensorflow as tf
import time

keras.mixed_precision.set_global_policy("mixed_float16")

## Introduction to Generative Large Language Models (LLMs)

Large language models (LLMs) are a type of machine learning models that are
trained on a large corpus of text data to generate outputs for various natural
language processing (NLP) tasks, such as text generation, question answering,
and machine translation.

Generative LLMs are typically based on deep learning neural networks, such as
the [Transformer architecture](https://arxiv.org/abs/1706.03762) invented by
Google researchers in 2017, and are trained on massive amounts of text data,
often involving billions of words. These models, such as Google [LaMDA](https://blog.google/technology/ai/lamda/)
and [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html),
are trained with a large dataset from various data sources which allows them to
generate output for many tasks. The core of Generative LLMs is predicting the
next word in a sentence, often referred as **Causal LM Pretraining**. In this
way LLMs can generate coherent text based on user prompts.

## Introduction to KerasNLP

Large Language Models are complex to build and expensive to train from scratch.
Luckily there are pretrained LLMs available for use right away. [KerasNLP](https://keras.io/keras_nlp/)
provides a large number of pre-trained checkpoints that allow you to experiment
with SOTA models without needing to train them yourself.

KerasNLP is a natural language processing library that supports users through
their entire development cycle. KerasNLP offers both pretrained models and
modularized building blocks, so developers could easily reuse pretrained models
or stack their own LLM.

In a nutshell, for generative LLM, KerasNLP offers:

- Pretrained models with `generate()` method, e.g.,
    `keras_nlp.models.GPT2CausalLM` and `keras_nlp.models.OPTCausalLM`.
- Sampler class that implements generation algorithms such as Top-K, Beam and
    contrastive search. These samplers can be used to generate text with
    custom models.

## Load a pre-trained GPT-2 model and generate some text

KerasNLP provides a number of pre-trained models, such as [Google
Bert](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)
and [GPT-2](https://openai.com/research/better-language-models). You can see
the list of models available in the [KerasNLP repository](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models).

It's very easy to load the GPT-2 model as you can see below:

In [6]:
# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/tokenizer.json...
INFO:kagglehub.clients:Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/tokenizer.json...
100%|██████████| 448/448 [00:00<00:00, 1.16MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/assets/tokenizer/merges.txt...
INFO:kagglehub.clients:Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/assets/tokenizer/merges.txt...
100%|██████████| 446k/446k [00:00<00:00, 527kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/assets/tokenizer/vocabulary.json...
INFO:kagglehub.clients:Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/assets/tokenizer/vocabulary.json...
100%|██████████| 0.99M/0.99M [00:01<00:00, 1.04MB/s]
Downloading from https://www.kaggle.com

Once the model is loaded, you can use it to generate some text right away. Run
the cells below to give it a try. It's as simple as calling a single function
*generate()*:

In [7]:
start = time.time()

output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")


GPT-2 output:
My trip to Yosemite was a great success. It was a wonderful day for me. The sun was up, and I saw the beautiful view from my home in Yosemite. The view of the Pacific was stunning. I had a nice time.

I also had some great food. It wasn't the most popular place I've been in Yosemite, but it was delicious! I got two slices of salmon, two slices of steak, and one of my own. I also had some chicken, a few slices of cheese and two slices of bacon. I was also able to eat a few more sandwiches. There was a lot of food and great atmosphere. I was also able to get some of my food from the park. I was able to pick up some of my own. The food was good, and the food was good! I had the salmon and the steak, and the cheese. I got some of the cheese and the bacon. The food was great. The food was good and
TOTAL TIME ELAPSED: 29.00s


Try another one:

In [8]:
start = time.time()

output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")


GPT-2 output:
That Italian restaurant is the latest to join the ranks of the most popular places to eat in the United States, and it's one of the few restaurants that's been able to make a name for itself in the country.

The Italian restaurant's menu features a variety of dishes, such as the classic spaghetti, a classic Italian salad, a classic pasta dish, and a special pasta dish that's served with a variety of vegetables. It has a wide variety of meats and fish, but the main menu item is the classic Italian dish, which comes with a side of fresh basil pesto and an assortment of vegetables, such as a sweet potato salad, a sweet potato salad, and a tomato salad.

But wait, there's more.

Italian restaurant owners and chefs from across the country are making a comeback with a new restaurant, called the Bistro della Vida del Vida in Santa Ana, Calif.

"I've been here for over 10 years
TOTAL TIME ELAPSED: 1.82s


Notice how much faster the second call is. This is because the computational
graph is [XLA compiled](https://www.tensorflow.org/xla) in the 1st run and
re-used in the 2nd behind the scenes.

The quality of the generated text looks OK, but we can improve it via
fine-tuning.

## More on the GPT-2 model from KerasNLP

Next up, we will actually fine-tune the model to update its parameters, but
before we do, let's take a look at the full set of tools we have to for working
with for GPT2.

The code of GPT2 can be found
[here](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/).
Conceptually the `GPT2CausalLM` can be hierarchically broken down into several
modules in KerasNLP, all of which have a *from_preset()* function that loads a
pretrained model:

- `keras_nlp.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a
    [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
- `keras_nlp.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2
    causal LM training. It does the tokenization along with other preprocessing
    works such as creating the label and appending the end token.
- `keras_nlp.models.GPT2Backbone`: the GPT2 model, which is a stack of
    `keras_nlp.layers.TransformerDecoder`. This is usually just referred as
    `GPT2`.
- `keras_nlp.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the
    output of `GPT2Backbone` by embedding matrix to generate logits over
    vocab tokens.

## Finetune on Reddit dataset

Now you have the knowledge of the GPT-2 model from KerasNLP, you can take one
step further to finetune the model so that it generates text in a specific
style, short or long, strict or casual. In this tutorial, we will use reddit
dataset for example.

In [9]:
import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

Downloading and preparing dataset 639.54 MiB (download: 639.54 MiB, generated: 141.46 MiB, total: 781.00 MiB) to /root/tensorflow_datasets/reddit_tifu/short/1.1.2...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/79740 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/reddit_tifu/short/1.1.2.incompleteXJCNIK/reddit_tifu-train.tfrecord*...:  …

Dataset reddit_tifu downloaded and prepared to /root/tensorflow_datasets/reddit_tifu/short/1.1.2. Subsequent calls will reuse this data.


Let's take a look inside sample data from the reddit TensorFlow Dataset. There
are two features:

- **__document__**: text of the post.
- **__title__**: the title.

In [10]:
for document, title in reddit_ds:
    print(document.numpy())
    print(title.numpy())
    break

b"me and a friend decided to go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. \n\nnow i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. \n\nwe arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. \n\nfunny thing about seafood. it runs through me faster than a kenyan \n\nwe arrived and walked around a bit. it was about 45min since we a

In our case, we are performing next word prediction in a language model, so we
only need the 'document' feature.

In [11]:
train_ds = (
    reddit_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

Now you can finetune the model using the familiar *fit()* function. Note that
`preprocessor` will be automatically called inside `fit` method since
`GPT2CausalLM` is a `keras_nlp.models.Task` instance.

This step takes quite a bit of GPU memory and a long time if we were to train
it all the way to a fully trained state. Here we just use part of the dataset
for demo purposes.

In [12]:
train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)



<keras.src.callbacks.History at 0x7c23f9594970>

After fine-tuning is finished, you can again generate text using the same
*generate()* function. This time, the text will be closer to Reddit writing
style, and the generated length will be close to our preset length in the
training set.

In [13]:
start = time.time()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")


GPT-2 output:
I like basketball, but it doesn't really fit me as much as i would like it to, so i decided to try and get some basketball.

i was playing basketball with a friend, and we decided we should play basketball together. i was playing with some friends, but it just seemed like a waste to me to try and play with a girl that i didn't know. so i went to the bathroom and went to grab my basketball. i grabbed it, grabbed a towel and started playing with my friend, but when i saw my friend
TOTAL TIME ELAPSED: 14.52s


## Into the Sampling Method

In KerasNLP, we offer a few sampling methods, e.g., contrastive search,
Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but
you can choose your own sampling method.

Much like optimizer and activations, there are two ways to specify your custom
sampler:

- Use a string identifier, such as "greedy", you are using the default
configuration via this way.
- Pass a `keras_nlp.samplers.Sampler` instance, you can use custom configuration
via this way.

In [14]:
# Use a string identifier.
gpt2_lm.compile(sampler="top_k")
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)

# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
greedy_sampler = keras_nlp.samplers.GreedySampler()
gpt2_lm.compile(sampler=greedy_sampler)

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)


GPT-2 output:
I like basketball so it was great. i was sitting in the locker room and the basketball was on the other side of the room. 

so i go into the locker room to grab some basketball and a couple of guys are waiting. the guys in the locker room start talking about the game. they all look at me and say "you can't do that." 

i'm like "well i can, but you can't dunk on me." 

they all say "you can, but not that big."

they all look me up, and then they all say "i'm going to dunk on you

GPT-2 output:
I like basketball, but i don't really like the game. 

so i was playing basketball at my local high school, and i was playing with my friends. 

i was playing with my friends, and i was playing with my brother, who was playing basketball with his brother. 

so i was playing with my brother, and he was playing with his brother's brother. 

so i was playing with my brother, and he was playing with his brother's brother. 

so i was playing with my brother, and he was playing with his b

For more details on KerasNLP `Sampler` class, you can check the code
[here](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/samplers).

## Finetune on Chinese Poem Dataset

We can also finetune GPT2 on non-English datasets. For readers knowing Chinese,
this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our
model to become a poet!

Because GPT2 uses byte-pair encoder, and the original pretraining dataset
contains some Chinese characters, we can use the original vocab to finetune on
Chinese dataset.

In [15]:
!# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git

Cloning into 'chinese-poetry'...
remote: Enumerating objects: 7309, done.[K
remote: Counting objects: 100% (114/114), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 7309 (delta 38), reused 78 (delta 20), pack-reused 7195[K
Receiving objects: 100% (7309/7309), 199.68 MiB | 15.76 MiB/s, done.
Resolving deltas: 100% (5326/5326), done.
Updating files: 100% (2285/2285), done.


Load text from the json file. We only use《全唐诗》for demo purposes.

In [16]:
import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
        continue
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

Let's take a look at sample data.

In [17]:
print(paragraphs[0])

輞川題作浪，武昌栽屬官。愛思如孔階，茸髯擺奇觀。


Similar as Reddit example, we convert to TF dataset, and only use partial data
to train.

In [18]:
train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)



<keras.src.callbacks.History at 0x7c24564306a0>

Let's check the result!

In [19]:
output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output)



昨夜雨疏风骤晚，晚樹欲樹花花爭。曲樹若若花，發發靈居風虛。若山靈山風晚，樹樹爲城樹。


Not bad 😀