Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Decouple decoder and output layer creation in BasePairwiseModel (#973)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #973

We create both the decoder and output layer in `BasePairwiseModel._create_decoder()`. This diff moves output layer creation out of `_create_decoder()` function.

Reviewed By: borguz

Differential Revision: D17318025

fbshipit-source-id: 556c988831d6dcf0788c1a61efebfb8f8e815940
  • Loading branch information
Kushal Lakhotia authored and facebook-github-bot committed Sep 12, 2019
1 parent 4816c75 commit 18a3cb7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pytext/models/bert_classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ def _create_encoder(
@classmethod
def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
encoder1, encoder2 = cls._create_encoder(config, tensorizers)
decoder, output_layer = cls._create_decoder(
config, [encoder1, encoder2], tensorizers
decoder = cls._create_decoder(config, [encoder1, encoder2], tensorizers)
output_layer = create_module(
config.output_layer, labels=tensorizers["labels"].vocab
)
return cls(encoder1, encoder2, decoder, output_layer, config.encode_relations)

Expand Down
8 changes: 4 additions & 4 deletions pytext/models/pair_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def _create_decoder(
decoder = create_module(
config.decoder, in_dim=decoder_in_dim, out_dim=len(labels)
)
output_layer = create_module(config.output_layer, labels=labels)
return decoder, output_layer
return decoder

@classmethod
def _encode_relations(cls, encodings: List[torch.Tensor]) -> List[torch.Tensor]:
Expand Down Expand Up @@ -205,8 +204,9 @@ def _create_representations(cls, config: Config, embeddings: nn.ModuleList):
def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
embeddings = cls._create_embeddings(config, tensorizers)
representations = cls._create_representations(config, embeddings)
decoder, output_layer = cls._create_decoder(
config, representations, tensorizers
decoder = cls._create_decoder(config, representations, tensorizers)
output_layer = create_module(
config.output_layer, labels=tensorizers["labels"].vocab
)
return cls(
embeddings, representations, decoder, output_layer, config.encode_relations
Expand Down

0 comments on commit 18a3cb7

Please sign in to comment.