Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions keras_nlp/models/whisper/whisper_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def whisper_kernel_initializer(stddev=0.02):

@keras_nlp_export("keras_nlp.models.WhisperBackbone")
class WhisperBackbone(Backbone):
"""Whisper encoder-decoder network for speech.
"""A Whisper encoder-decoder network for speech.

This class implements a Transformer-based encoder-decoder model as
described in
Expand All @@ -48,7 +48,7 @@ class WhisperBackbone(Backbone):

The default constructor gives a fully customizable, randomly initialized Whisper
model with any number of layers, heads, and embedding dimensions. To load
preset architectures and weights, use the `from_preset` constructor.
preset architectures and weights, use the `from_preset()` constructor.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
Expand Down Expand Up @@ -83,17 +83,17 @@ class WhisperBackbone(Backbone):
),
}

# Randomly initialized Whisper encoder-decoder model with a custom config
# Randomly initialized Whisper encoder-decoder model with a custom config.
model = keras_nlp.models.WhisperBackbone(
vocabulary_size=51864,
num_layers=6,
num_heads=8,
hidden_dim=512,
intermediate_dim=2048,
num_layers=4,
num_heads=4,
hidden_dim=256,
intermediate_dim=512,
max_encoder_sequence_length=128,
max_decoder_sequence_length=64,
max_decoder_sequence_length=128,
)
output = model(input_data)
model(input_data)
```
"""

Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/whisper/whisper_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@keras.utils.register_keras_serializable(package="keras_nlp")
class WhisperDecoder(TransformerDecoder):
"""Whisper decoder.
"""A Whisper decoder.

Inherits from `keras_nlp.layers.TransformerDecoder`, and overrides the
`_build` method so as to remove the bias term from the key projection layer.
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/whisper/whisper_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@keras.utils.register_keras_serializable(package="keras_nlp")
class WhisperEncoder(TransformerEncoder):
"""Whisper encoder.
"""A Whisper encoder.

Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the
`_build` method so as to remove the bias term from the key projection layer.
Expand Down