Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Deletion Layer #214

Merged
merged 41 commits into from Jul 27, 2022
Merged

Conversation

aflah02
Copy link
Collaborator

@aflah02 aflah02 commented May 31, 2022

Fixes #152
Hey @mattdangerw @chenmoneygithub
I've added the layer and some tests, I'll add more tests in the meanwhile as the code is reviewed!
You can find a demo here
Once major obstacle is adding docstrings as they are also tested and that causes conflicts since the layer randomly deletes values the docstring output may or may not be what the layer outputs hence I could only give one example where probability is 1 and all the words are the same. Any workarounds for this?
Also I'm facing issues with datasets as .numpy() doesn't work in graph_mode. Simply trying to iterate and copy the tensor also fails in graph_mode, any workarounds for this?
I have a workaround which I can do while running it but I can't seem to figure out any changes to make to the code which would fix this -

tf.data.experimental.enable_debug_mode()
augmenter = RandomDeletion(
    probability = 0.5,
    max_deletions = 3,
)
inputs = ["Never gonna give you up", "Never gonna let you down", "Never gonna run around and desert you"]
ds = tf.data.Dataset.from_tensor_slices(inputs)
ds = ds.batch(3).map(lambda y: tf.py_function(
                          (lambda x: augmenter(x)),
                          inp=[y], Tout=tf.string
                      ))
ds.take(1).get_single_element()

Output -

<tf.Tensor: shape=(3,), dtype=string, numpy=array([b'Never you', b'let you', b'Never gonna run and'], dtype=object)>

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments. At a high level, we should add some testing for compiling this layer in a tf function first (see comments on the test), then figure out how to simplify the layer call code so it's compilable.



class RandomDeletion(keras.layers.Layer):
"""Augments input by randomly deleting words
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

period at end of sentence.

Also we should probably have a little more a description in a separate paragraph, that describes the flow of computation. E.g. split words, delete words, reforms words.

from tensorflow import keras


class RandomDeletion(keras.layers.Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think before we ship this we should decide if we want to leave room for a separate character deletion layer, or if we would want to do that as attributes on this layer.

If a character deletion layer would be separate, we should probably call this "RandomCharacterDeletion" or something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree here, we should consider the scalability.

My question is to make a character-level deletion layer, how much change would be required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub I did discuss it with Matt here after his initial comment. We're now thinking of having them as 2 separate layers but would love to hear your thoughts on this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with both! I raised the question because I want to check if possible to have a BaseClass and only do small customization on WordDelete and CharacaterDelete.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub That does seem like a more efficient design choice but not really sure about that. I'll get back if I find a good way for that

inputs = inputs[0]
return inputs

def get_config(self) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can leave the return annotation off. Let's stick to optional annotations for simple types only.

keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
> self.probability
)
# Iterate to check for any cases where deletions exceed the maximum
for i in range(len(row_splits) - 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can clean this whole for loop up, though will need to think about how exactly a bit more.

One place for inspiration is probably the op code for tf text's RandomItemSelector. Which is also selecting a number of items based on a probability with a max cap.

I'm a little skeptical that this would function trace, you are doing a lot of looping and calling .numpy(), the later definitely won't work in a compiled context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading about RandomItemSelector, I think briefly what it does is to:

  1. Calculate how many to select, let's call it N.
  2. shuffle the list/array/tensor's index array.
  3. Pick the first N elements from index array, then use tf.gather to get the actual selected elements.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing this @mattdangerw and for the concise summary @chenmoneygithub this seems like a much more smoother way to do it, will incorporate this idea in my code. Just to clarify do I need to cite this code file in the code as (inspired by) or something of that sort?

keras_nlp/layers/random_deletion_test.py Show resolved Hide resolved
@aflah02
Copy link
Collaborator Author

aflah02 commented Jun 1, 2022

Thanks for the reviews, will get to work on this!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! left some comments!

from tensorflow import keras


class RandomDeletion(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree here, we should consider the scalability.

My question is to make a character-level deletion layer, how much change would be required?

"""Augments input by randomly deleting words

Args:
probability: probability of a word being chosen for deletion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be consistent with the capital case after ":", also we need to note the type in the arg comment.

keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
if not dtype.is_integer and dtype != tf.string:
raise ValueError(
"Output dtype must be an integer type of a string. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of => or

self.max_deletions = max_deletions

def call(self, inputs):
"""Augments input by randomly deleting words
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

period at end.

row_splits = ragged_words.row_splits.numpy()
mask = (
tf.random.uniform(ragged_words.flat_values.shape)
> self.probability
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the format looks a bit strange, is that formatted by black?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub Yup this was generated by black

> self.probability
)
# Iterate to check for any cases where deletions exceed the maximum
for i in range(len(row_splits) - 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading about RandomItemSelector, I think briefly what it does is to:

  1. Calculate how many to select, let's call it N.
  2. shuffle the list/array/tensor's index array.
  3. Pick the first N elements from index array, then use tf.gather to get the actual selected elements.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple high level comments!

output = augmenter(input)
self.assertAllEqual(output.shape, tf.convert_to_tensor(input).shape)

def test_shape_with_nested(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we actually need to support this. The problem is we are only supporting it for one level of nested rank, not multiple.

There are three crucial shapes to support as input: (), (batch_size,), and (batch_size, 1). We should be able to support all of these without the outer map_fn in your code above.

This could it would actually be reasonable to error out on, and ask for a support shape size.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That Makes Sense, I'll throw an error for the same

if scalar_input:
inputs = tf.expand_dims(inputs, 0)

ragged_words = tf.strings.split(inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about how we are doing whitespace splitting here. There's two downsides with this approach.

  • It's destructive. Newlines, tabs, consecutive spaces will all get replaces with a single space.
  • It will group punctuation along with words and delete them both. So for an input "Are you OK?" It would delete "OK?" as a single token I think.

Do we want that behavior? Maybe in the context of EDA normal usage it is fine.

There might be an option to do something fancier here. Use tf_text.regex_split to split on whitespace and punctuation, but still keep the split characters. Use tf.strings.regex_full_match to exclude whitespace and punctation from deletion. Something like...

ragged_words = tf_text.regex_split(inputs, WHITESPACE_AND_PUNCTUATION_REGEX, WHITESPACE_AND_PUNCTUATION_REGEX)

positions_flat = tf.range(tf.size(ragged_words.flat_values))
positions = ragged_words.with_flat_values(positions_flat)
seperator_positions = tf.strings.regex_full_match(ragged_words, WHITESPACE_AND_PUNCTUATION_REGEX)
positions = tf.ragged.boolean_mask(positions_flat, separator_positions == False)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mattdangerw working on this rn, just to confirm in the last line it should be
tf.ragged.boolean_mask(positions, separator_positions == False) instead right?

Copy link
Collaborator Author

@aflah02 aflah02 Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw This method also does this weird thing where it preserves the whitespaces for the deleted texts too, as you can see in this demo the whitespaces in front of "like" and "to" are also preserved. This does satisfy the criteria of preserving spaces albeit in a very weird way is this the intended result?

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! My main concern is the current implementation of picking up the indices to delete is a little complex and use sequential computation.

class RandomWordDeletion(keras.layers.Layer):
"""Augments input by randomly deleting words.

The layer works by splitting the words using `tf.strings.split` computes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is too long, let's refactor it a bit. One possible way is to break down by steps: 1) Split the words ... 2) randomly select ... 3) mask out... 4) join

then removed before returning and the remaining tokens are joined back.

Args:
probability: probability of a word being chosen for deletion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: argument type, and period at the end.

Also maybe call this selection_rate to be consistent with mlm_mask_generator

... )
>>> augmenter(["I like to fly kites, do you?",
... "Can we go fly some kites later?"])
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I fly kites, do you?',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output becomes byte string. I am wondering is it still a valid input to our tokenizers?

<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'fly kites, do you?',
b'we go some kites'], dtype=object)>

Batch the inputs and then Augment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Augment => augment.

def __init__(self, probability, max_deletions, name = None, **kwargs):
# Check dtype and provide a default.
if "dtype" not in kwargs or kwargs["dtype"] is None:
kwargs["dtype"] = tf.int32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All examples above have string inputs and string outputs, but here we allow integer inputs as well (we default to tf.int32). If integer inputs are expected, let's add one example to showcase, otherwise let's modify this type check part.

def validate_and_fix_rank(inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)
inputs = tf.cast(inputs, tf.string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we allow integer tensors? If not, let's error out.

elif inputs.shape.rank == 2:
if inputs.shape[1] != 1:
raise ValueError(
f"input must be of shape `[batch_size, 1]`. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused about this - if the input is tokenized input, such as [["Input", "must", "be", "string"], ["I", "don't", "know", "what", "I", "am", "writing"]], then we will reject the input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub I think the confusion you are having is whether we allow input that is pre-split or if splitting happens on the layer.

If we split inside the layer, the input you are describing is invalid. And I think it is totally valid to support only shapes (bs, 1), (bs) and ().

If we split outside the layer, the input shapes we should support will change. Essentially we should support ragged with shape (bs, None) or dense with shape (None).


# Shuffle and trim to items that are going to be selected.
def _shuffle_and_trim(x):
positions, top_n = x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For batched input, is this positions still having the batch dim? It looks not to me because you are taking the flat_values. But I might be wrong here.

shuffled = tf.random.shuffle(positions)
return shuffled[:top_n]

selected_for_mask = tf.map_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little concerning about the current implementation - the current logic of picking up indices to delete looks complex to me, and you are using tf.map_fn to handle each element separately instead of doing a batch processing, which could get slow.

Rethinking about how we can do the index pick up, can we do the following? Let's assume we have already done the split, so the input is a 2D raggedTensor.

  1. calculate how many words to pick up for each sequence, let's call it num_row_i for ith row.
  2. For each row, pick up 1 - num_row_i indices.
  3. for the selected indices in 2), we do a tf.gather to gather them into a new tensor, which are the words we want to keep.
  4. reduce_join the output of 3) to reconstruct the string for each row.

