Skip to content

Commit

Permalink
Add FlaxEncoderDecoderModel to the library
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Aug 13, 2021
1 parent fe743aa commit cc83bfc
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder ||||| |
| Encoder decoder ||||| |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/encoderdecoder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ EncoderDecoderModel

.. autoclass:: transformers.EncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained


FlaxEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,7 @@
"FlaxElectraPreTrainedModel",
]
)
_import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel")
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.gpt_neo"].extend(
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
Expand Down Expand Up @@ -3129,6 +3130,7 @@
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class PretrainedConfig(PushToHubMixin):
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which
consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``, or within the
`:class:~transformers.FlaxEncoderDecoderModel` class, which consists of all models in
``FLAX_AUTO_MODELS_FOR_CAUSAL_LM``.
tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
and decoder model to have the exact same parameter names.
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
FlaxElectraForTokenClassification,
FlaxElectraModel,
)
from ..encoder_decoder.modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel
Expand Down Expand Up @@ -81,6 +82,7 @@
BigBirdConfig,
CLIPConfig,
ElectraConfig,
EncoderDecoderConfig,
GPT2Config,
GPTNeoConfig,
MarianConfig,
Expand Down Expand Up @@ -150,6 +152,7 @@
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxMT5ForConditionalGeneration),
(MarianConfig, FlaxMarianMTModel),
(EncoderDecoderConfig, FlaxEncoderDecoderModel),
]
)

Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/encoder_decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available


_import_structure = {
Expand All @@ -28,13 +28,18 @@
if is_torch_available():
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]

if is_flax_available():
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]

if TYPE_CHECKING:
from .configuration_encoder_decoder import EncoderDecoderConfig

if is_torch_available():
from .modeling_encoder_decoder import EncoderDecoderModel

if is_flax_available():
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel

else:
import sys

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
class EncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.EncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to the
specified arguments, defining the encoder and decoder configs.
:class:`~transformers.EncoderDecoderModel` or a :class:`~transformers.FlaxEncoderDecoderModel`. It is used to
instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder
configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxGPT2LMHeadModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
Expand Down

0 comments on commit cc83bfc

Please sign in to comment.