Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_nlp/layers/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from tensorflow import keras
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras_nlp.api_export import keras_nlp_export


@keras_nlp_export("keras_nlp.layers.CachedMultiHeadAttention")
class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
"""MutliHeadAttention layer with cache support.

Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/layers/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.utils.keras_utils import clone_initializer


@keras.utils.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.layers.FNetEncoder")
class FNetEncoder(keras.layers.Layer):
"""FNet encoder.

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export

# TODO(mattdangerw): register this class as serializable.

@keras_nlp_export("keras_nlp.layers.MaskedLMHead")
class MaskedLMHead(keras.layers.Layer):
"""Masked Language Model (MaskedLM) head.

Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/layers/masked_lm_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.utils.tf_utils import assert_tf_text_installed

try:
Expand All @@ -23,7 +24,7 @@
tf_text = None


@keras.utils.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.layers.MaskedLMMaskGenerator")
class MaskedLMMaskGenerator(keras.layers.Layer):
"""Layer that applies language model masking.

Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/layers/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.utils.tf_utils import assert_tf_text_installed

try:
Expand All @@ -25,7 +26,7 @@
tf_text = None


@keras.utils.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.layers.MultiSegmentPacker")
class MultiSegmentPacker(keras.layers.Layer):
"""Packs multiple sequences into a single fixed width model input.

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export

@keras.utils.register_keras_serializable(package="keras_nlp")

@keras_nlp_export("keras_nlp.layers.PositionEmbedding")
class PositionEmbedding(keras.layers.Layer):
"""A layer which learns a position embedding for inputs sequences.

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/random_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export

@keras.utils.register_keras_serializable(package="keras_nlp")

@keras_nlp_export("keras_nlp.layers.RandomDeletion")
class RandomDeletion(keras.layers.Layer):
"""Augments input by randomly deleting tokens.

Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/layers/random_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export


@keras_nlp_export("keras_nlp.layers.RandomSwap")
class RandomSwap(keras.layers.Layer):
"""Augments input by randomly swapping words.

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/sine_position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export

@keras.utils.register_keras_serializable(package="keras_nlp")

@keras_nlp_export("keras_nlp.layers.SinePositionEncoding")
class SinePositionEncoding(keras.layers.Layer):
"""Sinusoidal positional encoding layer.

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export

@keras.utils.register_keras_serializable(package="keras_nlp")

@keras_nlp_export("keras_nlp.layers.StartEndPacker")
class StartEndPacker(keras.layers.Layer):
"""Adds start and end tokens to a sequence and pads to a fixed length.

Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/layers/token_and_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from tensorflow import keras

import keras_nlp.layers
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.utils.keras_utils import clone_initializer


@keras.utils.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.layers.TokenAndPositionEmbedding")
class TokenAndPositionEmbedding(keras.layers.Layer):
"""A layer which sums a token and position embedding.

Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.cached_multi_head_attention import (
CachedMultiHeadAttention,
)
Expand All @@ -28,7 +29,7 @@
)


@keras.utils.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.layers.TransformerDecoder")
class TransformerDecoder(keras.layers.Layer):
"""Transformer decoder.

Expand Down