Skip to content

Commit

Permalink
Updates to model API (#561)
Browse files Browse the repository at this point in the history
Summary:
- `FairseqModel` -> `FairseqEncoderDecoderModel`
- add `FairseqDecoder.extract_features` and `FairseqDecoder.output_layer`
- `encoder_out_dict` -> `encoder_out`
- rm unused `remove_head` functions
- update docs
Pull Request resolved: fairinternal/fairseq-py#561

Differential Revision: D15271142

Pulled By: myleott

fbshipit-source-id: 8e8864e399336020f0271c780598e968ff51a264
  • Loading branch information
myleott authored and facebook-github-bot committed May 15, 2019
1 parent a0c5f9b commit dffb167
Show file tree
Hide file tree
Showing 16 changed files with 207 additions and 110 deletions.
8 changes: 7 additions & 1 deletion docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,18 @@ Adding new models
.. autoclass:: fairseq.models.BaseFairseqModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqModel
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqLanguageModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqMultiModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoder
:members:
.. autoclass:: fairseq.models.CompositeEncoder
Expand Down
2 changes: 1 addition & 1 deletion docs/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Modules
=======

Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.

.. automodule:: fairseq.modules
:members:
Expand Down
2 changes: 1 addition & 1 deletion docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ New plug-ins are *registered* through a set of ``@register`` function
decorators, for example::

@register_model('my_lstm')
class MyLSTM(FairseqModel):
class MyLSTM(FairseqEncoderDecoderModel):
(...)

Once registered, new plug-ins can be used with the existing :ref:`Command-line
Expand Down
14 changes: 7 additions & 7 deletions docs/tutorial_simple_lstm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ Tutorial: Simple LSTM
=====================

In this tutorial we will extend fairseq by adding a new
:class:`~fairseq.models.FairseqModel` that encodes a source sentence with an
LSTM and then passes the final hidden state to a second LSTM that decodes the
target sentence (without attention).
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
sentence with an LSTM and then passes the final hidden state to a second LSTM
that decodes the target sentence (without attention).

This tutorial covers:

Expand Down Expand Up @@ -233,18 +233,18 @@ Once the model is registered we'll be able to use it with the existing
All registered models must implement the
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead
implement the :class:`~fairseq.models.FairseqModel` interface.
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.

Create a small wrapper class in the same file and register it in fairseq with
the name ``'simple_lstm'``::

from fairseq.models import FairseqModel, register_model
from fairseq.models import FairseqEncoderDecoderModel, register_model

# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.

@register_model('simple_lstm')
class SimpleLSTMModel(FairseqModel):
class SimpleLSTMModel(FairseqEncoderDecoderModel):

@staticmethod
def add_args(parser):
Expand Down Expand Up @@ -308,7 +308,7 @@ the name ``'simple_lstm'``::
# We could override the ``forward()`` if we wanted more control over how
# the encoder and decoder interact, but it's not necessary for this
# tutorial since we can inherit the default implementation provided by
# the FairseqModel base class, which looks like:
# the FairseqEncoderDecoderModel base class, which looks like:
#
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
# encoder_out = self.encoder(src_tokens, src_lengths)
Expand Down
13 changes: 8 additions & 5 deletions fairseq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqEncoderModel,
FairseqEncoderDecoderModel,
FairseqLanguageModel,
FairseqModel,
FairseqMultiModel,
FairseqLanguageModel,
FairseqEncoderModel,
)

from .composite_encoder import CompositeEncoder
Expand All @@ -30,6 +31,7 @@
'DistributedFairseqModel',
'FairseqDecoder',
'FairseqEncoder',
'FairseqEncoderDecoderModel',
'FairseqEncoderModel',
'FairseqIncrementalDecoder',
'FairseqLanguageModel',
Expand All @@ -56,12 +58,13 @@ def register_model(name):
For example::
@register_model('lstm')
class LSTM(FairseqModel):
class LSTM(FairseqEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqModel` for sequence-to-sequence
tasks or :class:`FairseqLanguageModel` for language modeling tasks.
Typically you will extend :class:`FairseqEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model
Expand Down
36 changes: 27 additions & 9 deletions fairseq/models/fairseq_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,40 @@ def __init__(self, dictionary):
self.dictionary = dictionary
self.onnx_trace = False

def forward(self, prev_output_tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape
`(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape
`(batch, tgt_len, src_len)`
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
x = self.output_layer(x)
return x, extra

def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError

def prepare_for_onnx_export_(self):
self.onnx_trace = True
def output_layer(self, features, **kwargs):
"""
Project features to the default output size, e.g., vocabulary size.
Args:
features (Tensor): features returned by *extract_features*.
"""
raise NotImplementedError

def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
Expand All @@ -63,3 +78,6 @@ def max_positions(self):
def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict

def prepare_for_onnx_export_(self):
self.onnx_trace = True
2 changes: 1 addition & 1 deletion fairseq/models/fairseq_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary

def forward(self, src_tokens, src_lengths):
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
Expand Down
29 changes: 18 additions & 11 deletions fairseq/models/fairseq_incremental_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model
only receives a single timestep of input corresponding to the immediately
previous output token (for input feeding) and must produce the next output
only receives a single timestep of input corresponding to the previous
output token (for input feeding) and must produce the next output
*incrementally*. Thus the model must cache any long-term state that is
needed about the sequence, e.g., hidden states, convolutional states, etc.
Expand All @@ -33,22 +33,29 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)

def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
incremental_state (dict, optional): dictionary used for storing
state during :ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError

def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError

Expand Down

0 comments on commit dffb167

Please sign in to comment.