Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions keras_nlp/layers/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ class StartEndPacker(keras.layers.Layer):
pad_value: int/str. The ID or token that is to be placed into the
unused positions after the last segment in the sequence. If None,
0 or "" will be added depending on the dtype of the input tensor.
return_padding_mask: bool. Whether to return a boolean padding mask of
all locations that are filled in with the `pad_value`.

Call arguments:
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
add_start_value: Pass `False` to not append a start value for this
input.
add_end_value: Pass `False` to not append an end value for this
input.

Examples:

Expand Down Expand Up @@ -94,6 +105,7 @@ def __init__(
start_value=None,
end_value=None,
pad_value=None,
return_padding_mask=False,
name=None,
**kwargs,
):
Expand All @@ -103,48 +115,60 @@ def __init__(
self.start_value = start_value
self.end_value = end_value
self.pad_value = pad_value
self.return_padding_mask = return_padding_mask

def call(self, inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)
def call(
self,
inputs,
sequence_length=None,
add_start_value=True,
add_end_value=True,
):
x = inputs # Intermediate result.

if not isinstance(x, (tf.Tensor, tf.RaggedTensor)):
x = tf.convert_to_tensor(x)

input_is_1d = False
if inputs.shape.rank < 1 or inputs.shape.rank > 2:
input_is_1d = x.shape.rank == 1
if x.shape.rank < 1 or x.shape.rank > 2:
raise ValueError(
"Input must either be rank 1 or rank 2. Received input with "
f"rank={inputs.shape.rank}"
f"rank={x.shape.rank}"
)
elif inputs.shape.rank == 1:
input_is_1d = True
if input_is_1d:
# Add a new axis at the beginning.
inputs = tf.expand_dims(inputs, axis=0)
if isinstance(inputs, tf.Tensor):
x = tf.expand_dims(x, axis=0)
if isinstance(x, tf.Tensor):
# Convert to ragged tensor.
inputs = tf.RaggedTensor.from_tensor(inputs)
x = tf.RaggedTensor.from_tensor(x)

batch_size = tf.shape(inputs)[0]
batch_size = tf.shape(x)[0]
sequence_length = sequence_length or self.sequence_length

# Concatenate start and end tokens.
if self.start_value is not None:
if add_start_value and self.start_value is not None:
start_token_id_tensor = tf.fill((batch_size, 1), self.start_value)
inputs = tf.concat([start_token_id_tensor, inputs], axis=-1)
if self.end_value is not None:
x = tf.concat([start_token_id_tensor, x], axis=-1)
if add_end_value and self.end_value is not None:
end_token_id_tensor = tf.fill((batch_size, 1), self.end_value)

# Trim to leave room for end token.
inputs = inputs[..., : self.sequence_length - 1]
inputs = tf.concat([inputs, end_token_id_tensor], axis=-1)
x = x[..., : sequence_length - 1]
x = tf.concat([x, end_token_id_tensor], axis=-1)

# Pad to desired length.
inputs = inputs.to_tensor(
outputs = x.to_tensor(
default_value=self.pad_value,
shape=(batch_size, self.sequence_length),
shape=(batch_size, sequence_length),
)
outputs = tf.squeeze(outputs, axis=0) if input_is_1d else outputs

if input_is_1d:
inputs = tf.squeeze(inputs, axis=0)
if self.return_padding_mask:
mask = tf.ones_like(x, dtype=tf.bool)
mask = mask.to_tensor(shape=(batch_size, sequence_length))
mask = tf.squeeze(mask, axis=0) if input_is_1d else mask
return outputs, mask

return inputs
return outputs

def get_config(self):
config = super().get_config()
Expand All @@ -154,6 +178,7 @@ def get_config(self):
"start_value": self.start_value,
"end_value": self.end_value,
"pad_value": self.pad_value,
"return_padding_mask": self.return_padding_mask,
}
)
return config
8 changes: 8 additions & 0 deletions keras_nlp/layers/start_end_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def test_batch(self):
for i in range(output.shape[0]):
self.assertAllEqual(output[i], exp_output[i])

def test_call_overrides(self):
x = tf.constant([5, 6, 7])
packer = StartEndPacker(start_value=1, end_value=2, sequence_length=4)
self.assertAllEqual(packer(x), [1, 5, 6, 2])
self.assertAllEqual(packer(x, add_start_value=False), [5, 6, 7, 2])
self.assertAllEqual(packer(x, add_end_value=False), [1, 5, 6, 7])
self.assertAllEqual(packer(x, sequence_length=2), [1, 2])

def test_get_config(self):
start_end_packer = StartEndPacker(
sequence_length=512,
Expand Down
12 changes: 5 additions & 7 deletions keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class GPT2Backbone(Backbone):
),
}

