diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index fcb3530252..5f5e3d3ab6 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -38,7 +38,7 @@ def whisper_kernel_initializer(stddev=0.02): @keras_nlp_export("keras_nlp.models.WhisperBackbone") class WhisperBackbone(Backbone): - """Whisper encoder-decoder network for speech. + """A Whisper encoder-decoder network for speech. This class implements a Transformer-based encoder-decoder model as described in @@ -48,7 +48,7 @@ class WhisperBackbone(Backbone): The default constructor gives a fully customizable, randomly initialized Whisper model with any number of layers, heads, and embedding dimensions. To load - preset architectures and weights, use the `from_preset` constructor. + preset architectures and weights, use the `from_preset()` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a @@ -83,17 +83,17 @@ class WhisperBackbone(Backbone): ), } - # Randomly initialized Whisper encoder-decoder model with a custom config + # Randomly initialized Whisper encoder-decoder model with a custom config. model = keras_nlp.models.WhisperBackbone( vocabulary_size=51864, - num_layers=6, - num_heads=8, - hidden_dim=512, - intermediate_dim=2048, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, max_encoder_sequence_length=128, - max_decoder_sequence_length=64, + max_decoder_sequence_length=128, ) - output = model(input_data) + model(input_data) ``` """ diff --git a/keras_nlp/models/whisper/whisper_decoder.py b/keras_nlp/models/whisper/whisper_decoder.py index d7d7d2867e..48551dbec7 100644 --- a/keras_nlp/models/whisper/whisper_decoder.py +++ b/keras_nlp/models/whisper/whisper_decoder.py @@ -20,7 +20,7 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class WhisperDecoder(TransformerDecoder): - """Whisper decoder. + """A Whisper decoder. Inherits from `keras_nlp.layers.TransformerDecoder`, and overrides the `_build` method so as to remove the bias term from the key projection layer. diff --git a/keras_nlp/models/whisper/whisper_encoder.py b/keras_nlp/models/whisper/whisper_encoder.py index ec1040891c..f779da26a4 100644 --- a/keras_nlp/models/whisper/whisper_encoder.py +++ b/keras_nlp/models/whisper/whisper_encoder.py @@ -20,7 +20,7 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class WhisperEncoder(TransformerEncoder): - """Whisper encoder. + """A Whisper encoder. Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the `_build` method so as to remove the bias term from the key projection layer.