Skip to content

Commit

Permalink
Random Deletion Layer (#214)
Browse files Browse the repository at this point in the history
* Random Deletion Working

* Added to init

* WOrking

* Working

* Current Status

* Working Layer More Tests to be Added

* Fixed Scalar Case

* Added Comments

* Minor Fixes

* Major Refactors and Fixes, ToDo - Docs, Tests

* Fixed Shape Issues for Scalar Lists

* Finalized Tests and DocString

* Ran Stylers Added More Descriptive DocString

* Fixed Failing Docstring Tests

* Removed Map Call and Unsupported Test

* Shape Fixes

* Working

* Working

* Changing Parent Class

* Changes

* Formatter Ran

* Finalized

* Addresed Review Comments

* Fornatter

* Added new Tests

* Fan Formatter

* Skip Works

* New Randomness

* All Testing Done

* Review Changes

* Addressed all Review Comments

* Copy edits for docstrings

* Finishes

* Changed Tokenizer Import

* Addressed Reviews

* Fix typo

Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>
  • Loading branch information
aflah02 and mattdangerw committed Jul 27, 2022
1 parent 7c338db commit 58f4bbb
Show file tree
Hide file tree
Showing 3 changed files with 493 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras_nlp/layers/__init__.py
Expand Up @@ -17,6 +17,7 @@
from keras_nlp.layers.mlm_mask_generator import MLMMaskGenerator
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.random_deletion import RandomDeletion
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding
from keras_nlp.layers.start_end_packer import StartEndPacker
from keras_nlp.layers.token_and_position_embedding import (
Expand Down
266 changes: 266 additions & 0 deletions keras_nlp/layers/random_deletion.py
@@ -0,0 +1,266 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops.ragged import ragged_array_ops


class RandomDeletion(keras.layers.Layer):
"""Augments input by randomly deleting tokens.
This layer comes in handy when you need to generate new data using deletion
augmentation as described in the paper [EDA: Easy Data Augmentation
Techniques for Boosting Performance on Text Classification Tasks]
(https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be
pretokenized so that each token can be individually treated as a possible
candidate for deletion.
Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and
either rank-1 or rank-2.
Args:
rate: The probability of a token being chosen for deletion.
max_deletions: The maximum number of tokens to delete.
skip_list: A list of token values that should not be considered
candidates for deletion.
skip_fn: A function that takes as input a scalar tensor token and
returns as output a scalar tensor True/False value. A value of
True indicates that the token should not be considered a
candidate for deletion. This function must be tracable--it
should consist of tensorflow operations.
skip_py_fn: A function that takes as input a python token value and
returns as output `True` or `False`. A value of True
indicates that should not be considered a candidate for deletion.
Unlike the `skip_fn` argument, this argument need not be
tracable--it can be any python function.
seed: A seed for the rng.
Examples:
Word level usage.
>>> keras.utils.set_random_seed(1337)
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42)
>>> augmented=augmenter(inputs)
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'],
dtype=object)>
Character level usage.
>>> keras.utils.set_random_seed(1337)
>>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8")
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4, seed=42)
>>> augmented=augmenter(inputs)
>>> tf.strings.reduce_join(augmented, axis=-1)
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'H Dude', b'pedUp'],
dtype=object)>
Usage with skip_list.
>>> keras.utils.set_random_seed(1337)
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4,
... skip_list=["Keras", "Tensorflow"], seed=42)
>>> augmented=augmenter(inputs)
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
<tf.Tensor: shape=(2,), dtype=string,
numpy=array([b'I like', b'Keras Tensorflow'], dtype=object)>
Usage with skip_fn.
>>> def skip_fn(word):
... return tf.strings.regex_full_match(word, r"\\pP")
>>> keras.utils.set_random_seed(1337)
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
>>> augmenter=keras_nlp.layers.RandomDeletion(rate=0.4,
... skip_fn=skip_fn, seed=42)
>>> augmented=augmenter(inputs)
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'],
dtype=object)>
Usage with skip_py_fn.
>>> def skip_py_fn(word):
... return len(word) < 4
>>> keras.utils.set_random_seed(1337)
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
>>> augmenter=RandomDeletion(rate=0.4,
... skip_py_fn=skip_py_fn, seed=42)
>>> augmented=augmenter(inputs)
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Hey I', b'and'],
dtype=object)>
"""

def __init__(
self,
rate,
max_deletions=None,
skip_list=None,
skip_fn=None,
skip_py_fn=None,
seed=None,
name=None,
**kwargs,
):
# Check dtype and provide a default.
if "dtype" not in kwargs or kwargs["dtype"] is None:
kwargs["dtype"] = tf.int32
else:
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
if not dtype.is_integer and dtype != tf.string:
raise ValueError(
"Output dtype must be one of `'string'`, `'int32'`, and "
f"`'int64'`. Received: dtype={dtype}"
)

super().__init__(name=name, **kwargs)
self.rate = rate
self.max_deletions = max_deletions
self.seed = random.randint(1, 1e9) if seed is None else seed
self._generator = tf.random.Generator.from_seed(self.seed)
self.skip_list = skip_list
self.skip_fn = skip_fn
self.skip_py_fn = skip_py_fn

if self.rate > 1 or self.rate < 0:
raise ValueError(
"Rate must be between 0 and 1 (both inclusive)."
f"Received: rate={rate}"
)

if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2:
raise ValueError(
"Exactly one of `skip_list`, `skip_fn`, `skip_py_fn` must be "
"provided."
)

if self.skip_list:
self.StaticHashTable = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
tf.convert_to_tensor(self.skip_list),
tf.convert_to_tensor([True] * len(self.skip_list)),
),
default_value=False,
)

