diff --git a/keras_nlp/layers/random_deletion_test.py b/keras_nlp/layers/random_deletion_test.py index dc9ac99e6d..9f028b5700 100644 --- a/keras_nlp/layers/random_deletion_test.py +++ b/keras_nlp/layers/random_deletion_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.layers import RandomDeletion +from keras_nlp.layers.random_deletion import RandomDeletion class RandomDeletionTest(tf.test.TestCase): diff --git a/keras_nlp/layers/random_swap_test.py b/keras_nlp/layers/random_swap_test.py index 7f212c2a89..d03772e814 100644 --- a/keras_nlp/layers/random_swap_test.py +++ b/keras_nlp/layers/random_swap_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.layers import RandomSwap +from keras_nlp.layers.random_swap import RandomSwap class RandomSwapTest(tf.test.TestCase): diff --git a/keras_nlp/layers/token_and_position_embedding.py b/keras_nlp/layers/token_and_position_embedding.py index b1cc805871..9eb80351ef 100644 --- a/keras_nlp/layers/token_and_position_embedding.py +++ b/keras_nlp/layers/token_and_position_embedding.py @@ -16,8 +16,8 @@ from tensorflow import keras -import keras_nlp.layers from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.position_embedding import PositionEmbedding from keras_nlp.utils.keras_utils import clone_initializer @@ -96,7 +96,7 @@ def __init__( name="token_embedding" + str(keras.backend.get_uid("token_embedding")), ) - self.position_embedding = keras_nlp.layers.PositionEmbedding( + self.position_embedding = PositionEmbedding( sequence_length=sequence_length, initializer=clone_initializer(self.embeddings_initializer), name="position_embedding" diff --git a/keras_nlp/layers/token_and_position_embedding_test.py b/keras_nlp/layers/token_and_position_embedding_test.py index 455d2899d6..bab027eba4 100644 --- a/keras_nlp/layers/token_and_position_embedding_test.py +++ b/keras_nlp/layers/token_and_position_embedding_test.py @@ -19,7 +19,9 @@ from absl.testing import parameterized from tensorflow import keras -from keras_nlp.layers import TokenAndPositionEmbedding +from keras_nlp.layers.token_and_position_embedding import ( + TokenAndPositionEmbedding, +) class TokenAndPositionEmbeddingTest(tf.test.TestCase, parameterized.TestCase): diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py index 20c2c815f6..fbbc6b8f19 100644 --- a/keras_nlp/metrics/bleu_test.py +++ b/keras_nlp/metrics/bleu_test.py @@ -17,8 +17,8 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.metrics import Bleu -from keras_nlp.tokenizers import ByteTokenizer +from keras_nlp.metrics.bleu import Bleu +from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer class BleuTest(tf.test.TestCase): diff --git a/keras_nlp/metrics/edit_distance_test.py b/keras_nlp/metrics/edit_distance_test.py index 743eab84a7..04a55bfbc4 100644 --- a/keras_nlp/metrics/edit_distance_test.py +++ b/keras_nlp/metrics/edit_distance_test.py @@ -17,7 +17,7 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.metrics import EditDistance +from keras_nlp.metrics.edit_distance import EditDistance class EditDistanceTest(tf.test.TestCase): diff --git a/keras_nlp/metrics/perplexity_test.py b/keras_nlp/metrics/perplexity_test.py index d4d5da9244..311d9fb092 100644 --- a/keras_nlp/metrics/perplexity_test.py +++ b/keras_nlp/metrics/perplexity_test.py @@ -16,7 +16,7 @@ import tensorflow as tf -from keras_nlp.metrics import Perplexity +from keras_nlp.metrics.perplexity import Perplexity class PerplexityTest(tf.test.TestCase): diff --git a/keras_nlp/metrics/rouge_l_test.py b/keras_nlp/metrics/rouge_l_test.py index 31f1661a7f..bb327bb166 100644 --- a/keras_nlp/metrics/rouge_l_test.py +++ b/keras_nlp/metrics/rouge_l_test.py @@ -17,7 +17,7 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.metrics import RougeL +from keras_nlp.metrics.rouge_l import RougeL class RougeLTest(tf.test.TestCase): diff --git a/keras_nlp/metrics/rouge_n_test.py b/keras_nlp/metrics/rouge_n_test.py index 1fb2fd5534..945ec89079 100644 --- a/keras_nlp/metrics/rouge_n_test.py +++ b/keras_nlp/metrics/rouge_n_test.py @@ -17,7 +17,7 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.metrics import RougeN +from keras_nlp.metrics.rouge_n import RougeN class RougeNTest(tf.test.TestCase): diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index 5d81615b2a..d2eddc6484 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -20,8 +20,8 @@ from tensorflow import keras from keras_nlp.api_export import keras_nlp_export -from keras_nlp.layers import PositionEmbedding -from keras_nlp.layers import TransformerDecoder +from keras_nlp.layers.position_embedding import PositionEmbedding +from keras_nlp.layers.transformer_decoder import TransformerDecoder from keras_nlp.models.backbone import Backbone from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index de8f438712..b772dfa5be 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -17,7 +17,6 @@ import tensorflow as tf -from keras_nlp import samplers from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( @@ -25,6 +24,7 @@ ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.samplers.serialization import get as get_sampler from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tf_utils import truncate_at_token @@ -194,7 +194,7 @@ def __init__( self.preprocessor = preprocessor self.generate_function = None # Private sampler set by compile. - self._sampler = samplers.get("top_k") + self._sampler = get_sampler("top_k") @classproperty def presets(cls): @@ -282,7 +282,7 @@ def compile( jit_compile=jit_compile and xla_compatible and not run_eagerly, **kwargs, ) - self._sampler = samplers.get(sampler) + self._sampler = get_sampler(sampler) # Clear the compiled generate function. self.generate_function = None diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 4683d54c42..909ef20cda 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -20,8 +20,10 @@ from tensorflow import keras from keras_nlp.api_export import keras_nlp_export -from keras_nlp.layers import TokenAndPositionEmbedding -from keras_nlp.layers import TransformerEncoder +from keras_nlp.layers.token_and_position_embedding import ( + TokenAndPositionEmbedding, +) +from keras_nlp.layers.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 82fe003547..e76017b8db 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -18,75 +18,8 @@ from keras_nlp.samplers.greedy_sampler import GreedySampler from keras_nlp.samplers.random_sampler import RandomSampler from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.serialization import deserialize +from keras_nlp.samplers.serialization import get +from keras_nlp.samplers.serialization import serialize from keras_nlp.samplers.top_k_sampler import TopKSampler from keras_nlp.samplers.top_p_sampler import TopPSampler - - -def serialize(sampler): - return keras.utils.serialize_keras_object(sampler) - - -def deserialize(config, custom_objects=None): - """Return a `Sampler` object from its config.""" - all_classes = { - "beam": BeamSampler, - "greedy": GreedySampler, - "random": RandomSampler, - "top_k": TopKSampler, - "top_p": TopPSampler, - } - return keras.utils.deserialize_keras_object( - config, - module_objects=all_classes, - custom_objects=custom_objects, - printable_module_name="samplers", - ) - - -def get(identifier): - """Retrieve a KerasNLP sampler by the identifier. - - The `identifier` may be the string name of a sampler class or class. - - >>> identifier = 'greedy' - >>> sampler = keras_nlp.samplers.get(identifier) - - You can also specify `config` of the sampler to this function by passing - dict containing `class_name` and `config` as an identifier. Also note that - the `class_name` must map to a `Sampler` class. - - >>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}} - >>> sampler = keras_nlp.samplers.get(cfg) - - In the case that the `identifier` is a class, this method will return a new - instance of the class by its constructor. - - Args: - identifier: String or dict that contains the sampler name or - configurations. - - Returns: - Sampler instance base on the input identifier. - - Raises: - ValueError: If the input identifier is not a supported type or in a bad - format. - """ - - if identifier is None: - return None - if isinstance(identifier, dict): - return deserialize(identifier) - elif isinstance(identifier, str): - if not identifier.islower(): - raise KeyError( - "`keras_nlp.samplers.get()` must take a lowercase string " - f"identifier, but received: {identifier}." - ) - return deserialize(identifier) - elif callable(identifier): - return identifier - else: - raise ValueError( - "Could not interpret sampler identifier: " + str(identifier) - ) diff --git a/keras_nlp/samplers/serialization.py b/keras_nlp/samplers/serialization.py new file mode 100644 index 0000000000..56adb7d226 --- /dev/null +++ b/keras_nlp/samplers/serialization.py @@ -0,0 +1,95 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorflow import keras + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.samplers.beam_sampler import BeamSampler +from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.random_sampler import RandomSampler +from keras_nlp.samplers.top_k_sampler import TopKSampler +from keras_nlp.samplers.top_p_sampler import TopPSampler + + +@keras_nlp_export("keras_nlp.samplers.serialize") +def serialize(sampler): + return keras.utils.serialize_keras_object(sampler) + + +@keras_nlp_export("keras_nlp.samplers.deserialize") +def deserialize(config, custom_objects=None): + """Return a `Sampler` object from its config.""" + all_classes = { + "beam": BeamSampler, + "greedy": GreedySampler, + "random": RandomSampler, + "top_k": TopKSampler, + "top_p": TopPSampler, + } + return keras.utils.deserialize_keras_object( + config, + module_objects=all_classes, + custom_objects=custom_objects, + printable_module_name="samplers", + ) + + +@keras_nlp_export("keras_nlp.samplers.get") +def get(identifier): + """Retrieve a KerasNLP sampler by the identifier. + + The `identifier` may be the string name of a sampler class or class. + + >>> identifier = 'greedy' + >>> sampler = keras_nlp.samplers.get(identifier) + + You can also specify `config` of the sampler to this function by passing + dict containing `class_name` and `config` as an identifier. Also note that + the `class_name` must map to a `Sampler` class. + + >>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}} + >>> sampler = keras_nlp.samplers.get(cfg) + + In the case that the `identifier` is a class, this method will return a new + instance of the class by its constructor. + + Args: + identifier: String or dict that contains the sampler name or + configurations. + + Returns: + Sampler instance base on the input identifier. + + Raises: + ValueError: If the input identifier is not a supported type or in a bad + format. + """ + + if identifier is None: + return None + if isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, str): + if not identifier.islower(): + raise KeyError( + "`keras_nlp.samplers.get()` must take a lowercase string " + f"identifier, but received: {identifier}." + ) + return deserialize(identifier) + elif callable(identifier): + return identifier + else: + raise ValueError( + "Could not interpret sampler identifier: " + str(identifier) + ) diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/serialization_test.py similarity index 64% rename from keras_nlp/samplers/sampler_test.py rename to keras_nlp/samplers/serialization_test.py index 49dff87a9d..b9fcff7e91 100644 --- a/keras_nlp/samplers/sampler_test.py +++ b/keras_nlp/samplers/serialization_test.py @@ -15,34 +15,34 @@ import tensorflow as tf -import keras_nlp +from keras_nlp.samplers.serialization import deserialize +from keras_nlp.samplers.serialization import get +from keras_nlp.samplers.serialization import serialize from keras_nlp.samplers.top_k_sampler import TopKSampler -class SamplerTest(tf.test.TestCase): +class SerializationTest(tf.test.TestCase): def test_serialization(self): sampler = TopKSampler(k=5) - restored = keras_nlp.samplers.deserialize( - keras_nlp.samplers.serialize(sampler) - ) + restored = deserialize(serialize(sampler)) self.assertDictEqual(sampler.get_config(), restored.get_config()) def test_get(self): # Test get from string. identifier = "top_k" - sampler = keras_nlp.samplers.get(identifier) + sampler = get(identifier) self.assertIsInstance(sampler, TopKSampler) # Test dict identifier. - original_sampler = keras_nlp.samplers.TopKSampler(k=7) - config = keras_nlp.samplers.serialize(original_sampler) - restored_sampler = keras_nlp.samplers.get(config) + original_sampler = TopKSampler(k=7) + config = serialize(original_sampler) + restored_sampler = get(config) self.assertDictEqual( - keras_nlp.samplers.serialize(restored_sampler), - keras_nlp.samplers.serialize(original_sampler), + serialize(restored_sampler), + serialize(original_sampler), ) # Test identifier is already a sampler instance. - original_sampler = keras_nlp.samplers.TopKSampler(k=7) - restored_sampler = keras_nlp.samplers.get(original_sampler) + original_sampler = TopKSampler(k=7) + restored_sampler = get(original_sampler) self.assertEqual(original_sampler, restored_sampler)