# Pretrained GPT-2 decoder
model = GPT2Backbone.from_preset("gpt2_base_en")
output = model(input_data)
# Pretrained GPT-2 decoder.
model = keras_nlp.models.GPT2Backbone.from_preset("gpt2_base_en")
model(input_data)

# Randomly initialized GPT-2 decoder with custom config
# Randomly initialized GPT-2 decoder with custom config.
model = keras_nlp.models.GPT2Backbone(
vocabulary_size=50257,
num_layers=12,
Expand All @@ -86,9 +86,7 @@ class GPT2Backbone(Backbone):
intermediate_dim=3072,
max_sequence_length=1024,
)

# Call the model on the input data.
output = model(input_data)
model(input_data)
```
"""

Expand Down
83 changes: 38 additions & 45 deletions keras_nlp/models/gpt2/gpt2_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,70 +25,63 @@

class GPT2Test(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
self.model = GPT2Backbone(
vocabulary_size=1000,
self.backbone = GPT2Backbone(
vocabulary_size=10,
num_layers=2,
num_heads=2,
hidden_dim=64,
intermediate_dim=128,
max_sequence_length=128,
hidden_dim=2,
intermediate_dim=4,
max_sequence_length=5,
)
self.batch_size = 8
self.input_batch = {
"token_ids": tf.ones(
(self.batch_size, self.model.max_sequence_length), dtype="int32"
),
"padding_mask": tf.ones(
(self.batch_size, self.model.max_sequence_length), dtype="int32"
),
"token_ids": tf.ones((2, 5), dtype="int32"),
"segment_ids": tf.ones((2, 5), dtype="int32"),
"padding_mask": tf.ones((2, 5), dtype="int32"),
}

self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
).batch(2)

def test_valid_call_gpt2(self):
self.model(self.input_batch)
def test_call(self):
self.backbone(self.input_batch)

def test_token_embedding(self):
output = self.backbone.token_embedding(self.input_batch["token_ids"])
self.assertEqual(output.shape, (2, 5, 2))

def test_name(self):
# Check default name passed through
self.assertRegexpMatches(self.model.name, "gpt2_backbone")
self.assertRegexpMatches(self.backbone.name, "gpt2_backbone")

def test_variable_sequence_length_call_gpt2(self):
for seq_length in (25, 50, 75):
def test_variable_sequence_length(self):
for seq_length in (2, 3, 4):
input_data = {
"token_ids": tf.ones(
(self.batch_size, seq_length), dtype="int32"
),
"padding_mask": tf.ones(
(self.batch_size, seq_length), dtype="int32"
),
"token_ids": tf.ones((2, seq_length), dtype="int32"),
"padding_mask": tf.ones((2, seq_length), dtype="int32"),
}
self.model(input_data)
self.backbone(input_data)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_gpt2_compile(self, jit_compile):
self.model.compile(jit_compile=jit_compile)
self.model.predict(self.input_batch)
def test_predict(self):
self.backbone.predict(self.input_batch)
self.backbone.predict(self.input_dataset)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_gpt2_compile_batched_ds(self, jit_compile):
self.model.compile(jit_compile=jit_compile)
self.model.predict(self.input_dataset)
def test_serialization(self):
new_backbone = keras.utils.deserialize_keras_object(
keras.utils.serialize_keras_object(self.backbone)
)
self.assertEqual(new_backbone.get_config(), self.backbone.get_config())

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
)
@pytest.mark.large
def test_saved_model(self, save_format, filename):
model_output = self.model(self.input_batch)
model_output = self.backbone(self.input_batch)
path = os.path.join(self.get_temp_dir(), filename)
# Don't save traces in the tf format, we check compilation elsewhere.
kwargs = {"save_traces": False} if save_format == "tf" else {}
self.model.save(path, save_format=save_format, **kwargs)
self.backbone.save(path, save_format=save_format, **kwargs)
restored_model = keras.models.load_model(path)

# Check we got the real object back.
Expand All @@ -105,16 +98,16 @@ class GPT2BackboneTPUTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
with self.tpu_strategy.scope():
self.model = GPT2Backbone(
vocabulary_size=1000,
vocabulary_size=10,
num_layers=2,
num_heads=2,
hidden_dim=64,
intermediate_dim=128,
max_sequence_length=128,
hidden_dim=2,
intermediate_dim=4,
max_sequence_length=5,
)
self.input_batch = {
"token_ids": tf.ones((8, 128), dtype="int32"),
"padding_mask": tf.ones((8, 128), dtype="int32"),
"token_ids": tf.ones((2, 5), dtype="int32"),
"padding_mask": tf.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
Expand Down
Loading