Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Support torchscriptify in multi_label_classification_layer
Browse files Browse the repository at this point in the history
Summary: Refactored multi_label_classification_layer to support jit export

Reviewed By: shivanipoddariiith

Differential Revision: D21033328

fbshipit-source-id: 1cab210ee62da8bfc561054c2781c0b6372508d2
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed May 5, 2020
1 parent 27d27b1 commit d5ea8d2
Showing 1 changed file with 51 additions and 76 deletions.
127 changes: 51 additions & 76 deletions pytext/models/output_layers/multi_label_classification_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn as nn
from caffe2.python import core
from pytext.data.tensorizers import Tensorizer
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage
from torch import jit

Expand All @@ -17,18 +16,22 @@


class MultiLabelClassificationScores(nn.Module):
def __init__(self, scores: jit.ScriptModule):
def __init__(self, scores: List[jit.ScriptModule]):
super().__init__()
self.scores = scores
self.scores = nn.ModuleList(scores)
log_class_usage(__class__)

def forward(
self, logits: List[torch.Tensor]
) -> Tuple[List[List[Dict[str, float]]]]:
def forward(self, logits: List[torch.Tensor]) -> List[List[Dict[str, float]]]:

results: List[List[Dict[str, float]]] = []
for idx, sc in enumerate(self.scores):
logit = logits[idx]
# flatten from [batch_size, ..., label_set_size] to
# [batch_size, label_set_size]
# must flatten because jit doesn't support dynamic return type
flattened_logit = logit.view(-1, logit.size()[-1])
results.append(sc(flattened_logit))

results = []
for logit_ in logits:
results.append(self.scores(logit_))
return results


Expand All @@ -37,61 +40,35 @@ class MultiLabelClassificationLayer(OutputLayerBase):
Output layer for multilabel sequence classification models.
Args:
output (List[WordTaggingOutputLayer]): Output for multilabels, here
USM + PTSR + EA task.
label_names (List[str]): Ordered list of labels predicted through the model
for which the losses need to be aggregated by the output layer
label_tensorizer (Dict[str, LabelListTensorizer]): Dict of list of labels
that constitute the output from the decoder ordered by label_names
sequencing
outputs (Dict[str, ClassificationOutputLayer]): Output for multilabels
optional label_weights (Dict[str, int]): Dict of label_names along with the
weight for label
Attributes:
output (type): Output layer for multilabel-multiclass classification task
label_names (type): List of labels to be predicted by the model
label_tensorizer (type): Dict of key-label names with values-tensorizers
used to compute the size of the label vocab
optional label_weights (type): Dict of label-weight to compute weighted
output layer
"""

class Config(OutputLayerBase.Config):
output: List[ClassificationOutputLayer.Config] = []
label_weights: Dict[str, float] = {}
outputs: List[ClassificationOutputLayer.Config] = []
label_set_weights: Dict[str, float] = {}

@classmethod
def from_config(
cls,
config: Config,
label_tensorizers: [Dict[str, Tensorizer]],
label_names: [List[str]],
):
modules = []
for label_idx in range(0, len(label_names)):
label_ = label_names[label_idx]
modules.append(
create_module(
config.output[label_idx], labels=label_tensorizers[label_].vocab
)
def from_config(cls, config: Config, label_tensorizers: [Dict[str, Tensorizer]]):
modules = {
name: ClassificationOutputLayer.from_config(
config.outputs[idx], labels=tensorizer.vocab
)
print("Created Modules", len(modules))
return cls(modules, label_names, config.label_weights)
for idx, (name, tensorizer) in enumerate(label_tensorizers.items())
}

return cls(modules, config.label_set_weights)

def __init__(
self,
output: List[ClassificationOutputLayer],
label_names: List[str],
label_weights: Optional[Dict[str, float]] = None,
outputs: Dict[str, ClassificationOutputLayer],
label_set_weights: Optional[Dict[str, float]] = None,
) -> None:
super().__init__()
self.output = output
self.label_names = label_names
self.label_weights = label_weights
self.outputs = outputs
self.num_label_sets = len(outputs)
self.label_set_weights = label_set_weights
log_class_usage(__class__)

def get_loss(
Expand All @@ -105,30 +82,28 @@ def get_loss(
"""Compute and return the averaged intent and slot-filling loss.
Args:
logits (List[tuple[torch.Tensor]]): Logits returned by
logits (List[torch.Tensor]): Logits returned by
:class:`~pytext.models.decoders.MultiLabelDecoder`. It's list
containing logits for all label tasks here pTSR, autoUSM and EA.
targets (List[tuple[torch.Tensor]]): Targets as computed by the true labels
optional label_weights (Dict[str, float]): Label weights for multi-label
ordering of logits corresponding the respective label.
targets (Optional[torch.Tensor]): Not applicable. Defaults to None.
targets (List[torch.Tensor]): Targets as computed by the true labels
context (Optional[torch.Tensor]): Not applicable. Defaults to None.
Returns:
torch.Tensor: Averaged Loss across all label losses.
"""
loss = 0
for label_idx, label_name in enumerate(self.label_names):
logit = logits[label_idx]
# [batch_size * seq_lens, dim]
total_loss = 0
for logit, target, (label_name, output_layer) in zip(
logits, targets, self.outputs.items()
):
# flatten from [batch_size, ..., label_set_size] to
# [batch_size, label_set_size]
flattened_logit = logit.view(-1, logit.size()[-1])
loss += self.output[label_idx].get_loss(
flattened_logit, targets[label_idx].view(-1), None
)
if self.label_weights:
weight = self.label_weights[label_name]
loss = torch.mean(torch.mul(loss, weight))
loss = loss / (len(self.label_names) * 1.0)
return loss
loss = output_layer.get_loss(flattened_logit, target.view(-1))
if label_name in self.label_set_weights:
weight = self.label_set_weights[label_name]
loss = torch.mul(loss, weight)
total_loss += loss
return total_loss / (self.num_label_sets * 1.0)

def get_pred(
self,
Expand All @@ -149,9 +124,9 @@ def get_pred(
logits (List[torch.Tensor]): Logits returned by
:class:`~pytext.models.decoders.MultiLabelDecoder`. It's list
containing logits for all label tasks here pTSR, autoUSM and EA.
targets (List[tuple[torch.Tensor]]): Targets as computed by the true labels
targets (List[torch.Tensor]): Targets as computed by the true labels
ordering of logits corresponding the respective label.
targets (Optional[torch.Tensor]): Not applicable. Defaults to None.
context (Optional[torch.Tensor]): Not applicable. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores.
Expand All @@ -160,8 +135,8 @@ def get_pred(
# label predictions ordered by label list.
scores = []
preds = []
for label_idx, logit in enumerate(logits):
pred, score = self.output[label_idx].get_pred(logit)
for output_layer, logit in zip(self.outputs.values(), logits):
pred, score = output_layer.get_pred(logit)
preds.append(pred)
scores.append(score)
return (preds, scores)
Expand All @@ -182,13 +157,13 @@ def export_to_caffe2(
return functools.reduce(
operator.add,
[
self.output[idx].export_to_caffe2(
workspace, init_net, predict_net, model_out[idx], out_name
output_layer.export_to_caffe2(
workspace, init_net, predict_net, single_output, out_name
)
for idx, single_output in enumerate(model_out)
for output_layer, single_output in zip(self.outputs, model_out)
],
)

def torchscript_predictions(self):
scores = self.output.torchscript_predictions()
scores = [o.torchscript_predictions() for o in self.outputs.values()]
return jit.script(MultiLabelClassificationScores(scores))

0 comments on commit d5ea8d2

Please sign in to comment.