Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 112 additions & 68 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tf_utils import tensor_to_string_list
from keras_nlp.utils.tf_utils import truncate_at_token


@keras_nlp_export("keras_nlp.models.GPT2CausalLM")
Expand All @@ -49,7 +48,7 @@ class GPT2CausalLM(Task):
default, `"top_k"` sampling will be used.

This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
which case it will automatically apply preprocessing to string inputs during
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
when creating the model with `from_preset()`.

Expand Down Expand Up @@ -306,28 +305,23 @@ def make_generate_function(self):

def generate_step(
self,
token_ids,
padding_mask,
inputs,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.

This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. It takes in a dense `tf.Tensor` of token
ids, and return a dense `tf.Tensor` of token ids, and includes no
preprocessing. This function is wrapped by the `generate()` method.
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.

Args:
token_ids: A dense int Tensor, with shape
`(batch_size, max_length)`. The user provided token ids
padded to `max_length`.
padding_mask: A dense boolean Tensor, with the same shape as
`token_ids`. Positions that are True in the `padding_mask`
are assumed to be user input and never updated.
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
Expand All @@ -352,7 +346,7 @@ def next(prompt, cache, index):
cache,
)

return self._sampler(
token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
Expand All @@ -362,6 +356,78 @@ def next(prompt, cache, index):
hidden_states=hidden_states,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = (token_ids == end_token_id) & (~padding_mask)
end_locations = tf.cast(end_locations, tf.int32)
# Use cumsum to get ones in all locations after end_locations.
overflow = tf.math.cumsum(end_locations, exclusive=True)
# Our padding mask is the inverse of these overflow locations.
padding_mask = ~tf.cast(overflow, tf.bool)
else:
# Without early stopping, all locations will have been updated.
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def _normalize_generate_inputs(
self,
inputs,
):
"""Normalize user input to the generate function.

This function coverts all inputs to tensors, adds a batch dimension if
necessary, and returns a iterable "dataset like" object (either an
actual `tf.data.Dataset` or a list with a single batch element).
"""
input_is_scalar = False

if isinstance(inputs, tf.data.Dataset):
return inputs, input_is_scalar

if isinstance(inputs, str) or isinstance(inputs, list):
inputs = tf.convert_to_tensor(inputs)

if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0:
input_is_scalar = True
inputs = inputs[tf.newaxis]

# We avoid coverting to a dataset purely for speed, for a single batch
# of input, creating a dataset would add significant overhead.
return [inputs], input_is_scalar

def _normalize_generate_outputs(
self,
outputs,
input_is_scalar,
):
"""Normalize user output from the generate function.

This function converts all output to numpy (for integer output), or
python strings (for string output). If a batch dimension was added to
the input, it is removed from the output (so generate can be string in,
string out).
"""

def normalize(x):
x = tf.concat(x, axis=0)
x = tf.squeeze(x, 0) if input_is_scalar else x
is_string = x.dtype == tf.string
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
return tensor_to_string_list(x) if is_string else x.numpy()

if isinstance(outputs[0], dict):
return {
"token_ids": normalize([x["token_ids"] for x in outputs]),
"padding_mask": normalize([x["padding_mask"] for x in outputs]),
}
return normalize([x for x in outputs])

def generate(
self,
inputs,
Expand Down Expand Up @@ -397,65 +463,43 @@ def generate(
A string or string list if `preprocessor` is set, and a integer
tensor of token IDs if `preprocessor is None`.
"""
input_is_scalar = False

# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
# 2. Generate new tokens via a compiled function on dense tensors.
# 3. Optionally postprocess dense integer tensors back to string.
generate_function = self.make_generate_function()
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id

def preprocess(x):
return self.preprocessor(
x,
sequence_length=max_length,
return_labels=False,
# We do not append an end token by default during
# generation, as generating directly in the same sequence is
# the most common workflow. If an end token directly after
# a prompt is desired, it can be added to the prompt string.
add_end_token=False,
)

if not isinstance(inputs, tf.data.Dataset):
inputs = tf.convert_to_tensor(inputs)
input_is_scalar = inputs.shape.rank == 0
inputs = inputs[tf.newaxis] if input_is_scalar else inputs
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [preprocess(inputs)]
else:
def preprocess(x):
return self.preprocessor.generate_preprocess(
x, sequence_length=max_length
)

def generate(x):
return generate_function(x, end_token_id=end_token_id)

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

# Normalize inputs, apply our three passes, and normalize outputs.
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)

if self.preprocessor is not None:
if isinstance(inputs, tf.data.Dataset):
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
inputs = inputs.prefetch(tf.data.AUTOTUNE)
else:
if not isinstance(inputs, tf.data.Dataset):
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [inputs]
else:
# Fast path for non-dataset, single-batch input.
inputs = [preprocess(x) for x in inputs]

generate_function = self.make_generate_function()
outputs = []
for batch in inputs:
token_ids, padding_mask = batch["token_ids"], batch["padding_mask"]
# If `preprocessor` is attached, we can stop after end_token_id.
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id
# Run the compiled generate function.
output = generate_function(token_ids, padding_mask, end_token_id)

if self.preprocessor is not None:
# Truncate to ragged by removing tokens after the first
# generated `end_token_id`.
output = truncate_at_token(output, end_token_id, padding_mask)
# Strip start token if added.
if self.preprocessor.add_start_token:
output = output[:, 1:]
# Detokenize.
output = self.preprocessor.tokenizer.detokenize(output)
outputs.append(output)

outputs = tf.concat(outputs, axis=0)
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
if outputs.dtype == tf.string:
return tensor_to_string_list(outputs)
return outputs.numpy()
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

return self._normalize_generate_outputs(outputs, input_is_scalar)

@classmethod
def create_layout_map(cls, mesh):
Expand Down
106 changes: 73 additions & 33 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,32 @@

"""GPT2 Causal LM preprocessor layer."""

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.GPT2CausalLMPreprocessor")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.

This preprocessing layer is primarily meant to be used with
This preprocessing layer is meant for use with
`keras_nlp.models.GPT2CausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence. For use with generation,
pass `return_labels=False`, in which case the output will simply be the
encoded string features.
`y` label is the next token id in the `x` sequence.

For use with generation, the layer also exposes two methods
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
is attached to a `keras_nlp.models.GPT2CausalLM` instance, these methods
will be called implicitly in `generate()`. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).

Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
Expand All @@ -47,12 +56,6 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor):
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
add_start_token: Pass to override the configured value of
`add_start_token` on the layer.
add_end_token: Pass to override the configured value of
`add_end_token` on the layer.
return_labels: If `True`, the output `"token_ids"` will be offset by one
and returned as labels. If `False` only features will be returned.

Examples:
```python
Expand Down Expand Up @@ -95,9 +98,6 @@ def call(
y=None,
sample_weight=None,
sequence_length=None,
add_start_token=None,
add_end_token=None,
return_labels=True,
):
if y is not None or sample_weight is not None:
logging.warning(
Expand All @@ -106,25 +106,65 @@ def call(
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)
if return_labels:
# Tokenize with one extra token to account for the truncation below.
sequence_length = (sequence_length or self.sequence_length) + 1
x = super().call(
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
if return_labels:
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
else:
return x
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
"""Covert strings to integer token input for generation.

Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.

Unlike calling the the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
"""Covert integer token output to strings for generation.

This method reverses `generate_preprocess()`, by first removing all
padding and start/end tokens, and then converting the interger sequence
back to a string.
"""
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
Loading