New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Deletion Layer #214
Merged
+493
−0
Merged
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
bf3219f
Random Deletion Working
aflah02 66c0de7
Added to init
aflah02 d2508dc
WOrking
aflah02 9e2e243
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 b48dc5a
Working
aflah02 02ce27d
Current Status
aflah02 f137256
Working Layer More Tests to be Added
aflah02 c7dcb8a
Fixed Scalar Case
aflah02 515a1d3
Added Comments
aflah02 6acfb62
Minor Fixes
aflah02 68fcae0
Major Refactors and Fixes, ToDo - Docs, Tests
aflah02 058d572
Fixed Shape Issues for Scalar Lists
aflah02 7a57339
Finalized Tests and DocString
aflah02 9cf9a2a
Ran Stylers Added More Descriptive DocString
aflah02 ce10e2d
Fixed Failing Docstring Tests
aflah02 a1ab88e
Removed Map Call and Unsupported Test
aflah02 17e2365
Shape Fixes
aflah02 b5cfe45
Working
aflah02 839a770
Working
aflah02 ec2d4ed
Changing Parent Class
aflah02 dbdd690
Changes
aflah02 fd856ad
Formatter Ran
aflah02 292353b
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 9c447dc
Finalized
aflah02 35cf31a
Merge branch 'RandomDeletionLayer' of https://github.com/aflah02/kera…
aflah02 b26dd24
Addresed Review Comments
aflah02 983bfb3
Fornatter
aflah02 4813f1f
Added new Tests
aflah02 86bafbe
Fan Formatter
aflah02 46eda29
Skip Works
aflah02 906614a
New Randomness
aflah02 20df2b2
All Testing Done
aflah02 0c8fdc5
Review Changes
aflah02 3d87d14
Addressed all Review Comments
aflah02 5ee888a
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 a77880e
Copy edits for docstrings
mattdangerw 9c1904c
Finishes
aflah02 b87394d
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 81ba7bf
Changed Tokenizer Import
aflah02 0da28dc
Addressed Reviews
aflah02 d91a1bb
Fix typo
mattdangerw File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import random | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.python.ops.ragged import ragged_array_ops | ||
|
||
|
||
class RandomDeletion(keras.layers.Layer): | ||
"""Augments input by randomly deleting tokens. | ||
|
||
This layer comes in handy when you need to generate new data using deletion | ||
augmentation as described in the paper [EDA: Easy Data Augmentation | ||
Techniques for Boosting Performance on Text Classification Tasks] | ||
(https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be | ||
pretokenized so that each token can be individually treated as a possible | ||
candidate for deletion. | ||
|
||
Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and | ||
either rank-1 or rank-2. | ||
|
||
Args: | ||
rate: The probability of a token being chosen for deletion. | ||
max_deletions: The maximum number of tokens to delete. | ||
skip_list: A list of token values that should not be considered | ||
candidates for deletion. | ||
skip_fn: A function that takes as input a scalar tensor token and | ||
returns as output a scalar tensor True/False value. A value of | ||
True indicates that the token should not be considered a | ||
candidate for deletion. This function must be tracable--it | ||
should consist of tensorflow operations. | ||
skip_py_fn: A function that takes as input a python token value and | ||
returns as output `True` or `False`. A value of True | ||
indicates that should not be considered a candidate for deletion. | ||
Unlike the `skip_fn` argument, this argument need not be | ||
tracable--it can be any python function. | ||
seed: A seed for the rng. | ||
|
||
Examples: | ||
|
||
Word level usage. | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'], | ||
dtype=object)> | ||
|
||
Character level usage. | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8") | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'H Dude', b'pedUp'], | ||
dtype=object)> | ||
|
||
Usage with skip_list. | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, | ||
... skip_list=["Keras", "Tensorflow"], seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, | ||
numpy=array([b'I like', b'Keras Tensorflow'], dtype=object)> | ||
|
||
Usage with skip_fn. | ||
>>> def skip_fn(word): | ||
... return tf.strings.regex_full_match(word, r"\\pP") | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, | ||
... skip_fn=skip_fn, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'], | ||
dtype=object)> | ||
|
||
Usage with skip_py_fn. | ||
>>> def skip_py_fn(word): | ||
... return len(word) < 4 | ||
>>> keras.utils.set_random_seed(1337) | ||
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"]) | ||
>>> augmenter=RandomDeletion(rate=0.4, | ||
... skip_py_fn=skip_py_fn, seed=42) | ||
>>> augmented=augmenter(inputs) | ||
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1) | ||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Hey I', b'and'], | ||
dtype=object)> | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rate, | ||
max_deletions=None, | ||
skip_list=None, | ||
skip_fn=None, | ||
skip_py_fn=None, | ||
seed=None, | ||
name=None, | ||
**kwargs, | ||
): | ||
# Check dtype and provide a default. | ||
if "dtype" not in kwargs or kwargs["dtype"] is None: | ||
kwargs["dtype"] = tf.int32 | ||
else: | ||
dtype = tf.dtypes.as_dtype(kwargs["dtype"]) | ||
if not dtype.is_integer and dtype != tf.string: | ||
raise ValueError( | ||
"Output dtype must be an integer type or a string. " | ||
aflah02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"Received: dtype={dtype}" | ||
) | ||
|
||
super().__init__(name=name, **kwargs) | ||
self.rate = rate | ||
self.max_deletions = max_deletions | ||
aflah02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.seed = random.randint(1, 1e9) if seed is None else seed | ||
self._generator = tf.random.Generator.from_seed(self.seed) | ||
self.skip_list = skip_list | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.skip_fn = skip_fn | ||
self.skip_py_fn = skip_py_fn | ||
|
||
if self.rate > 1 or self.rate < 0: | ||
raise ValueError( | ||
"Rate must be between 0 and 1 (both inclusive)." | ||
f"Received: rate={rate}" | ||
) | ||
|
||
if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2: | ||
raise ValueError( | ||
"Exactly one of skip_list, skip_fn, skip_py_fn must be " | ||
aflah02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"provided." | ||
) | ||
|
||
if self.skip_list: | ||
self.StaticHashTable = tf.lookup.StaticHashTable( | ||
tf.lookup.KeyValueTensorInitializer( | ||
tf.convert_to_tensor(self.skip_list), | ||
tf.convert_to_tensor([True] * len(self.skip_list)), | ||
), | ||
default_value=False, | ||
) | ||
|
||
def call(self, inputs): | ||
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): | ||
inputs = tf.convert_to_tensor(inputs) | ||
|
||
input_is_1d = False | ||
if inputs.shape.rank < 1 or inputs.shape.rank > 2: | ||
raise ValueError( | ||
"Input must either be rank 1 or rank 2. Received input with " | ||
f"rank={inputs.shape.rank}" | ||
) | ||
elif inputs.shape.rank == 1: | ||
input_is_1d = True | ||
# Add a new axis at the beginning. | ||
inputs = tf.expand_dims(inputs, axis=0) | ||
if isinstance(inputs, tf.Tensor): | ||
# Convert to ragged tensor. | ||
inputs = tf.RaggedTensor.from_tensor(inputs) | ||
|
||
# skip words that are in the skip_list | ||
skip_masks = None | ||
if self.skip_list: | ||
skip_masks = self.StaticHashTable.lookup(inputs.flat_values) | ||
elif self.skip_fn: | ||
skip_masks = tf.map_fn( | ||
self.skip_fn, inputs.flat_values, dtype=tf.bool | ||
) | ||
elif self.skip_py_fn: | ||
|
||
def _preprocess_fn(word): | ||
return self.skip_py_fn(word.numpy().decode("utf-8")) | ||
aflah02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
skip_masks = tf.map_fn( | ||
lambda x: tf.py_function(_preprocess_fn, [x], tf.bool), | ||
inputs.flat_values, | ||
dtype=tf.bool, | ||
) | ||
positions_flat = tf.range(tf.size(inputs.flat_values)) | ||
positions = inputs.with_flat_values(positions_flat) | ||
if skip_masks is not None: | ||
skip_masks = tf.logical_not(skip_masks) | ||
skip_masks.set_shape([None]) | ||
positions = ragged_array_ops.boolean_mask( | ||
positions, inputs.with_flat_values(skip_masks) | ||
) | ||
|
||
# Figure out how many we are going to select. | ||
word_counts = tf.cast(inputs.row_lengths(), "float32") | ||
aflah02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_to_select = tf.random.stateless_binomial( | ||
shape=tf.shape(word_counts), | ||
seed=self._generator.make_seeds()[:, 0], | ||
counts=word_counts, | ||
probs=self.rate, | ||
) | ||
if self.max_deletions is not None: | ||
num_to_select = tf.math.minimum(num_to_select, self.max_deletions) | ||
num_to_select = tf.cast(num_to_select, "int64") | ||
|
||
# Shuffle and trim to items that are going to be selected. | ||
def _shuffle_and_trim(x): | ||
positions, top_n = x | ||
shuffled = tf.random.shuffle(positions, seed=self.seed) | ||
return shuffled[:top_n] | ||
|
||
selected_for_mask = tf.map_fn( | ||
_shuffle_and_trim, | ||
(positions, num_to_select), | ||
fn_output_signature=tf.RaggedTensorSpec( | ||
ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype | ||
), | ||
) | ||
selected_for_mask.flat_values.set_shape([None]) | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Construct the mask which is a boolean RT | ||
# Scatter 0's to positions that have been selector for deletion. | ||
update_values = tf.zeros_like(selected_for_mask.flat_values, "int32") | ||
update_indices = selected_for_mask.flat_values | ||
update_indices = tf.expand_dims(update_indices, -1) | ||
update_indices = tf.cast(update_indices, "int32") | ||
mask_flat = tf.ones_like(inputs.flat_values, dtype="int32") | ||
mask_flat = tf.tensor_scatter_nd_update( | ||
mask_flat, update_indices, update_values | ||
) | ||
mask = tf.cast(inputs.with_flat_values(mask_flat), "bool") | ||
|
||
inputs = tf.ragged.boolean_mask(inputs, mask) | ||
|
||
if input_is_1d: | ||
inputs = tf.squeeze(inputs, axis=0) | ||
|
||
return inputs | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"rate": self.rate, | ||
"max_deletions": self.max_deletions, | ||
"seed": self.seed, | ||
aflah02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"skip_list": self.skip_list, | ||
"skip_fn": self.skip_fn, | ||
"skip_py_fn": self.skip_py_fn, | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think before we ship this we should decide if we want to leave room for a separate character deletion layer, or if we would want to do that as attributes on this layer.
If a character deletion layer would be separate, we should probably call this "RandomCharacterDeletion" or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree here, we should consider the scalability.
My question is to make a character-level deletion layer, how much change would be required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenmoneygithub I did discuss it with Matt here after his initial comment. We're now thinking of having them as 2 separate layers but would love to hear your thoughts on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine with both! I raised the question because I want to check if possible to have a BaseClass and only do small customization on WordDelete and CharacaterDelete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenmoneygithub That does seem like a more efficient design choice but not really sure about that. I'll get back if I find a good way for that