From 1b74f33991e891d5f8bc794e4723e9243f63be73 Mon Sep 17 00:00:00 2001 From: Chenyang Yu Date: Mon, 9 Mar 2020 17:15:37 -0700 Subject: [PATCH] add PyText Embedding TorchScript Wrapper Summary: add PyText Embedding TorchScript Wrapper Reviewed By: hudeven Differential Revision: D20350097 fbshipit-source-id: e4dec2963f5f92019a659def9fb0a4e55494943f --- pytext/torchscript/module.py | 59 ++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index 1e18f6c23..9a94df20e 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -59,3 +59,62 @@ def forward( input_tensors = self.tensorizer(inputs) logits = self.model(input_tensors, torch.tensor(dense_feat).float()) return self.output_layer(logits) + + +class ScriptPyTextEmbeddingModule(ScriptModule): + def __init__( + self, + model: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + index: int = 0, + ): + super().__init__() + self.model = model + self.tensorizer = tensorizer + self.index = torch.jit.Attribute(index, int) + + @torch.jit.script_method + def forward( + self, + texts: Optional[List[str]] = None, + tokens: Optional[List[List[str]]] = None, + languages: Optional[List[str]] = None, + ) -> torch.Tensor: + inputs: ScriptBatchInput = ScriptBatchInput( + texts=squeeze_1d(texts), + tokens=squeeze_2d(tokens), + languages=squeeze_1d(languages), + ) + input_tensors = self.tensorizer(inputs) + # call model + return self.model(input_tensors)[self.index] + + +class ScriptPyTextEmbeddingModuleWithDense(ScriptModule): + def __init__( + self, + model: torch.jit.ScriptModule, + tensorizer: ScriptTensorizer, + index: int = 0, + ): + super().__init__() + self.model = model + self.tensorizer = tensorizer + self.index = torch.jit.Attribute(index, int) + + @torch.jit.script_method + def forward( + self, + dense_feat: List[List[float]], + texts: Optional[List[str]] = None, + tokens: Optional[List[List[str]]] = None, + languages: Optional[List[str]] = None, + ) -> torch.Tensor: + inputs: ScriptBatchInput = ScriptBatchInput( + texts=squeeze_1d(texts), + tokens=squeeze_2d(tokens), + languages=squeeze_1d(languages), + ) + input_tensors = self.tensorizer(inputs) + # call model + return self.model(input_tensors, torch.tensor(dense_feat).float())[self.index]