## Install KerasNLP and Import Dependencies.

In [None]:
!pip install -q -U git+https://github.com/keras-team/keras-nlp.git@master

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for keras-nlp (setup.py) ... [?25l[?25hdone


In [None]:
import keras_nlp
import tensorflow as tf
from tensorflow import keras

## Load `GPT2CausalLM` from KerasNLP.

`GPT2CausalLM` is basically a GPT2 model, followed by multiplying the outputs by the embedding matrix.

In [None]:
# To speed up, we use preprocessor of length 256 instead of full length 1024.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=256,
    add_end_token=True,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

Downloading data from https://storage.googleapis.com/keras-nlp/models/gpt2_base_en/v1/vocab.json
Downloading data from https://storage.googleapis.com/keras-nlp/models/gpt2_base_en/v1/merges.txt
Downloading data from https://storage.googleapis.com/keras-nlp/models/gpt2_base_en/v1/model.h5




## Finetune on Chinese Poem Dataset

We can also finetune GPT2 on non-English datasets. For readers knowing Chinese, this part illustrates how to finetung GPT2 on Chinese poem dataset to teach our model to become a poet!

In [None]:
# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git

Cloning into 'chinese-poetry'...
remote: Enumerating objects: 7210, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 7210 (delta 3), reused 13 (delta 3), pack-reused 7195[K
Receiving objects: 100% (7210/7210), 197.74 MiB | 35.88 MiB/s, done.
Resolving deltas: 100% (5292/5292), done.
Updating files: 100% (2282/2282), done.


In [None]:
import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/quan_tang_shi/json"):
    full_filename = "%s/%s" % ("chinese-poetry/quan_tang_shi/json", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

In [None]:
paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]
print(paragraphs[0])

數萼初含雪，孤標畫本難。香中別有韻，清極不知寒。橫笛和愁聽，斜枝倚病看。朔風如解意，容易莫摧殘。


In [None]:
train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .prefetch(tf.data.AUTOTUNE)
)
train_ds = train_ds.take(2000)

In [None]:
num_epochs = 3

lr = tf.keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(lr),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7f0bdc0f3a30>

In [None]:
output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output.numpy().decode("utf-8"))

昨夜雨疏风骤清，今朝暗見鶴悠悠。獨攜清淨深芳院，欹歌白鶴應頻別，獨自長江獨不同。


Not bad 😀

In [None]:
# You can save the weights for future usage.
gpt2_lm.backbone.save_weights("/content/model.h5")

In [None]:
!md5sum /content/model.h5