Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Deletion Layer #214

Merged
merged 41 commits into from Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
bf3219f
Random Deletion Working
aflah02 Apr 29, 2022
66c0de7
Added to init
aflah02 Apr 29, 2022
d2508dc
WOrking
aflah02 May 3, 2022
9e2e243
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 May 3, 2022
b48dc5a
Working
aflah02 May 23, 2022
02ce27d
Current Status
aflah02 May 26, 2022
f137256
Working Layer More Tests to be Added
aflah02 May 31, 2022
c7dcb8a
Fixed Scalar Case
aflah02 May 31, 2022
515a1d3
Added Comments
aflah02 May 31, 2022
6acfb62
Minor Fixes
aflah02 Jun 3, 2022
68fcae0
Major Refactors and Fixes, ToDo - Docs, Tests
aflah02 Jun 8, 2022
058d572
Fixed Shape Issues for Scalar Lists
aflah02 Jun 8, 2022
7a57339
Finalized Tests and DocString
aflah02 Jun 9, 2022
9cf9a2a
Ran Stylers Added More Descriptive DocString
aflah02 Jun 9, 2022
ce10e2d
Fixed Failing Docstring Tests
aflah02 Jun 9, 2022
a1ab88e
Removed Map Call and Unsupported Test
aflah02 Jun 10, 2022
17e2365
Shape Fixes
aflah02 Jun 10, 2022
b5cfe45
Working
aflah02 Jun 29, 2022
839a770
Working
aflah02 Jun 29, 2022
ec2d4ed
Changing Parent Class
aflah02 Jun 30, 2022
dbdd690
Changes
aflah02 Jul 1, 2022
fd856ad
Formatter Ran
aflah02 Jul 1, 2022
292353b
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 Jul 5, 2022
9c447dc
Finalized
aflah02 Jul 5, 2022
35cf31a
Merge branch 'RandomDeletionLayer' of https://github.com/aflah02/kera…
aflah02 Jul 5, 2022
b26dd24
Addresed Review Comments
aflah02 Jul 6, 2022
983bfb3
Fornatter
aflah02 Jul 6, 2022
4813f1f
Added new Tests
aflah02 Jul 6, 2022
86bafbe
Fan Formatter
aflah02 Jul 6, 2022
46eda29
Skip Works
aflah02 Jul 12, 2022
906614a
New Randomness
aflah02 Jul 12, 2022
20df2b2
All Testing Done
aflah02 Jul 20, 2022
0c8fdc5
Review Changes
aflah02 Jul 20, 2022
3d87d14
Addressed all Review Comments
aflah02 Jul 20, 2022
5ee888a
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 Jul 20, 2022
a77880e
Copy edits for docstrings
mattdangerw Jul 22, 2022
9c1904c
Finishes
aflah02 Jul 25, 2022
b87394d
Merge branch 'keras-team:master' into RandomDeletionLayer
aflah02 Jul 25, 2022
81ba7bf
Changed Tokenizer Import
aflah02 Jul 25, 2022
0da28dc
Addressed Reviews
aflah02 Jul 26, 2022
d91a1bb
Fix typo
mattdangerw Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/layers/__init__.py
Expand Up @@ -16,6 +16,7 @@
from keras_nlp.layers.mlm_head import MLMHead
from keras_nlp.layers.mlm_mask_generator import MLMMaskGenerator
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.random_word_deletion import RandomWordDeletion
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding
from keras_nlp.layers.token_and_position_embedding import (
TokenAndPositionEmbedding,
Expand Down
193 changes: 193 additions & 0 deletions keras_nlp/layers/random_word_deletion.py
@@ -0,0 +1,193 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras


class RandomWordDeletion(keras.layers.Layer):
"""Augments input by randomly deleting words.

The layer works by splitting the words using `tf.strings.split` computes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is too long, let's refactor it a bit. One possible way is to break down by steps: 1) Split the words ... 2) randomly select ... 3) mask out... 4) join

the indices to keep randomly and masks out the ones to be deleted which are
then removed before returning and the remaining tokens are joined back.

Args:
probability: probability of a word being chosen for deletion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: argument type, and period at the end.

Also maybe call this selection_rate to be consistent with mlm_mask_generator

max_deletions: The maximum number of words to delete

Examples:

Basic usage.
>>> tf.random.get_global_generator().reset_from_seed(30)
>>> tf.random.set_seed(30)
>>> augmenter = keras_nlp.layers.RandomWordDeletion(
... probability = 0.7,
... max_deletions = 2,
... )
>>> augmenter(["I like to fly kites, do you?",
... "Can we go fly some kites later?"])
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I fly kites, do you?',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output becomes byte string. I am wondering is it still a valid input to our tokenizers?

b'Can we fly kites later?'], dtype=object)>

Augment first, then batch the dataset.
>>> tf.random.get_global_generator().reset_from_seed(30)
>>> tf.random.set_seed(30)
>>> inputs = ["I like to fly kites, do you?",
... "Can we go fly some kites later?"]
>>> augmenter = keras_nlp.layers.RandomWordDeletion(
... probability = 0.6,
... max_deletions = 3,
... )
>>> ds = tf.data.Dataset.from_tensor_slices(inputs)
>>> ds = ds.map(augmenter)
>>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(3))
>>> ds.take(1).get_single_element()
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'fly kites, do you?',
b'we go some kites'], dtype=object)>

Batch the inputs and then Augment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Augment => augment.

>>> tf.random.get_global_generator().reset_from_seed(30)
>>> tf.random.set_seed(30)
>>> inputs = ["I like to fly kites, do you?",
... "Can we go fly some kites later?"]
>>> augmenter = keras_nlp.layers.RandomWordDeletion(
... probability = 0.6,
... max_deletions = 3,
... )
>>> ds = tf.data.Dataset.from_tensor_slices(inputs)
>>> ds = ds.batch(3).map(augmenter)
>>> ds.take(1).get_single_element()
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'fly kites, do you?',
b'we go some kites'], dtype=object)>
"""

def __init__(self, probability, max_deletions, name = None, **kwargs):
# Check dtype and provide a default.
if "dtype" not in kwargs or kwargs["dtype"] is None:
kwargs["dtype"] = tf.int32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All examples above have string inputs and string outputs, but here we allow integer inputs as well (we default to tf.int32). If integer inputs are expected, let's add one example to showcase, otherwise let's modify this type check part.

else:
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
if not dtype.is_integer and dtype != tf.string:
raise ValueError(
"Output dtype must be an integer type or a string. "
f"Received: dtype={dtype}"
)

super().__init__(name=name, **kwargs)
self.probability = probability
self.max_deletions = max_deletions

def call(self, inputs):
"""Augments input by randomly deleting words.
Args:
inputs: A tensor or nested tensor of strings to augment.
Returns:
A tensor or nested tensor of augmented strings.
"""

def validate_and_fix_rank(inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)
inputs = tf.cast(inputs, tf.string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we allow integer tensors? If not, let's error out.

if inputs.shape.rank == 0 or inputs.shape.rank == 1:
return inputs
elif inputs.shape.rank == 2:
if inputs.shape[1] != 1:
raise ValueError(
f"input must be of shape `[batch_size, 1]`. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused about this - if the input is tokenized input, such as [["Input", "must", "be", "string"], ["I", "don't", "know", "what", "I", "am", "writing"]], then we will reject the input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub I think the confusion you are having is whether we allow input that is pre-split or if splitting happens on the layer.

If we split inside the layer, the input you are describing is invalid. And I think it is totally valid to support only shapes (bs, 1), (bs) and ().

If we split outside the layer, the input shapes we should support will change. Essentially we should support ragged with shape (bs, None) or dense with shape (None).

f"Found shape: {inputs.shape}"
)
else:
return tf.squeeze(inputs, axis=1)
else:
raise ValueError(
f"input must be of rank 0 (scalar input), 1 or 2. "
f"Found rank: {inputs.shape.rank}"
)

isString = False
if isinstance(inputs, str):
inputs = [inputs]
isString = True

inputs = validate_and_fix_rank(inputs)

scalar_input = inputs.shape.rank == 0
if scalar_input:
inputs = tf.expand_dims(inputs, 0)

ragged_words = tf.strings.split(inputs)

positions_flat = tf.range(tf.size(ragged_words.flat_values))
positions = ragged_words.with_flat_values(positions_flat)

# Figure out how many we are going to select.
word_counts = tf.cast(ragged_words.row_lengths(), "float32")
num_to_select = tf.random.stateless_binomial(
shape=tf.shape(word_counts),
seed=tf.random.get_global_generator().make_seeds()[:, 0],
counts=word_counts,
probs=self.probability,
)
num_to_select = tf.math.minimum(num_to_select, self.max_deletions)
num_to_select = tf.cast(num_to_select, "int64")

# Shuffle and trim to items that are going to be selected.
def _shuffle_and_trim(x):
positions, top_n = x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For batched input, is this positions still having the batch dim? It looks not to me because you are taking the flat_values. But I might be wrong here.

shuffled = tf.random.shuffle(positions)
return shuffled[:top_n]

selected_for_mask = tf.map_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little concerning about the current implementation - the current logic of picking up indices to delete looks complex to me, and you are using tf.map_fn to handle each element separately instead of doing a batch processing, which could get slow.

Rethinking about how we can do the index pick up, can we do the following? Let's assume we have already done the split, so the input is a 2D raggedTensor.

  1. calculate how many words to pick up for each sequence, let's call it num_row_i for ith row.
  2. For each row, pick up 1 - num_row_i indices.
  3. for the selected indices in 2), we do a tf.gather to gather them into a new tensor, which are the words we want to keep.
  4. reduce_join the output of 3) to reconstruct the string for each row.

