Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Add custom metric class for reporting Joint model metrics #1339

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytext/data/tensorizers.py
Expand Up @@ -853,8 +853,8 @@ def numberize(self, row):
label_idx_list.append(self.pad_idx)
else:
raise Exception(
"Found none or empty value in the list,"
+ " while pad_missing is disabled"
"Found none or empty value in the list, \
while pad_missing is disabled"
)
else:
label_idx_list.append(self.vocab.lookup_all(label))
Expand Down
2 changes: 2 additions & 0 deletions pytext/metric_reporters/__init__.py
Expand Up @@ -15,6 +15,7 @@
from .regression_metric_reporter import RegressionMetricReporter
from .squad_metric_reporter import SquadMetricReporter
from .word_tagging_metric_reporter import (
MultiLabelSequenceTaggingMetricReporter,
NERMetricReporter,
SequenceTaggingMetricReporter,
WordTaggingMetricReporter,
Expand All @@ -26,6 +27,7 @@
"MetricReporter",
"ClassificationMetricReporter",
"MultiLabelClassificationMetricReporter",
"MultiLabelSequenceTaggingMetricReporter",
"RegressionMetricReporter",
"IntentSlotMetricReporter",
"LanguageModelMetricReporter",
Expand Down
67 changes: 67 additions & 0 deletions pytext/metric_reporters/word_tagging_metric_reporter.py
Expand Up @@ -11,8 +11,10 @@
AllConfusions,
Confusions,
LabelPrediction,
MultiLabelSoftClassificationMetrics,
PRF1Metrics,
compute_classification_metrics,
compute_multi_label_multi_class_soft_metrics,
)
from pytext.metrics.intent_slot_metrics import (
Node,
Expand All @@ -26,6 +28,9 @@
from .metric_reporter import MetricReporter


NAN_LABELS = ["__UNKNOWN__", "__PAD__"]


def get_slots(word_names):
slots = {
Node(label=slot.label, span=Span(slot.start, slot.end))
Expand Down Expand Up @@ -92,6 +97,68 @@ def get_model_select_metric(self, metrics):
return metrics.micro_scores.f1


class MultiLabelSequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
super().__init__(channels)
self.label_names = label_names
# Right now the assumption is that we use the same pad idx for all
# labels. #TODO Extend it to use multiple label specific pad idxs
self.pad_idx = pad_idx
self.label_vocabs = label_vocabs

@classmethod
def from_config(cls, config, tensorizers):
return MultiLabelSequenceTaggingMetricReporter(
channels=[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)],
label_names=tensorizers.keys(),
pad_idx=[v.pad_idx for _, v in tensorizers.items()],
label_vocabs=[v.vocab._vocab for _, v in tensorizers.items()],
)

def calculate_metric(self):
list_score_pred_expect = []
for label_idx in range(0, len(self.label_names)):
list_score_pred_expect.append(
list(
itertools.chain.from_iterable(
(
LabelPrediction(s, p, e)
for s, p, e in zip(scores, pred, expect)
if e != self.pad_idx[label_idx]
)
for scores, pred, expect in zip(
self.all_scores[label_idx],
self.all_preds[label_idx],
self.all_targets[label_idx],
)
)
)
)
metrics = compute_multi_label_multi_class_soft_metrics(
list_score_pred_expect,
self.label_names,
self.label_vocabs,
self.calculate_loss(),
)
return metrics

def batch_context(self, raw_batch, batch):
return {}

@staticmethod
def get_model_select_metric(metrics):
if isinstance(metrics, MultiLabelSoftClassificationMetrics):
# There are multiclass precision/recall labels
# Compute average precision
normalize_count = sum(1 for k in metrics.average_precision.keys()) * 1.0
avg_precision = (
sum(v for k, v in metrics.average_precision.items()) / normalize_count
)
else:
avg_precision = metrics.accuracy
return avg_precision


class SequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels):
super().__init__(channels)
Expand Down
66 changes: 66 additions & 0 deletions pytext/metrics/__init__.py
Expand Up @@ -21,6 +21,7 @@
from pytext.utils.ascii_table import ascii_table


NAN_LABELS = ["__UNKNOWN__", "__PAD__"]
RECALL_AT_PRECISION_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9]
PRECISION_AT_RECALL_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9]

Expand Down Expand Up @@ -95,6 +96,19 @@ class SoftClassificationMetrics(NamedTuple):
roc_auc: Optional[float]


class MultiLabelSoftClassificationMetrics(NamedTuple):
"""
Classification scores that are independent of thresholds.
"""

average_precision: Dict[str, float]
recall_at_precision: Dict[str, Dict[str, Dict[float, float]]]
decision_thresh_at_precision: Dict[str, Dict[str, Dict[float, float]]]
precision_at_recall: Dict[str, Dict[str, Dict[float, float]]]
decision_thresh_at_recall: Dict[str, Dict[str, Dict[float, float]]]
roc_auc: Optional[Dict[Optional[str], Optional[Dict[str, Optional[float]]]]]


