<a href="https://colab.research.google.com/github/deep-diver/gpt2-ft-pipeline/blob/main/notebooks/saved_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q keras-nlp

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.7/527.7 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m84.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import tensorflow as tf
import keras_nlp
from tensorflow.python.saved_model import tag_constants

In [110]:
gpt2_tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en")
gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=256,
    add_end_token=True,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en", preprocessor=gpt2_preprocessor)



In [172]:
signature_dict = {
    "prompt": tf.TensorSpec(shape=[], dtype=tf.string, name="prompt"),
    "max_length": tf.TensorSpec(shape=[], dtype=tf.int64, name="max_length"),
}

def gpt2_lm_exporter(model):
  @tf.function(input_signature=[signature_dict])
  def serving_fn(inputs):
    prompt = tf.convert_to_tensor(inputs["prompt"])
    input_is_scalar = prompt.shape.rank == 0
    prompt = prompt[tf.newaxis] if input_is_scalar else prompt
    prompt = model.preprocessor.tokenizer(prompt)

    # Pad ragged to dense tensors.
    padded_shape = (1, inputs["max_length"])
    min_length = tf.reduce_min(prompt.row_lengths())
    input_mask = tf.ones_like(prompt, tf.bool).to_tensor(shape=padded_shape)
    prompt = prompt.to_tensor(shape=padded_shape)
    prompt = tf.cast(prompt, dtype="int64")

    generate_function = model.make_generate_function()
    output = generate_function({"token_ids": prompt, "padding_mask": input_mask}, min_length)

    token_ids, padding_mask = output["token_ids"], output["padding_mask"]
    padding_mask = padding_mask & (token_ids != model.preprocessor.tokenizer.end_token_id)
    token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)

    token_ids = tf.cast(token_ids, dtype="int32")
    unicode_text = tf.strings.reduce_join(
        model.preprocessor.tokenizer.id_to_token_map.lookup(token_ids), axis=-1
    )
    split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
    byte_text = tf.strings.reduce_join(
        model.preprocessor.tokenizer.unicode2byte.lookup(split_unicode_text), axis=-1
    )
    byte_text = tf.concat(byte_text, axis=0)
    byte_text = tf.squeeze(byte_text, 0)
    return {"result": byte_text}

  return serving_fn

In [173]:
tf.saved_model.save(
    gpt2_lm,
    "./gpt_lm_custom/1/",
    signatures={"serving_default": gpt2_lm_exporter(gpt2_lm)},
)



In [174]:
!saved_model_cli show --dir gpt_lm_custom/1/ --tag_set serve --signature_def serving_default

The given SavedModel SignatureDef contains the following input(s):
  inputs['max_length'] tensor_info:
      dtype: DT_INT64
      shape: ()
      name: serving_default_max_length:0
  inputs['prompt'] tensor_info:
      dtype: DT_STRING
      shape: ()
      name: serving_default_prompt:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['result'] tensor_info:
      dtype: DT_STRING
      shape: ()
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict


In [175]:
saved_model_loaded = tf.saved_model.load("gpt_lm_custom/1", tags=[tag_constants.SERVING])
gpt_lm_predict_fn = saved_model_loaded.signatures["serving_default"]

In [176]:
gpt_lm_predict_fn

<ConcreteFunction signature_wrapper(*, max_length, prompt) at 0x7F2DC74F7E20>

In [177]:
prompt = tf.constant("hello world")
max_length = tf.constant(100, dtype="int64")

In [178]:
result = gpt_lm_predict_fn(
    prompt=prompt,
    max_length=max_length,
    # batch_size=batch_size
)

In [179]:
result

{'result': <tf.Tensor: shape=(), dtype=string, numpy=b'hello world of gaming and entertainment.\n\nThe game is based on a game called "Dread Pirate Roberts" by the creators of the popular game "Dread Pirate Roberts" (also known as DOTA or DOTA 2).\n\nIt\'s the second most popular online role-playing game in the world and the second most downloaded game.\n\nIn addition to being the most downloaded game on Steam and the most used app, DOTA is the third most downloaded game on the Android App'>}