In this way, we can keep the batch dimension along with the computation. But I am not 100% sure if this works, WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's chat about this tomorrow! Overall I think we need to make sure this can trace, but don't actually need to care about performance that much. If you are running an augmentation, it's because you have very little data. You don't actually care about speed.

I think the reason this is so complex (here and in tf.text), is that there is no good way to uniformly sample from a pool of candidate indices with a max cap without using a map call. It's actually the shuffle call that we can't run without map.

_shuffle_and_trim,
(positions, num_to_select),
fn_output_signature=tf.RaggedTensorSpec(
ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype
),
)
selected_for_mask.flat_values.set_shape([None])

# Construct the mask which is a boolean RT
# Scatter 0's to positions that have been selector for deletion.
update_values = tf.zeros_like(
selected_for_mask.flat_values, "int32"
)
update_indices = selected_for_mask.flat_values
update_indices = tf.expand_dims(update_indices, -1)
update_indices = tf.cast(update_indices, "int32")
mask_flat = tf.ones_like(ragged_words.flat_values, dtype="int32")
mask_flat = tf.tensor_scatter_nd_update(
mask_flat, update_indices, update_values
)
mask = tf.cast(ragged_words.with_flat_values(mask_flat), "bool")

ragged_words = tf.ragged.boolean_mask(ragged_words, mask)
deleted = tf.strings.reduce_join(
ragged_words, axis=-1, separator=" "
)
if scalar_input:
deleted = tf.squeeze(deleted, 0)
if isString:
deleted = deleted[0]
return deleted