In this way, we can keep the batch dimension along with the computation. But I am not 100% sure if this works, WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's chat about this tomorrow! Overall I think we need to make sure this can trace, but don't actually need to care about performance that much. If you are running an augmentation, it's because you have very little data. You don't actually care about speed.

I think the reason this is so complex (here and in tf.text), is that there is no good way to uniformly sample from a pool of candidate indices with a max cap without using a map call. It's actually the shuffle call that we can't run without map.

@mattdangerw
Copy link
Member

One thing we need to add here is a seed argument. We support that in all our random layers, for example initializers.

I think the easiest way to do that will be to subclass keras.__internal__.layers.BaseRandomLayer code, and pass the seed to the super class. Then we can use a few functions to generate seeds for the ops themselves, self._random_generator.make_seed_for_stateless_op for the seed for the stateless_binomial call, and self._random_generator.make_legacy_seed for the seed for the shuffle call.

After we do all of this, I think we should see the following behavior...

tf.random.set_seed(1)
RandomDelection(0.2, seed=1)(inputs)

tf.random.set_seed(1)
RandomDelection(0.2, seed=1)(inputs) # output is the same as first call.

tf.random.set_seed(1)
RandomDelection(0.2, seed=2)(inputs) # output is different than the first call.

tf.random.set_seed(1)
RandomDelection(0.2, seed=2)(inputs) # output is the same as third call.

Overall, this is still a little more tricky than we would like, as the underlying random API for tensorflow changes.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments! Let's discuss re any sort of skip argument tomorrow

keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Show resolved Hide resolved
>>> tf.random.get_global_generator().reset_from_seed(30)
>>> tf.random.set_seed(30)
>>> inputs = tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8")
>>> augmenter = keras_nlp.layers.RandomDeletion(rate = 0.4,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can show these examples without max_deletions, just to keep things simple

keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion_test.py Show resolved Hide resolved
self.rate = rate
self.max_deletions = max_deletions
self.seed = seed
self._random_generator = backend.RandomGenerator(seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use this from the keras import, keras.backend.RandomGenerator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some weird behaviour here, if I do this I get AttributeError: module 'keras.api._v2.keras.backend' has no attribute 'RandomGenerator' not sure why

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomGenerator is not exposed as a public API: https://github.com/keras-team/keras/blob/v2.9.0/keras/backend.py#L1823

There is actually an API in TF: tf.random.Generator (https://www.tensorflow.org/api_docs/python/tf/random/Generator), can we use that one?

word_counts = tf.cast(inputs.row_lengths(), "float32")
num_to_select = tf.random.stateless_binomial(
shape=tf.shape(word_counts),
seed=tf.random.get_global_generator().make_seeds()[:, 0],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we do self._random_generator.make_seed_for_stateless_op()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw I don't think we can. It returns a None right now unless I set the rng_type but as per the latest tf release it is not available currently and was added after the latest release

Examples:

Word level usage
>>> tf.random.get_global_generator().reset_from_seed(30)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after we make the minor randomness changes suggested below, can we do tf.keras.utils.set_random_seed(30)? That would make this a little more readable with a single seed line.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking close! few comments

keras_nlp/layers/random_deletion.py Show resolved Hide resolved
keras_nlp/layers/random_deletion_test.py Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few more comments

keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
keras_nlp/layers/random_deletion.py Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comments, lgtm pending these changes

keras_nlp/layers/random_deletion.py Outdated Show resolved Hide resolved
)
elif self.skip_py_fn:

def _preprocess_string_fn(word):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify this a bit. We should also rename word -> token. Something like...

string_fn = lambda x: self.skip_py_fn(x.numpy().decode("utf-8"))
int_fn = lambda x: self.skip_py_fn(x.numpy())
py_fn = string_fn if ragged.dtype == tf.string else int_fn

return tf.map_fn(
    lambda x: tf.py_function(py_fn, [x], tf.bool),
    inputs.flat_values,
    dtype=tf.bool,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw This seems easier however it fails the style rules, the formatter gives this error - E731 do not assign a lambda expression, use a def. I guess this will persist till we keep assigning lambdas. Alternatively this works:

            def string_fn(token):
                return self.skip_py_fn(token.numpy().decode("utf-8"))
            def int_fn(token):
                return self.skip_py_fn(token.numpy())
            py_fn = string_fn if inputs.dtype == tf.string else int_fn

keras_nlp/layers/random_deletion_test.py Outdated Show resolved Hide resolved
for i in range(output.shape[0]):
self.assertAllEqual(output[i], exp_output[i])

def test_skip_options(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like we are just testing roughly the same things as the dataset test below. Can we just move the skip_list test into the dataset test below and delete this test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw The dataset tests already have skip options being tested so should I remove this altogether?

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's wait for Matt's approval.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mattdangerw mattdangerw merged commit 58f4bbb into keras-team:master Jul 27, 2022
@aflah02
Copy link
Collaborator Author

aflah02 commented Jul 27, 2022

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Random Deletion Layer - Data Augmentation
3 participants