def call(self, inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)

input_is_1d = False
if inputs.shape.rank < 1 or inputs.shape.rank > 2:
raise ValueError(
"Input must either be rank 1 or rank 2. Received input with "
f"rank={inputs.shape.rank}"
)
elif inputs.shape.rank == 1:
input_is_1d = True
# Add a new axis at the beginning.
inputs = tf.expand_dims(inputs, axis=0)
if isinstance(inputs, tf.Tensor):
# Convert to ragged tensor.
inputs = tf.RaggedTensor.from_tensor(inputs)

skip_masks = None
if self.skip_list:
skip_masks = self.StaticHashTable.lookup(inputs.flat_values)
elif self.skip_fn:
skip_masks = tf.map_fn(
self.skip_fn, inputs.flat_values, fn_output_signature=tf.bool
)
elif self.skip_py_fn:

def string_fn(token):
return self.skip_py_fn(token.numpy().decode("utf-8"))

def int_fn(token):
return self.skip_py_fn(token.numpy())

py_fn = string_fn if inputs.dtype == tf.string else int_fn

skip_masks = tf.map_fn(
lambda x: tf.py_function(py_fn, [x], tf.bool),
inputs.flat_values,
fn_output_signature=tf.bool,
)

positions_flat = tf.range(tf.size(inputs.flat_values))
positions = inputs.with_flat_values(positions_flat)
if skip_masks is not None:
skip_masks = tf.logical_not(skip_masks)
skip_masks.set_shape([None])
positions = ragged_array_ops.boolean_mask(
positions, inputs.with_flat_values(skip_masks)
)

# Figure out how many we are going to select.
token_counts = tf.cast(inputs.row_lengths(), "float32")
num_to_select = tf.random.stateless_binomial(
shape=tf.shape(token_counts),
seed=self._generator.make_seeds()[:, 0],
counts=token_counts,
probs=self.rate,
)
if self.max_deletions is not None:
num_to_select = tf.math.minimum(num_to_select, self.max_deletions)
num_to_select = tf.cast(num_to_select, "int64")

# Shuffle and trim to items that are going to be selected.
def _shuffle_and_trim(x):
positions, top_n = x
shuffled = tf.random.shuffle(positions, seed=self.seed)
return shuffled[:top_n]

selected_for_mask = tf.map_fn(
_shuffle_and_trim,
(positions, num_to_select),
fn_output_signature=tf.RaggedTensorSpec(
ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype
),
)
selected_for_mask.flat_values.set_shape([None])

# Construct the mask which is a boolean RT
# Scatter 0's to positions that have been selector for deletion.
update_values = tf.zeros_like(selected_for_mask.flat_values, "int32")
update_indices = selected_for_mask.flat_values
update_indices = tf.expand_dims(update_indices, -1)
update_indices = tf.cast(update_indices, "int32")
mask_flat = tf.ones_like(inputs.flat_values, dtype="int32")
mask_flat = tf.tensor_scatter_nd_update(
mask_flat, update_indices, update_values
)
mask = tf.cast(inputs.with_flat_values(mask_flat), "bool")

inputs = tf.ragged.boolean_mask(inputs, mask)

if input_is_1d:
inputs = tf.squeeze(inputs, axis=0)

return inputs

def get_config(self):
config = super().get_config()
config.update(
{
"rate": self.rate,
"max_deletions": self.max_deletions,
"seed": self.seed,
"skip_list": self.skip_list,
"skip_fn": self.skip_fn,
"skip_py_fn": self.skip_py_fn,
}
)
return config

0 comments on commit 58f4bbb

Please sign in to comment.