diff --git a/keras_nlp/layers/__init__.py b/keras_nlp/layers/__init__.py index a4cc25080..df67418a3 100644 --- a/keras_nlp/layers/__init__.py +++ b/keras_nlp/layers/__init__.py @@ -17,6 +17,7 @@ from keras_nlp.layers.mlm_mask_generator import MLMMaskGenerator from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker from keras_nlp.layers.position_embedding import PositionEmbedding +from keras_nlp.layers.random_deletion import RandomDeletion from keras_nlp.layers.sine_position_encoding import SinePositionEncoding from keras_nlp.layers.start_end_packer import StartEndPacker from keras_nlp.layers.token_and_position_embedding import ( diff --git a/keras_nlp/layers/random_deletion.py b/keras_nlp/layers/random_deletion.py new file mode 100644 index 000000000..582f3918b --- /dev/null +++ b/keras_nlp/layers/random_deletion.py @@ -0,0 +1,266 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import tensorflow as tf +from tensorflow import keras +from tensorflow.python.ops.ragged import ragged_array_ops + + +class RandomDeletion(keras.layers.Layer): + """Augments input by randomly deleting tokens. + + This layer comes in handy when you need to generate new data using deletion + augmentation as described in the paper [EDA: Easy Data Augmentation + Techniques for Boosting Performance on Text Classification Tasks] + (https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be + pretokenized so that each token can be individually treated as a possible + candidate for deletion. + + Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and + either rank-1 or rank-2. + + Args: + rate: The probability of a token being chosen for deletion. + max_deletions: The maximum number of tokens to delete. + skip_list: A list of token values that should not be considered + candidates for deletion. + skip_fn: A function that takes as input a scalar tensor token and + returns as output a scalar tensor True/False value. A value of + True indicates that the token should not be considered a + candidate for deletion. This function must be tracable--it + should consist of tensorflow operations. + skip_py_fn: A function that takes as input a python token value and + returns as output `True` or `False`. A value of True + indicates that should not be considered a candidate for deletion. + Unlike the `skip_fn` argument, this argument need not be + tracable--it can be any python function. + seed: A seed for the rng. + + Examples: + + Word level usage. + >>> keras.utils.set_random_seed(1337) + >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) + >>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42) + >>> augmented=augmenter(inputs) + >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) + + + Character level usage. + >>> keras.utils.set_random_seed(1337) + >>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8") + >>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42) + >>> augmented=augmenter(inputs) + >>> tf.strings.reduce_join(augmented, axis=-1) + + + Usage with skip_list. + >>> keras.utils.set_random_seed(1337) + >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) + >>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, + ... skip_list=["Keras", "Tensorflow"], seed=42) + >>> augmented=augmenter(inputs) + >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) + + + Usage with skip_fn. + >>> def skip_fn(word): + ... return tf.strings.regex_full_match(word, r"\\pP") + >>> keras.utils.set_random_seed(1337) + >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) + >>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, + ... skip_fn=skip_fn, seed=42) + >>> augmented=augmenter(inputs) + >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) + + + Usage with skip_py_fn. + >>> def skip_py_fn(word): + ... return len(word) < 4 + >>> keras.utils.set_random_seed(1337) + >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) + >>> augmenter=RandomDeletion(rate=0.4, + ... skip_py_fn=skip_py_fn, seed=42) + >>> augmented=augmenter(inputs) + >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) + + """ + + def __init__( + self, + rate, + max_deletions=None, + skip_list=None, + skip_fn=None, + skip_py_fn=None, + seed=None, + name=None, + **kwargs, + ): + # Check dtype and provide a default. + if "dtype" not in kwargs or kwargs["dtype"] is None: + kwargs["dtype"] = tf.int32 + else: + dtype = tf.dtypes.as_dtype(kwargs["dtype"]) + if not dtype.is_integer and dtype != tf.string: + raise ValueError( + "Output dtype must be one of `'string'`, `'int32'`, and " + f"`'int64'`. Received: dtype={dtype}" + ) + + super().__init__(name=name, **kwargs) + self.rate = rate + self.max_deletions = max_deletions + self.seed = random.randint(1, 1e9) if seed is None else seed + self._generator = tf.random.Generator.from_seed(self.seed) + self.skip_list = skip_list + self.skip_fn = skip_fn + self.skip_py_fn = skip_py_fn + + if self.rate > 1 or self.rate < 0: + raise ValueError( + "Rate must be between 0 and 1 (both inclusive)." + f"Received: rate={rate}" + ) + + if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2: + raise ValueError( + "Exactly one of `skip_list`, `skip_fn`, `skip_py_fn` must be " + "provided." + ) + + if self.skip_list: + self.StaticHashTable = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer( + tf.convert_to_tensor(self.skip_list), + tf.convert_to_tensor([True] * len(self.skip_list)), + ), + default_value=False, + ) + + def call(self, inputs): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + input_is_1d = False + if inputs.shape.rank < 1 or inputs.shape.rank > 2: + raise ValueError( + "Input must either be rank 1 or rank 2. Received input with " + f"rank={inputs.shape.rank}" + ) + elif inputs.shape.rank == 1: + input_is_1d = True + # Add a new axis at the beginning. + inputs = tf.expand_dims(inputs, axis=0) + if isinstance(inputs, tf.Tensor): + # Convert to ragged tensor. + inputs = tf.RaggedTensor.from_tensor(inputs) + + skip_masks = None + if self.skip_list: + skip_masks = self.StaticHashTable.lookup(inputs.flat_values) + elif self.skip_fn: + skip_masks = tf.map_fn( + self.skip_fn, inputs.flat_values, fn_output_signature=tf.bool + ) + elif self.skip_py_fn: + + def string_fn(token): + return self.skip_py_fn(token.numpy().decode("utf-8")) + + def int_fn(token): + return self.skip_py_fn(token.numpy()) + + py_fn = string_fn if inputs.dtype == tf.string else int_fn + + skip_masks = tf.map_fn( + lambda x: tf.py_function(py_fn, [x], tf.bool), + inputs.flat_values, + fn_output_signature=tf.bool, + ) + + positions_flat = tf.range(tf.size(inputs.flat_values)) + positions = inputs.with_flat_values(positions_flat) + if skip_masks is not None: + skip_masks = tf.logical_not(skip_masks) + skip_masks.set_shape([None]) + positions = ragged_array_ops.boolean_mask( + positions, inputs.with_flat_values(skip_masks) + ) + + # Figure out how many we are going to select. + token_counts = tf.cast(inputs.row_lengths(), "float32") + num_to_select = tf.random.stateless_binomial( + shape=tf.shape(token_counts), + seed=self._generator.make_seeds()[:, 0], + counts=token_counts, + probs=self.rate, + ) + if self.max_deletions is not None: + num_to_select = tf.math.minimum(num_to_select, self.max_deletions) + num_to_select = tf.cast(num_to_select, "int64") + + # Shuffle and trim to items that are going to be selected. + def _shuffle_and_trim(x): + positions, top_n = x + shuffled = tf.random.shuffle(positions, seed=self.seed) + return shuffled[:top_n] + + selected_for_mask = tf.map_fn( + _shuffle_and_trim, + (positions, num_to_select), + fn_output_signature=tf.RaggedTensorSpec( + ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype + ), + ) + selected_for_mask.flat_values.set_shape([None]) + + # Construct the mask which is a boolean RT + # Scatter 0's to positions that have been selector for deletion. + update_values = tf.zeros_like(selected_for_mask.flat_values, "int32") + update_indices = selected_for_mask.flat_values + update_indices = tf.expand_dims(update_indices, -1) + update_indices = tf.cast(update_indices, "int32") + mask_flat = tf.ones_like(inputs.flat_values, dtype="int32") + mask_flat = tf.tensor_scatter_nd_update( + mask_flat, update_indices, update_values + ) + mask = tf.cast(inputs.with_flat_values(mask_flat), "bool") + + inputs = tf.ragged.boolean_mask(inputs, mask) + + if input_is_1d: + inputs = tf.squeeze(inputs, axis=0) + + return inputs + + def get_config(self): + config = super().get_config() + config.update( + { + "rate": self.rate, + "max_deletions": self.max_deletions, + "seed": self.seed, + "skip_list": self.skip_list, + "skip_fn": self.skip_fn, + "skip_py_fn": self.skip_py_fn, + } + ) + return config diff --git a/keras_nlp/layers/random_deletion_test.py b/keras_nlp/layers/random_deletion_test.py new file mode 100644 index 000000000..edacd3f35 --- /dev/null +++ b/keras_nlp/layers/random_deletion_test.py @@ -0,0 +1,226 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for RandomDeletion Layer.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.layers import random_deletion +from keras_nlp.tokenizers import UnicodeCodepointTokenizer + + +class RandomDeletionTest(tf.test.TestCase): + def test_shape_and_output_from_word_deletion(self): + keras.utils.set_random_seed(1337) + inputs = ["Hey I like", "Keras and Tensorflow"] + split = tf.strings.split(inputs) + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42 + ) + augmented = augmenter(split) + output = tf.strings.reduce_join(augmented, separator=" ", axis=-1) + self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) + exp_output = [b"I like", b"and Tensorflow"] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_shape_and_output_from_character_swaps(self): + keras.utils.set_random_seed(1337) + inputs = ["Hey I like", "Keras and Tensorflow"] + split = tf.strings.unicode_split(inputs, "UTF-8") + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42 + ) + augmented = augmenter(split) + output = tf.strings.reduce_join(augmented, axis=-1) + self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) + exp_output = [b"Hey I lie", b"Keras and Tensoflow"] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_with_integer_tokens(self): + keras.utils.set_random_seed(1337) + inputs = ["Hey I like", "Keras and Tensorflow"] + tokenizer = UnicodeCodepointTokenizer(lowercase=False) + tokenized = tokenizer.tokenize(inputs) + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=4, seed=42 + ) + augmented = augmenter(tokenized) + output = tokenizer.detokenize(augmented) + self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) + exp_output = [b"Hey Ile", b"Keas and ensoflw"] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_skip_options(self): + keras.utils.set_random_seed(1337) + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42, skip_list=["Tensorflow", "like"] + ) + inputs = ["Hey I like", "Keras and Tensorflow"] + split = tf.strings.split(inputs) + augmented = augmenter(split) + output = tf.strings.reduce_join(augmented, separator=" ", axis=-1) + self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) + exp_output = [b"I like", b"and Tensorflow"] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def skip_fn(word): + if word == "Tensorflow" or word == "like": + return True + return False + + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42, skip_fn=skip_fn + ) + augmented = augmenter(split) + output = tf.strings.reduce_join(augmented, separator=" ", axis=-1) + self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) + exp_output = [b"Hey like", b"Keras Tensorflow"] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def skip_py_fn(word): + if word == "Tensorflow" or word == "like": + return True + return False + + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42, skip_py_fn=skip_py_fn + ) + augmented = augmenter(split) + output = tf.strings.reduce_join(augmented, separator=" ", axis=-1) + self.assertAllEqual(output.shape, tf.convert_to_tensor(inputs).shape) + exp_output = [b"Hey like", b"Keras Tensorflow"] + + def test_get_config_and_from_config(self): + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42 + ) + + expected_config_subset = {"max_deletions": 1, "rate": 0.4, "seed": 42} + + config = augmenter.get_config() + + self.assertEqual(config, {**config, **expected_config_subset}) + + restored_augmenter = random_deletion.RandomDeletion.from_config( + config, + ) + + self.assertEqual( + restored_augmenter.get_config(), + {**config, **expected_config_subset}, + ) + + def test_augment_first_batch_second(self): + keras.utils.set_random_seed(1337) + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42 + ) + inputs = ["Hey I like", "Keras and Tensorflow"] + split = tf.strings.split(inputs) + ds = tf.data.Dataset.from_tensor_slices(split) + ds = ds.map(augmenter) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2)) + output = ds.take(1).get_single_element() + + exp_output = [[b"I", b"like"], [b"Keras", b"and", b"Tensorflow"]] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def skip_fn(word): + return tf.strings.regex_full_match(word, r"\pP") + + def skip_py_fn(word): + return len(word) < 4 + + augmenter = random_deletion.RandomDeletion( + rate=0.8, max_deletions=1, seed=42, skip_fn=skip_fn + ) + ds = tf.data.Dataset.from_tensor_slices(split) + ds = ds.map(augmenter) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2)) + output = ds.take(1).get_single_element() + exp_output = [[b"I", b"like"], [b"and", b"Tensorflow"]] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + augmenter = random_deletion.RandomDeletion( + rate=0.8, max_deletions=1, seed=42, skip_py_fn=skip_py_fn + ) + ds = tf.data.Dataset.from_tensor_slices(split) + ds = ds.map(augmenter) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2)) + output = ds.take(1).get_single_element() + exp_output = [[b"Hey", b"I"], [b"and", b"Tensorflow"]] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_batch_first_augment_second(self): + keras.utils.set_random_seed(1337) + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42 + ) + inputs = ["Hey I like", "Keras and Tensorflow"] + split = tf.strings.split(inputs) + ds = tf.data.Dataset.from_tensor_slices(split) + ds = ds.batch(5).map(augmenter) + output = ds.take(1).get_single_element() + + exp_output = [[b"I", b"like"], [b"and", b"Tensorflow"]] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def skip_fn(word): + return tf.strings.regex_full_match(word, r"\pP") + + def skip_py_fn(word): + return len(word) < 4 + + augmenter = random_deletion.RandomDeletion( + rate=0.8, max_deletions=1, seed=42, skip_fn=skip_fn + ) + ds = tf.data.Dataset.from_tensor_slices(split) + ds = ds.batch(5).map(augmenter) + output = ds.take(1).get_single_element() + exp_output = [[b"I", b"like"], [b"and", b"Tensorflow"]] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + augmenter = random_deletion.RandomDeletion( + rate=0.8, max_deletions=1, seed=42, skip_py_fn=skip_py_fn + ) + ds = tf.data.Dataset.from_tensor_slices(split) + ds = ds.batch(5).map(augmenter) + output = ds.take(1).get_single_element() + exp_output = [[b"Hey", b"I"], [b"and", b"Tensorflow"]] + for i in range(output.shape[0]): + self.assertAllEqual(output[i], exp_output[i]) + + def test_functional_model(self): + keras.utils.set_random_seed(1337) + input_data = tf.constant(["Hey I like", "Keras and Tensorflow"]) + augmenter = random_deletion.RandomDeletion( + rate=0.4, max_deletions=1, seed=42 + ) + inputs = tf.keras.Input(dtype="string", shape=()) + outputs = augmenter(tf.strings.split(inputs)) + model = tf.keras.Model(inputs, outputs) + model_output = model(input_data) + self.assertAllEqual( + model_output, [[b"I", b"like"], [b"and", b"Tensorflow"]] + )