def get_config(self):
config = super().get_config()
config.update(
{
"probability": self.probability,
"max_deletions": self.max_deletions,
}
)
return config
87 changes: 87 additions & 0 deletions keras_nlp/layers/random_word_deletion_test.py
@@ -0,0 +1,87 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Random Word Deletion Layer."""

import tensorflow as tf

from keras_nlp.layers import random_word_deletion


class RandomDeletionTest(tf.test.TestCase):
def test_shape_with_scalar(self):
augmenter = random_word_deletion.RandomWordDeletion(
probability=0.5, max_deletions=3
)
input = ["Running Around"]
output = augmenter(input)
self.assertAllEqual(output.shape, tf.convert_to_tensor(input).shape)

def test_get_config_and_from_config(self):

augmenter = random_word_deletion.RandomWordDeletion(
probability=0.5, max_deletions=3
)

expected_config_subset = {"probability": 0.5, "max_deletions": 3}

config = augmenter.get_config()

self.assertEqual(config, {**config, **expected_config_subset})

restored_augmenter = (
random_word_deletion.RandomWordDeletion.from_config(
config,
)
)

self.assertEqual(
restored_augmenter.get_config(),
{**config, **expected_config_subset},
)

def test_augment_first_batch_second(self):
tf.random.get_global_generator().reset_from_seed(30)
tf.random.set_seed(30)
augmenter = random_word_deletion.RandomWordDeletion(
probability=0.5, max_deletions=3
)

ds = tf.data.Dataset.from_tensor_slices(
["samurai or ninja", "keras is good", "tensorflow is a library"]
)
ds = ds.map(augmenter)
ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(3))
output = ds.take(1).get_single_element()

exp_output = [b"samurai", b"is good", b"tensorflow a library"]
for i in range(output.shape[0]):
self.assertAllEqual(output[i], exp_output[i])

def test_batch_first_augment_second(self):
tf.random.get_global_generator().reset_from_seed(30)
tf.random.set_seed(30)
augmenter = random_word_deletion.RandomWordDeletion(
probability=0.5, max_deletions=3
)

ds = tf.data.Dataset.from_tensor_slices(
["samurai or ninja", "keras is good", "tensorflow is a library"]
)
ds = ds.batch(3).map(augmenter)
output = ds.take(1).get_single_element()

exp_output = [b"samurai", b"is good", b"tensorflow"]

for i in range(output.shape[0]):
self.assertAllEqual(output[i], exp_output[i])