class MacroPRF1Scores(NamedTuple):
"""
Macro precision/recall/F1 scores (averages across each label).
Expand Down Expand Up @@ -753,6 +767,58 @@ def compute_multi_label_soft_metrics(
return soft_metrics


def compute_multi_label_multi_class_soft_metrics(
predictions: Sequence[Sequence[LabelListPrediction]],
label_names: Sequence[str],
label_vocabs: Sequence[Sequence[str]],
loss: float,
recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS,
precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS,
) -> MultiLabelSoftClassificationMetrics:
"""

Computes multi-label soft classification metrics with multi-class accommodation

Args:
predictions: multi-label predictions,
including the confidence score for each label.
label_names: Indexed label names.
recall_at_precision_thresholds: precision thresholds at which to calculate
recall
precision_at_recall_thresholds: recall thresholds at which to calculate
precision


Returns:
Dict from label strings to their corresponding soft metrics.
"""
soft_metrics = MultiLabelSoftClassificationMetrics({}, {}, {}, {}, {}, {})
for label_idx, label_vocab in enumerate(label_vocabs):
label = list(label_names)[label_idx]
soft_metrics_ = compute_soft_metrics(predictions[label_idx], label_vocab)
temp_avg_precision_ = {k: v.average_precision for k, v in soft_metrics_.items()}
soft_metrics.average_precision[label] = sum(
v for k, v in temp_avg_precision_.items() if k not in NAN_LABELS
) / (
sum(1 for k, v in temp_avg_precision_.items() if k not in NAN_LABELS) * 1.0
)
soft_metrics.recall_at_precision[label] = {
k: v.recall_at_precision for k, v in soft_metrics_.items()
}
soft_metrics.decision_thresh_at_precision[label] = {
k: v.decision_thresh_at_precision for k, v in soft_metrics_.items()
}
soft_metrics.precision_at_recall[label] = {
k: v.precision_at_recall for k, v in soft_metrics_.items()
}
soft_metrics.decision_thresh_at_recall[label] = {
k: v.decision_thresh_at_recall for k, v in soft_metrics_.items()
}
soft_metrics.roc_auc[label] = {k: v.roc_auc for k, v in soft_metrics_.items()}

return soft_metrics


def compute_matthews_correlation_coefficients(
TP: int, FP: int, FN: int, TN: int
) -> float:
Expand Down
64 changes: 64 additions & 0 deletions pytext/models/decoders/multilabel_decoder.py
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List

import torch
import torch.nn as nn
from pytext.utils.usage import log_class_usage

from .decoder_base import DecoderBase


class MultiLabelDecoder(DecoderBase):
"""
Implements a 'n-tower' MLP: one for each of the multi labels
Used in USM/EA: the user satisfaction modeling, pTSR prediction and
Error Attribution are all 3 label sets that need predicting.

"""

class Config(DecoderBase.Config):
# Intermediate hidden dimensions
hidden_dims: List[int] = []

def __init__(
self,
config: Config,
in_dim: int,
output_dim: Dict[str, int],
label_names: List[str],
) -> None:
super().__init__(config)
self.label_mlps = nn.ModuleDict({})
# Store the ordered list to preserve the ordering of the labels
# when generating the output layer
self.label_names = label_names
aggregate_out_dim = 0
for label_, _ in output_dim.items():
self.label_mlps[label_] = MultiLabelDecoder.get_mlp(
in_dim, output_dim[label_], config.hidden_dims
)
aggregate_out_dim += output_dim[label_]
self.out_dim = (1, aggregate_out_dim)
log_class_usage(__class__)

@staticmethod
def get_mlp(in_dim: int, out_dim: int, hidden_dims: List[int]):
layers = []
current_dim = in_dim
for dim in hidden_dims or []:
layers.append(nn.Linear(current_dim, dim))
layers.append(nn.ReLU())
current_dim = dim
layers.append(nn.Linear(current_dim, out_dim))
return nn.Sequential(*layers)

def forward(self, *input: torch.Tensor):
logits = tuple(
self.label_mlps[x](torch.cat(input, 1)) for x in self.label_names
)
return logits

def get_decoder(self) -> List[nn.Module]:
return self.label_mlps
Expand Up @@ -128,7 +128,7 @@ def forward(self, logits: torch.Tensor):
class BinaryClassificationOutputLayer(ClassificationOutputLayer):
def get_pred(self, logit, *args, **kwargs):
"""See `OutputLayerBase.get_pred()`."""
preds = torch.max(logit, 1)[1]
preds = torch.max(logit, -1)[1]
scores = F.logsigmoid(logit)
return preds, scores

Expand All @@ -153,7 +153,7 @@ def export_to_caffe2(
class MulticlassOutputLayer(ClassificationOutputLayer):
def get_pred(self, logit, *args, **kwargs):
"""See `OutputLayerBase.get_pred()`."""
preds = torch.max(logit, 1)[1]
preds = torch.max(logit, -1)[1]
scores = F.log_softmax(logit, 1)
return preds, scores

Expand Down