Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ class TransformerDecoder(keras.layers.Layer):
activation function of feedforward network.
layer_norm_epsilon: float, defaults to 1e-5. The eps value in layer
normalization components.
kernel_initializer: string or tf.keras.initializers initializer,
defaults to "glorot_uniform". The kernel initializer for
the dense and multiheaded attention layers.
bias_initializer: string or tf.keras.initializers initializer,
defaults to "zeros". The bias initializer for
the dense and multiheaded attention layers.
name: string, defaults to None. The name of the layer.
**kwargs: other keyword arguments.

Examples:
```python
# Create a single transformer decoder layer.
decoder = keras_nlp.layer.TransformerDecoder(
decoder = keras_nlp.layers.TransformerDecoder(
intermediate_dim=64, num_heads=8)

# Create a simple model containing the decoder.
Expand Down Expand Up @@ -74,15 +80,19 @@ def __init__(
dropout=0,
activation="relu",
layer_norm_epsilon=1e-05,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.activation = keras.activations.get(activation)
self.layer_norm_epsilon = layer_norm_epsilon
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self._built = False

def _build(self, input_shape):
Expand All @@ -95,12 +105,16 @@ def _build(self, input_shape):
key_dim=self._attention_head_size,
value_dim=self._attention_head_size,
dropout=self.dropout,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._encoder_decoder_attention_layer = keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self._attention_head_size,
value_dim=feature_size,
dropout=self.dropout,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)

self._decoder_attention_layernorm = keras.layers.LayerNormalization()
Expand All @@ -114,11 +128,18 @@ def _build(self, input_shape):
# First dense layer in the feedforward network, which maps input
# feauture size to dimension `self.intermediate_dim`.
self._intermediate_dense = keras.layers.Dense(
self.intermediate_dim, activation=self.activation
self.intermediate_dim,
activation=self.activation,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
# Second dense layer in the feedforward network, which maps input
# feature size back to the input feature size.
self._output_dense = keras.layers.Dense(feature_size)
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._outputdropout = keras.layers.Dropout(rate=self.dropout)

def _add_and_norm(self, input1, input2, norm_layer):
Expand Down Expand Up @@ -219,8 +240,14 @@ def get_config(self):
"intermediate_dim": self.intermediate_dim,
"num_heads": self.num_heads,
"dropout": self.dropout,
"activation": self.activation,
"activation": keras.activations.serialize(self.activation),
"layer_norm_epsilon": self.layer_norm_epsilon,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
}
)
return config
22 changes: 22 additions & 0 deletions keras_nlp/layers/transformer_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,46 @@ def test_get_config_and_from_config(self):
decoder = transformer_decoder.TransformerDecoder(
intermediate_dim=4,
num_heads=2,
kernel_initializer="HeNormal",
bias_initializer="Zeros",
)

config = decoder.get_config()

expected_config_subset = {
"intermediate_dim": 4,
"num_heads": 2,
"dropout": 0,
"activation": "relu",
"layer_norm_epsilon": 1e-05,
"kernel_initializer": keras.initializers.serialize(
keras.initializers.HeNormal()
),
"bias_initializer": keras.initializers.serialize(
keras.initializers.Zeros()
),
}

self.assertEqual(config, {**config, **expected_config_subset})
self.assertEqual(config, {**config, **expected_config_subset})

restored_decoder = transformer_decoder.TransformerDecoder.from_config(
config,
)

self.assertEqual(
restored_decoder.get_config(), {**config, **expected_config_subset}
)

def test_value_error_when_invalid_kernel_inititalizer(self):
with self.assertRaises(ValueError):
transformer_decoder.TransformerDecoder(
intermediate_dim=4,
num_heads=2,
dropout=0.5,
kernel_initializer="Invalid",
)

def test_one_training_step_of_transformer_encoder(self):
class MyModel(keras.Model):
def __init__(self):
Expand Down
35 changes: 30 additions & 5 deletions keras_nlp/layers/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,20 @@ class TransformerEncoder(keras.layers.Layer):
activation function of feedforward network.
layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer
normalization components.
kernel_initializer: string or tf.keras.initializers initializer,
defaults to "glorot_uniform". The kernel initializer for
the dense and multiheaded attention layers.
bias_initializer: string or tf.keras.initializers initializer,
defaults to "zeros". The bias initializer for
the dense and multiheaded attention layers.
name: string, defaults to None. The name of the layer.
**kwargs: other keyword arguments.

Examples:

```python
# Create a single transformer encoder layer.
encoder = keras_nlp.layer.TransformerEncoder(
encoder = keras_nlp.layers.TransformerEncoder(
intermediate_dim=64, num_heads=8)

# Create a simple model containing the encoder.
Expand All @@ -69,15 +75,19 @@ def __init__(
dropout=0,
activation="relu",
layer_norm_epsilon=1e-05,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
name=None,
**kwargs
):
super().__init__(name=name, **kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.activation = keras.activations.get(activation)
self.layer_norm_epsilon = layer_norm_epsilon
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self._built = False

def _build(self, input_shape):
Expand All @@ -90,6 +100,8 @@ def _build(self, input_shape):
key_dim=self._attention_head_size,
value_dim=self._attention_head_size,
dropout=self.dropout,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)

self._attention_layernorm = keras.layers.LayerNormalization()
Expand All @@ -98,9 +110,16 @@ def _build(self, input_shape):
self._attentiondropout = keras.layers.Dropout(rate=self.dropout)

self._intermediate_dense = keras.layers.Dense(
self.intermediate_dim, activation=self.activation
self.intermediate_dim,
activation=self.activation,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._output_dense = keras.layers.Dense(feature_size)
self._outputdropout = keras.layers.Dropout(rate=self.dropout)

def _add_and_norm(self, input1, input2, norm_layer):
Expand Down Expand Up @@ -161,8 +180,14 @@ def get_config(self):
"intermediate_dim": self.intermediate_dim,
"num_heads": self.num_heads,
"dropout": self.dropout,
"activation": self.activation,
"activation": keras.activations.serialize(self.activation),
"layer_norm_epsilon": self.layer_norm_epsilon,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
}
)
return config
21 changes: 21 additions & 0 deletions keras_nlp/layers/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,45 @@ def test_get_config_and_from_config(self):
encoder = transformer_encoder.TransformerEncoder(
intermediate_dim=4,
num_heads=2,
kernel_initializer="HeNormal",
bias_initializer="Zeros",
)

config = encoder.get_config()

expected_config_subset = {
"intermediate_dim": 4,
"num_heads": 2,
"dropout": 0,
"activation": "relu",
"layer_norm_epsilon": 1e-05,
"kernel_initializer": keras.initializers.serialize(
keras.initializers.HeNormal()
),
"bias_initializer": keras.initializers.serialize(
keras.initializers.Zeros()
),
}

self.assertEqual(config, {**config, **expected_config_subset})

restored_encoder = transformer_encoder.TransformerEncoder.from_config(
config,
)

self.assertEqual(
restored_encoder.get_config(), {**config, **expected_config_subset}
)

def test_value_error_when_invalid_kernel_inititalizer(self):
with self.assertRaises(ValueError):
transformer_encoder.TransformerEncoder(
intermediate_dim=4,
num_heads=2,
dropout=0.5,
kernel_initializer="Invalid",
)

def test_one_training_step_of_transformer_encoder(self):
encoder = transformer_encoder.TransformerEncoder(
intermediate_dim=4,
Expand Down