Skip to content

Commit

Permalink
[Flax] Add FlaxBlenderbot (#13633)
Browse files Browse the repository at this point in the history
* Init Flax implementation for Blenderbot

* Add a majority of stuff except for tests

* make style quality

* Add tests and fix some bugs

* Add tests

* Clean source code and fix some bugs

* Fix copies and docs

* Fix jax device condition for tests

* Fix layer norm in the encoder

* Fix a few typos in the test file

* make fix-copies

* make fix-copies

* fix layer norm

* Fix Flax params dtype (#13090)

* Fix PR reference (#13098)

* make fix-copies

* Update tests/test_modeling_flax_blenderbot.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
  • Loading branch information
3 people authored Nov 30, 2021
1 parent 254fef6 commit faacd74
Show file tree
Hide file tree
Showing 13 changed files with 2,026 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BigBirdPegasus ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot ||||| |
| Blenderbot ||||| |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
14 changes: 14 additions & 0 deletions docs/source/model_doc/blenderbot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,17 @@ TFBlenderbotForConditionalGeneration

.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
:members: call


FlaxBlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxBlenderbotModel
:members: __call__, encode, decode


FlaxBlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxBlenderbotForConditionalGeneration
:members: __call__, encode, decode
8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,9 @@
"FlaxBigBirdPreTrainedModel",
]
)
_import_structure["models.blenderbot"].extend(
["FlaxBlenderbotForConditionalGeneration", "FlaxBlenderbotModel", "FlaxBlenderbotPreTrainedModel"]
)
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
Expand Down Expand Up @@ -3647,6 +3650,11 @@
FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel,
)
from .models.blenderbot import (
FlaxBlenderbotForConditionalGeneration,
FlaxBlenderbotModel,
FlaxBlenderbotPreTrainedModel,
)
from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
("blenderbot", "FlaxBlenderbotModel"),
]
)

Expand Down Expand Up @@ -89,6 +90,7 @@
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
]
)

Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def setup(self) -> None:
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
Expand All @@ -430,7 +430,7 @@ def setup(self) -> None:
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

def __call__(
self,
Expand Down Expand Up @@ -533,15 +533,15 @@ def setup(self) -> None:
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)

self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
Expand All @@ -550,7 +550,7 @@ def setup(self) -> None:
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

def __call__(
self,
Expand Down Expand Up @@ -730,7 +730,7 @@ def setup(self):
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

def __call__(
self,
Expand Down Expand Up @@ -802,7 +802,7 @@ def setup(self):
)

self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

def __call__(
self,
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/models/blenderbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
Expand Down Expand Up @@ -47,6 +47,14 @@
]


if is_flax_available():
_import_structure["modeling_flax_blenderbot"] = [
"FlaxBlenderbotForConditionalGeneration",
"FlaxBlenderbotModel",
"FlaxBlenderbotPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from .tokenization_blenderbot import BlenderbotTokenizer
Expand All @@ -70,6 +78,13 @@
TFBlenderbotPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_blenderbot import (
FlaxBlenderbotForConditionalGeneration,
FlaxBlenderbotModel,
FlaxBlenderbotPreTrainedModel,
)

else:
import sys

Expand Down
Loading

0 comments on commit faacd74

Please sign in to comment.