New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Deletion Layer #214
Changes from 17 commits
bf3219f
66c0de7
d2508dc
9e2e243
b48dc5a
02ce27d
f137256
c7dcb8a
515a1d3
6acfb62
68fcae0
058d572
7a57339
9cf9a2a
ce10e2d
a1ab88e
17e2365
b5cfe45
839a770
ec2d4ed
dbdd690
fd856ad
292353b
9c447dc
35cf31a
b26dd24
983bfb3
4813f1f
86bafbe
46eda29
906614a
20df2b2
0c8fdc5
3d87d14
5ee888a
a77880e
9c1904c
b87394d
81ba7bf
0da28dc
d91a1bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
|
||
class RandomWordDeletion(keras.layers.Layer): | ||
"""Augments input by randomly deleting words. | ||
|
||
The layer works by splitting the words using `tf.strings.split` computes | ||
the indices to keep randomly and masks out the ones to be deleted which are | ||
then removed before returning and the remaining tokens are joined back. | ||
|
||
Args: | ||
probability: probability of a word being chosen for deletion | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: argument type, and period at the end. Also maybe call this |
||
max_deletions: The maximum number of words to delete | ||
|
||
Examples: | ||
|
||
Basic usage. | ||
>>> tf.random.get_global_generator().reset_from_seed(30) | ||
>>> tf.random.set_seed(30) | ||
>>> augmenter = keras_nlp.layers.RandomWordDeletion( | ||
... probability = 0.7, | ||
... max_deletions = 2, | ||
... ) | ||
>>> augmenter(["I like to fly kites, do you?", | ||
... "Can we go fly some kites later?"]) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I fly kites, do you?', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output becomes byte string. I am wondering is it still a valid input to our tokenizers? |
||
b'Can we fly kites later?'], dtype=object)> | ||
|
||
Augment first, then batch the dataset. | ||
>>> tf.random.get_global_generator().reset_from_seed(30) | ||
>>> tf.random.set_seed(30) | ||
>>> inputs = ["I like to fly kites, do you?", | ||
... "Can we go fly some kites later?"] | ||
>>> augmenter = keras_nlp.layers.RandomWordDeletion( | ||
... probability = 0.6, | ||
... max_deletions = 3, | ||
... ) | ||
>>> ds = tf.data.Dataset.from_tensor_slices(inputs) | ||
>>> ds = ds.map(augmenter) | ||
>>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(3)) | ||
>>> ds.take(1).get_single_element() | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'fly kites, do you?', | ||
b'we go some kites'], dtype=object)> | ||
|
||
Batch the inputs and then Augment. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Augment => augment. |
||
>>> tf.random.get_global_generator().reset_from_seed(30) | ||
>>> tf.random.set_seed(30) | ||
>>> inputs = ["I like to fly kites, do you?", | ||
... "Can we go fly some kites later?"] | ||
>>> augmenter = keras_nlp.layers.RandomWordDeletion( | ||
... probability = 0.6, | ||
... max_deletions = 3, | ||
... ) | ||
>>> ds = tf.data.Dataset.from_tensor_slices(inputs) | ||
>>> ds = ds.batch(3).map(augmenter) | ||
>>> ds.take(1).get_single_element() | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'fly kites, do you?', | ||
b'we go some kites'], dtype=object)> | ||
""" | ||
|
||
def __init__(self, probability, max_deletions, name = None, **kwargs): | ||
# Check dtype and provide a default. | ||
if "dtype" not in kwargs or kwargs["dtype"] is None: | ||
kwargs["dtype"] = tf.int32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All examples above have string inputs and string outputs, but here we allow integer inputs as well (we default to tf.int32). If integer inputs are expected, let's add one example to showcase, otherwise let's modify this type check part. |
||
else: | ||
dtype = tf.dtypes.as_dtype(kwargs["dtype"]) | ||
if not dtype.is_integer and dtype != tf.string: | ||
raise ValueError( | ||
"Output dtype must be an integer type or a string. " | ||
f"Received: dtype={dtype}" | ||
) | ||
|
||
super().__init__(name=name, **kwargs) | ||
self.probability = probability | ||
self.max_deletions = max_deletions | ||
|
||
def call(self, inputs): | ||
"""Augments input by randomly deleting words. | ||
Args: | ||
inputs: A tensor or nested tensor of strings to augment. | ||
Returns: | ||
A tensor or nested tensor of augmented strings. | ||
""" | ||
|
||
def validate_and_fix_rank(inputs): | ||
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): | ||
inputs = tf.convert_to_tensor(inputs) | ||
inputs = tf.cast(inputs, tf.string) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we allow integer tensors? If not, let's error out. |
||
if inputs.shape.rank == 0 or inputs.shape.rank == 1: | ||
return inputs | ||
elif inputs.shape.rank == 2: | ||
if inputs.shape[1] != 1: | ||
raise ValueError( | ||
f"input must be of shape `[batch_size, 1]`. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little confused about this - if the input is tokenized input, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @chenmoneygithub I think the confusion you are having is whether we allow input that is pre-split or if splitting happens on the layer. If we split inside the layer, the input you are describing is invalid. And I think it is totally valid to support only shapes If we split outside the layer, the input shapes we should support will change. Essentially we should support ragged with shape |
||
f"Found shape: {inputs.shape}" | ||
) | ||
else: | ||
return tf.squeeze(inputs, axis=1) | ||
else: | ||
raise ValueError( | ||
f"input must be of rank 0 (scalar input), 1 or 2. " | ||
f"Found rank: {inputs.shape.rank}" | ||
) | ||
|
||
isString = False | ||
if isinstance(inputs, str): | ||
inputs = [inputs] | ||
isString = True | ||
|
||
inputs = validate_and_fix_rank(inputs) | ||
|
||
scalar_input = inputs.shape.rank == 0 | ||
if scalar_input: | ||
inputs = tf.expand_dims(inputs, 0) | ||
|
||
ragged_words = tf.strings.split(inputs) | ||
|
||
positions_flat = tf.range(tf.size(ragged_words.flat_values)) | ||
positions = ragged_words.with_flat_values(positions_flat) | ||
|
||
# Figure out how many we are going to select. | ||
word_counts = tf.cast(ragged_words.row_lengths(), "float32") | ||
num_to_select = tf.random.stateless_binomial( | ||
shape=tf.shape(word_counts), | ||
seed=tf.random.get_global_generator().make_seeds()[:, 0], | ||
counts=word_counts, | ||
probs=self.probability, | ||
) | ||
num_to_select = tf.math.minimum(num_to_select, self.max_deletions) | ||
num_to_select = tf.cast(num_to_select, "int64") | ||
|
||
# Shuffle and trim to items that are going to be selected. | ||
def _shuffle_and_trim(x): | ||
positions, top_n = x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For batched input, is this |
||
shuffled = tf.random.shuffle(positions) | ||
return shuffled[:top_n] | ||
|
||
selected_for_mask = tf.map_fn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little concerning about the current implementation - the current logic of picking up indices to delete looks complex to me, and you are using tf.map_fn to handle each element separately instead of doing a batch processing, which could get slow. Rethinking about how we can do the index pick up, can we do the following? Let's assume we have already done the split, so the input is a 2D raggedTensor.
In this way, we can keep the batch dimension along with the computation. But I am not 100% sure if this works, WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's chat about this tomorrow! Overall I think we need to make sure this can trace, but don't actually need to care about performance that much. If you are running an augmentation, it's because you have very little data. You don't actually care about speed. I think the reason this is so complex (here and in tf.text), is that there is no good way to uniformly sample from a pool of candidate indices with a max cap without using a |
||
_shuffle_and_trim, | ||
(positions, num_to_select), | ||
fn_output_signature=tf.RaggedTensorSpec( | ||
ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype | ||
), | ||
) | ||
selected_for_mask.flat_values.set_shape([None]) | ||
|
||
# Construct the mask which is a boolean RT | ||
# Scatter 0's to positions that have been selector for deletion. | ||
update_values = tf.zeros_like( | ||
selected_for_mask.flat_values, "int32" | ||
) | ||
update_indices = selected_for_mask.flat_values | ||
update_indices = tf.expand_dims(update_indices, -1) | ||
update_indices = tf.cast(update_indices, "int32") | ||
mask_flat = tf.ones_like(ragged_words.flat_values, dtype="int32") | ||
mask_flat = tf.tensor_scatter_nd_update( | ||
mask_flat, update_indices, update_values | ||
) | ||
mask = tf.cast(ragged_words.with_flat_values(mask_flat), "bool") | ||
|
||
ragged_words = tf.ragged.boolean_mask(ragged_words, mask) | ||
deleted = tf.strings.reduce_join( | ||
ragged_words, axis=-1, separator=" " | ||
) | ||
if scalar_input: | ||
deleted = tf.squeeze(deleted, 0) | ||
if isString: | ||
deleted = deleted[0] | ||
return deleted | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"probability": self.probability, | ||
"max_deletions": self.max_deletions, | ||
} | ||
) | ||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for Random Word Deletion Layer.""" | ||
|
||
import tensorflow as tf | ||
|
||
from keras_nlp.layers import random_word_deletion | ||
|
||
|
||
class RandomDeletionTest(tf.test.TestCase): | ||
def test_shape_with_scalar(self): | ||
augmenter = random_word_deletion.RandomWordDeletion( | ||
probability=0.5, max_deletions=3 | ||
) | ||
input = ["Running Around"] | ||
output = augmenter(input) | ||
self.assertAllEqual(output.shape, tf.convert_to_tensor(input).shape) | ||
|
||
def test_get_config_and_from_config(self): | ||
|
||
augmenter = random_word_deletion.RandomWordDeletion( | ||
probability=0.5, max_deletions=3 | ||
) | ||
|
||
expected_config_subset = {"probability": 0.5, "max_deletions": 3} | ||
|
||
config = augmenter.get_config() | ||
|
||
self.assertEqual(config, {**config, **expected_config_subset}) | ||
|
||
restored_augmenter = ( | ||
random_word_deletion.RandomWordDeletion.from_config( | ||
config, | ||
) | ||
) | ||
|
||
self.assertEqual( | ||
restored_augmenter.get_config(), | ||
{**config, **expected_config_subset}, | ||
) | ||
|
||
def test_augment_first_batch_second(self): | ||
tf.random.get_global_generator().reset_from_seed(30) | ||
tf.random.set_seed(30) | ||
augmenter = random_word_deletion.RandomWordDeletion( | ||
probability=0.5, max_deletions=3 | ||
) | ||
|
||
ds = tf.data.Dataset.from_tensor_slices( | ||
["samurai or ninja", "keras is good", "tensorflow is a library"] | ||
) | ||
ds = ds.map(augmenter) | ||
ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(3)) | ||
output = ds.take(1).get_single_element() | ||
|
||
exp_output = [b"samurai", b"is good", b"tensorflow a library"] | ||
for i in range(output.shape[0]): | ||
self.assertAllEqual(output[i], exp_output[i]) | ||
|
||
def test_batch_first_augment_second(self): | ||
tf.random.get_global_generator().reset_from_seed(30) | ||
tf.random.set_seed(30) | ||
augmenter = random_word_deletion.RandomWordDeletion( | ||
probability=0.5, max_deletions=3 | ||
) | ||
|
||
ds = tf.data.Dataset.from_tensor_slices( | ||
["samurai or ninja", "keras is good", "tensorflow is a library"] | ||
) | ||
ds = ds.batch(3).map(augmenter) | ||
output = ds.take(1).get_single_element() | ||
|
||
exp_output = [b"samurai", b"is good", b"tensorflow"] | ||
|
||
for i in range(output.shape[0]): | ||
self.assertAllEqual(output[i], exp_output[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence is too long, let's refactor it a bit. One possible way is to break down by steps: 1) Split the words ... 2) randomly select ... 3) mask out... 4) join