Skip to content

Commit

Permalink
FairseqEncoderModel
Browse files Browse the repository at this point in the history
Summary: Base class for encoder-only models. Some models doesn't have decoder part.

Reviewed By: myleott

Differential Revision: D14413406

fbshipit-source-id: f36473b91dcf3c835fd6d50e2eb6002afa75f11a
  • Loading branch information
Dmytro Okhonko authored and facebook-github-bot committed Mar 12, 2019
1 parent 7fc9a3b commit 9e1c880
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
1 change: 1 addition & 0 deletions fairseq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FairseqModel, # noqa: F401
FairseqMultiModel, # noqa: F401
FairseqLanguageModel, # noqa: F401
FairseqEncoderModel, # noqa: F401
)

from .composite_encoder import CompositeEncoder # noqa: F401
Expand Down
40 changes: 40 additions & 0 deletions fairseq/models/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,43 @@ def supported_targets(self):
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()


class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
Args:
encoder (FairseqEncoder): the encoder
"""

def __init__(self, encoder):
super().__init__()
self.encoder = encoder
assert isinstance(self.encoder, FairseqEncoder)

def forward(self, src_tokens, src_lengths, **kwargs):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate logits.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, seq_len, vocab)`
"""
return self.encoder(src_tokens, src_lengths)

def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()

@property
def supported_targets(self):
return {'future'}

def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()

0 comments on commit 9e1c880

Please sign in to comment.