Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_nlp/layers/random_deletion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers import RandomDeletion
from keras_nlp.layers.random_deletion import RandomDeletion


class RandomDeletionTest(tf.test.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/layers/random_swap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers import RandomSwap
from keras_nlp.layers.random_swap import RandomSwap


class RandomSwapTest(tf.test.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/layers/token_and_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from tensorflow import keras

import keras_nlp.layers
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
name="token_embedding"
+ str(keras.backend.get_uid("token_embedding")),
)
self.position_embedding = keras_nlp.layers.PositionEmbedding(
self.position_embedding = PositionEmbedding(
sequence_length=sequence_length,
initializer=clone_initializer(self.embeddings_initializer),
name="position_embedding"
Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/token_and_position_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.layers import TokenAndPositionEmbedding
from keras_nlp.layers.token_and_position_embedding import (
TokenAndPositionEmbedding,
)


class TokenAndPositionEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/metrics/bleu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.metrics import Bleu
from keras_nlp.tokenizers import ByteTokenizer
from keras_nlp.metrics.bleu import Bleu
from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer


class BleuTest(tf.test.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/metrics/edit_distance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.metrics import EditDistance
from keras_nlp.metrics.edit_distance import EditDistance


class EditDistanceTest(tf.test.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/metrics/perplexity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import tensorflow as tf

from keras_nlp.metrics import Perplexity
from keras_nlp.metrics.perplexity import Perplexity


class PerplexityTest(tf.test.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/metrics/rouge_l_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.metrics import RougeL
from keras_nlp.metrics.rouge_l import RougeL


class RougeLTest(tf.test.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/metrics/rouge_n_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.metrics import RougeN
from keras_nlp.metrics.rouge_n import RougeN


class RougeNTest(tf.test.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers import PositionEmbedding
from keras_nlp.layers import TransformerDecoder
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import tensorflow as tf

from keras_nlp import samplers
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.samplers.serialization import get as get_sampler
from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tf_utils import truncate_at_token
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self.preprocessor = preprocessor
self.generate_function = None
# Private sampler set by compile.
self._sampler = samplers.get("top_k")
self._sampler = get_sampler("top_k")

@classproperty
def presets(cls):
Expand Down Expand Up @@ -282,7 +282,7 @@ def compile(
jit_compile=jit_compile and xla_compatible and not run_eagerly,
**kwargs,
)
self._sampler = samplers.get(sampler)
self._sampler = get_sampler(sampler)
# Clear the compiled generate function.
self.generate_function = None

Expand Down
6 changes: 4 additions & 2 deletions keras_nlp/models/roberta/roberta_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers import TokenAndPositionEmbedding
from keras_nlp.layers import TransformerEncoder
from keras_nlp.layers.token_and_position_embedding import (
TokenAndPositionEmbedding,
)
from keras_nlp.layers.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
Expand Down
73 changes: 3 additions & 70 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,75 +18,8 @@
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.random_sampler import RandomSampler
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.serialization import deserialize
from keras_nlp.samplers.serialization import get
from keras_nlp.samplers.serialization import serialize
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler


def serialize(sampler):
return keras.utils.serialize_keras_object(sampler)


def deserialize(config, custom_objects=None):
"""Return a `Sampler` object from its config."""
all_classes = {
"beam": BeamSampler,
"greedy": GreedySampler,
"random": RandomSampler,
"top_k": TopKSampler,
"top_p": TopPSampler,
}
return keras.utils.deserialize_keras_object(
config,
module_objects=all_classes,
custom_objects=custom_objects,
printable_module_name="samplers",
)


def get(identifier):
"""Retrieve a KerasNLP sampler by the identifier.

The `identifier` may be the string name of a sampler class or class.

>>> identifier = 'greedy'
>>> sampler = keras_nlp.samplers.get(identifier)

You can also specify `config` of the sampler to this function by passing
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Sampler` class.

>>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}}
>>> sampler = keras_nlp.samplers.get(cfg)

In the case that the `identifier` is a class, this method will return a new
instance of the class by its constructor.

Args:
identifier: String or dict that contains the sampler name or
configurations.

Returns:
Sampler instance base on the input identifier.

Raises:
ValueError: If the input identifier is not a supported type or in a bad
format.
"""

if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, str):
if not identifier.islower():
raise KeyError(
"`keras_nlp.samplers.get()` must take a lowercase string "
f"identifier, but received: {identifier}."
)
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError(
"Could not interpret sampler identifier: " + str(identifier)
)
95 changes: 95 additions & 0 deletions keras_nlp/samplers/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.random_sampler import RandomSampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler


@keras_nlp_export("keras_nlp.samplers.serialize")
def serialize(sampler):
return keras.utils.serialize_keras_object(sampler)


@keras_nlp_export("keras_nlp.samplers.deserialize")
def deserialize(config, custom_objects=None):
"""Return a `Sampler` object from its config."""
all_classes = {
"beam": BeamSampler,
"greedy": GreedySampler,
"random": RandomSampler,
"top_k": TopKSampler,
"top_p": TopPSampler,
}
return keras.utils.deserialize_keras_object(
config,
module_objects=all_classes,
custom_objects=custom_objects,
printable_module_name="samplers",
)


@keras_nlp_export("keras_nlp.samplers.get")
def get(identifier):
"""Retrieve a KerasNLP sampler by the identifier.

The `identifier` may be the string name of a sampler class or class.

>>> identifier = 'greedy'
>>> sampler = keras_nlp.samplers.get(identifier)

You can also specify `config` of the sampler to this function by passing
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Sampler` class.

>>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}}
>>> sampler = keras_nlp.samplers.get(cfg)

In the case that the `identifier` is a class, this method will return a new
instance of the class by its constructor.

Args:
identifier: String or dict that contains the sampler name or
configurations.

Returns:
Sampler instance base on the input identifier.

Raises:
ValueError: If the input identifier is not a supported type or in a bad
format.
"""

if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, str):
if not identifier.islower():
raise KeyError(
"`keras_nlp.samplers.get()` must take a lowercase string "
f"identifier, but received: {identifier}."
)
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError(
"Could not interpret sampler identifier: " + str(identifier)
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,34 @@

import tensorflow as tf

import keras_nlp
from keras_nlp.samplers.serialization import deserialize
from keras_nlp.samplers.serialization import get
from keras_nlp.samplers.serialization import serialize
from keras_nlp.samplers.top_k_sampler import TopKSampler


class SamplerTest(tf.test.TestCase):
class SerializationTest(tf.test.TestCase):
def test_serialization(self):
sampler = TopKSampler(k=5)
restored = keras_nlp.samplers.deserialize(
keras_nlp.samplers.serialize(sampler)
)
restored = deserialize(serialize(sampler))
self.assertDictEqual(sampler.get_config(), restored.get_config())

def test_get(self):
# Test get from string.
identifier = "top_k"
sampler = keras_nlp.samplers.get(identifier)
sampler = get(identifier)
self.assertIsInstance(sampler, TopKSampler)

# Test dict identifier.
original_sampler = keras_nlp.samplers.TopKSampler(k=7)
config = keras_nlp.samplers.serialize(original_sampler)
restored_sampler = keras_nlp.samplers.get(config)
original_sampler = TopKSampler(k=7)
config = serialize(original_sampler)
restored_sampler = get(config)
self.assertDictEqual(
keras_nlp.samplers.serialize(restored_sampler),
keras_nlp.samplers.serialize(original_sampler),
serialize(restored_sampler),
serialize(original_sampler),
)

# Test identifier is already a sampler instance.
original_sampler = keras_nlp.samplers.TopKSampler(k=7)
restored_sampler = keras_nlp.samplers.get(original_sampler)
original_sampler = TopKSampler(k=7)
restored_sampler = get(original_sampler)
self.assertEqual(original_sampler, restored_sampler)