Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Unify model input for ByteTokensDocumentModel (#1274)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1274

https://fb.workplace.com/groups/300451907202972/permalink/589822794932547/

Reviewed By: codekansas

Differential Revision: D20367268

fbshipit-source-id: 8f5c7af1b8ff8247cdf1d1c0949aa2ab1620752b
  • Loading branch information
hudeven authored and facebook-github-bot committed Mar 10, 2020
1 parent 1b74f33 commit b745425
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions pytext/models/doc_model.py
Expand Up @@ -153,14 +153,13 @@ def forward(
):
if tokens is None:
raise RuntimeError("tokens is required")
if dense_feat is None:
raise RuntimeError("dense_feat is required")

seq_lens = make_sequence_lengths(tokens)
word_ids = self.vocab.lookup_indices_2d(tokens)
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
if dense_feat is not None:
dense_feat = self.normalizer.normalize(dense_feat)
else:
raise RuntimeError("dense is required")
dense_feat = self.normalizer.normalize(dense_feat)
logits = self.model(
torch.tensor(word_ids),
torch.tensor(seq_lens),
Expand Down Expand Up @@ -281,7 +280,14 @@ def __init__(self):
self.output_layer = output_layer

@jit.script_method
def forward(self, tokens: List[List[str]]):
def forward(
self,
texts: Optional[List[str]] = None,
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
if tokens is None:
raise RuntimeError("tokens is required")
seq_lens = make_sequence_lengths(tokens)
word_ids = self.vocab.lookup_indices_2d(tokens)
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
Expand All @@ -307,7 +313,18 @@ def __init__(self):
self.output_layer = output_layer

@jit.script_method
def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]):
def forward(
self,
texts: Optional[List[str]] = None,
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
dense_feat: Optional[List[List[float]]] = None,
):
if tokens is None:
raise RuntimeError("tokens is required")
if dense_feat is None:
raise RuntimeError("dense_feat is required")

seq_lens = make_sequence_lengths(tokens)
word_ids = self.vocab.lookup_indices_2d(tokens)
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
Expand Down

0 comments on commit b745425

Please sign in to comment.