Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

enabled lmlstm labels exporting #767

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pytext/models/language_models/lmlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytext.config import ConfigBase
from pytext.data import CommonMetadata
from pytext.data.tensorizers import Tensorizer, TokenTensorizer
from pytext.exporters.exporter import ModelExporter
from pytext.models.decoders import DecoderBase
from pytext.models.decoders.mlp_decoder import MLPDecoder
from pytext.models.embeddings import EmbeddingBase
Expand Down Expand Up @@ -235,13 +236,29 @@ def get_export_output_names(self, tensorizers):
def vocab_to_export(self, tensorizers):
return {"tokens": list(tensorizers["tokens"].vocab)}

def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
exporter = ModelExporter(
ModelExporter.Config(),
self.get_export_input_names(tensorizers),
self.arrange_model_inputs(tensor_dict),
self.vocab_to_export(tensorizers),
self.get_export_output_names(tensorizers),
)
return exporter.export_to_caffe2(self, path, export_onnx_path=export_onnx_path)

def forward(
self, tokens: torch.Tensor, seq_len: torch.Tensor
) -> List[torch.Tensor]:
token_emb = self.embedding(tokens)
if self.stateful and self._states is None:
self._states = self.init_hidden(tokens.size(0))

if torch.onnx.is_in_onnx_export():
token_emb = token_emb.cpu()
seq_len = seq_len.cpu()
if self.stateful:
self._states = (self._states[0].cpu(), self._states[1].cpu())

rep, states = self.representation(token_emb, seq_len, states=self._states)
if self.decoder is None:
output = rep
Expand Down
15 changes: 15 additions & 0 deletions pytext/models/output_layers/lm_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import torch
import torch.nn.functional as F
from caffe2.python import core
from pytext.config.component import create_loss
from pytext.data.utils import PAD, Vocabulary
from pytext.fields import FieldMeta
from pytext.loss import CrossEntropyLoss, Loss

from .output_layer_base import OutputLayerBase
from .utils import OutputLayerUtils


class LMOutputLayer(OutputLayerBase):
Expand Down Expand Up @@ -107,6 +109,19 @@ def get_pred(
scores = F.log_softmax(logit, 2)
return preds, scores

def export_to_caffe2(
self,
workspace: core.workspace,
init_net: core.Net,
predict_net: core.Net,
model_out: torch.Tensor,
output_name: str,
) -> List[core.BlobReference]:
prob_out = predict_net.Softmax(output_name, axis=model_out.dim() - 1)
return OutputLayerUtils.gen_additional_blobs(
predict_net, prob_out, model_out, output_name, self.target_names
)

@staticmethod
def calculate_perplexity(sequence_loss: torch.Tensor) -> torch.Tensor:
return torch.exp(sequence_loss)