Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
add PyText Embedding TorchScript Wrapper
Browse files Browse the repository at this point in the history
Summary: add PyText Embedding TorchScript Wrapper

Reviewed By: hudeven

Differential Revision: D20350097

fbshipit-source-id: e4dec2963f5f92019a659def9fb0a4e55494943f
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Mar 10, 2020
1 parent 8432792 commit 1b74f33
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions pytext/torchscript/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,62 @@ def forward(
input_tensors = self.tensorizer(inputs)
logits = self.model(input_tensors, torch.tensor(dense_feat).float())
return self.output_layer(logits)


class ScriptPyTextEmbeddingModule(ScriptModule):
def __init__(
self,
model: torch.jit.ScriptModule,
tensorizer: ScriptTensorizer,
index: int = 0,
):
super().__init__()
self.model = model
self.tensorizer = tensorizer
self.index = torch.jit.Attribute(index, int)

@torch.jit.script_method
def forward(
self,
texts: Optional[List[str]] = None,
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
) -> torch.Tensor:
inputs: ScriptBatchInput = ScriptBatchInput(
texts=squeeze_1d(texts),
tokens=squeeze_2d(tokens),
languages=squeeze_1d(languages),
)
input_tensors = self.tensorizer(inputs)
# call model
return self.model(input_tensors)[self.index]


class ScriptPyTextEmbeddingModuleWithDense(ScriptModule):
def __init__(
self,
model: torch.jit.ScriptModule,
tensorizer: ScriptTensorizer,
index: int = 0,
):
super().__init__()
self.model = model
self.tensorizer = tensorizer
self.index = torch.jit.Attribute(index, int)

@torch.jit.script_method
def forward(
self,
dense_feat: List[List[float]],
texts: Optional[List[str]] = None,
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
) -> torch.Tensor:
inputs: ScriptBatchInput = ScriptBatchInput(
texts=squeeze_1d(texts),
tokens=squeeze_2d(tokens),
languages=squeeze_1d(languages),
)
input_tensors = self.tensorizer(inputs)
# call model
return self.model(input_tensors, torch.tensor(dense_feat).float())[self.index]

0 comments on commit 1b74f33

Please sign in to comment.