diff --git a/keras_nlp/layers/cached_multi_head_attention.py b/keras_nlp/layers/cached_multi_head_attention.py index a86bb93d71..6c5cb57481 100644 --- a/keras_nlp/layers/cached_multi_head_attention.py +++ b/keras_nlp/layers/cached_multi_head_attention.py @@ -17,7 +17,10 @@ from tensorflow import keras from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice +from keras_nlp.api_export import keras_nlp_export + +@keras_nlp_export("keras_nlp.layers.CachedMultiHeadAttention") class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): """MutliHeadAttention layer with cache support. diff --git a/keras_nlp/layers/f_net_encoder.py b/keras_nlp/layers/f_net_encoder.py index e84f32e1e6..79f7b72d87 100644 --- a/keras_nlp/layers/f_net_encoder.py +++ b/keras_nlp/layers/f_net_encoder.py @@ -17,10 +17,11 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export from keras_nlp.utils.keras_utils import clone_initializer -@keras.utils.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.layers.FNetEncoder") class FNetEncoder(keras.layers.Layer): """FNet encoder. diff --git a/keras_nlp/layers/masked_lm_head.py b/keras_nlp/layers/masked_lm_head.py index 2f749eb0c7..f26b0303e7 100644 --- a/keras_nlp/layers/masked_lm_head.py +++ b/keras_nlp/layers/masked_lm_head.py @@ -17,8 +17,10 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export -# TODO(mattdangerw): register this class as serializable. + +@keras_nlp_export("keras_nlp.layers.MaskedLMHead") class MaskedLMHead(keras.layers.Layer): """Masked Language Model (MaskedLM) head. diff --git a/keras_nlp/layers/masked_lm_mask_generator.py b/keras_nlp/layers/masked_lm_mask_generator.py index 069f35b8f6..f235430d8a 100644 --- a/keras_nlp/layers/masked_lm_mask_generator.py +++ b/keras_nlp/layers/masked_lm_mask_generator.py @@ -15,6 +15,7 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export from keras_nlp.utils.tf_utils import assert_tf_text_installed try: @@ -23,7 +24,7 @@ tf_text = None -@keras.utils.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.layers.MaskedLMMaskGenerator") class MaskedLMMaskGenerator(keras.layers.Layer): """Layer that applies language model masking. diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index ec3f118498..2cc3b65076 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -17,6 +17,7 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export from keras_nlp.utils.tf_utils import assert_tf_text_installed try: @@ -25,7 +26,7 @@ tf_text = None -@keras.utils.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.layers.MultiSegmentPacker") class MultiSegmentPacker(keras.layers.Layer): """Packs multiple sequences into a single fixed width model input. diff --git a/keras_nlp/layers/position_embedding.py b/keras_nlp/layers/position_embedding.py index 64b041732f..2602b275c1 100644 --- a/keras_nlp/layers/position_embedding.py +++ b/keras_nlp/layers/position_embedding.py @@ -17,8 +17,10 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export -@keras.utils.register_keras_serializable(package="keras_nlp") + +@keras_nlp_export("keras_nlp.layers.PositionEmbedding") class PositionEmbedding(keras.layers.Layer): """A layer which learns a position embedding for inputs sequences. diff --git a/keras_nlp/layers/random_deletion.py b/keras_nlp/layers/random_deletion.py index f4626d156a..a2d4305d6a 100644 --- a/keras_nlp/layers/random_deletion.py +++ b/keras_nlp/layers/random_deletion.py @@ -16,8 +16,10 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export -@keras.utils.register_keras_serializable(package="keras_nlp") + +@keras_nlp_export("keras_nlp.layers.RandomDeletion") class RandomDeletion(keras.layers.Layer): """Augments input by randomly deleting tokens. diff --git a/keras_nlp/layers/random_swap.py b/keras_nlp/layers/random_swap.py index 67ff52103f..f7460e3a60 100644 --- a/keras_nlp/layers/random_swap.py +++ b/keras_nlp/layers/random_swap.py @@ -16,7 +16,10 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export + +@keras_nlp_export("keras_nlp.layers.RandomSwap") class RandomSwap(keras.layers.Layer): """Augments input by randomly swapping words. diff --git a/keras_nlp/layers/sine_position_encoding.py b/keras_nlp/layers/sine_position_encoding.py index 3fabb5426c..2305d90bbe 100644 --- a/keras_nlp/layers/sine_position_encoding.py +++ b/keras_nlp/layers/sine_position_encoding.py @@ -17,8 +17,10 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export -@keras.utils.register_keras_serializable(package="keras_nlp") + +@keras_nlp_export("keras_nlp.layers.SinePositionEncoding") class SinePositionEncoding(keras.layers.Layer): """Sinusoidal positional encoding layer. diff --git a/keras_nlp/layers/start_end_packer.py b/keras_nlp/layers/start_end_packer.py index 028d6ff42a..a0102d6776 100644 --- a/keras_nlp/layers/start_end_packer.py +++ b/keras_nlp/layers/start_end_packer.py @@ -17,8 +17,10 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export -@keras.utils.register_keras_serializable(package="keras_nlp") + +@keras_nlp_export("keras_nlp.layers.StartEndPacker") class StartEndPacker(keras.layers.Layer): """Adds start and end tokens to a sequence and pads to a fixed length. diff --git a/keras_nlp/layers/token_and_position_embedding.py b/keras_nlp/layers/token_and_position_embedding.py index 5f52c3eb2a..b1cc805871 100644 --- a/keras_nlp/layers/token_and_position_embedding.py +++ b/keras_nlp/layers/token_and_position_embedding.py @@ -17,10 +17,11 @@ from tensorflow import keras import keras_nlp.layers +from keras_nlp.api_export import keras_nlp_export from keras_nlp.utils.keras_utils import clone_initializer -@keras.utils.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.layers.TokenAndPositionEmbedding") class TokenAndPositionEmbedding(keras.layers.Layer): """A layer which sums a token and position embedding. diff --git a/keras_nlp/layers/transformer_decoder.py b/keras_nlp/layers/transformer_decoder.py index 469c0a4f6f..8e6f0ba804 100644 --- a/keras_nlp/layers/transformer_decoder.py +++ b/keras_nlp/layers/transformer_decoder.py @@ -17,6 +17,7 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.cached_multi_head_attention import ( CachedMultiHeadAttention, ) @@ -28,7 +29,7 @@ ) -@keras.utils.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.layers.TransformerDecoder") class TransformerDecoder(keras.layers.Layer): """Transformer decoder.