Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add custom metric class for reporting Joint model metrics (#1339)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1339

Adding a multilabel metric class to support reporting all multi label and multi class metrics joint pytext models

Reviewed By: seayoung1112

Differential Revision: D21077306

fbshipit-source-id: 3b0938f67cd0d658af567eaca89c8afcc88a0aa8
  • Loading branch information
shivanipods authored and facebook-github-bot committed Apr 29, 2020
1 parent 7b28e8f commit 447c799
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
23 changes: 11 additions & 12 deletions pytext/metric_reporters/word_tagging_metric_reporter.py
Expand Up @@ -11,6 +11,7 @@
AllConfusions,
Confusions,
LabelPrediction,
MultiLabelSoftClassificationMetrics,
PRF1Metrics,
compute_classification_metrics,
compute_multi_label_multi_class_soft_metrics,
Expand All @@ -27,6 +28,9 @@
from .metric_reporter import MetricReporter


NAN_LABELS = ["__UNKNOWN__", "__PAD__"]


def get_slots(word_names):
slots = {
Node(label=slot.label, span=Span(slot.start, slot.end))
Expand Down Expand Up @@ -97,6 +101,8 @@ class MultiLabelSequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
super().__init__(channels)
self.label_names = label_names
# Right now the assumption is that we use the same pad idx for all
# labels. #TODO Extend it to use multiple label specific pad idxs
self.pad_idx = pad_idx
self.label_vocabs = label_vocabs

Expand All @@ -110,8 +116,6 @@ def from_config(cls, config, tensorizers):
)

def calculate_metric(self):
if len(self.all_scores) == 0:
return {}
list_score_pred_expect = []
for label_idx in range(0, len(self.label_names)):
list_score_pred_expect.append(
Expand Down Expand Up @@ -143,18 +147,13 @@ def batch_context(self, raw_batch, batch):

@staticmethod
def get_model_select_metric(metrics):
if isinstance(metrics, dict):
if isinstance(metrics, MultiLabelSoftClassificationMetrics):
# There are multiclass precision/recall labels
# Compute average precision
avg_precision = 0.0
for _, metric in metrics.items():
if metric:
avg_precision += sum(
v.average_precision
for k, v in metric.items()
if v.average_precision > 0
) / (len(metric.keys()) * 1.0)
avg_precision = avg_precision / (len(metrics.keys()) * 1.0)
normalize_count = sum(1 for k in metrics.average_precision.keys()) * 1.0
avg_precision = (
sum(v for k, v in metrics.average_precision.items()) / normalize_count
)
else:
avg_precision = metrics.accuracy
return avg_precision
Expand Down
41 changes: 38 additions & 3 deletions pytext/metrics/__init__.py
Expand Up @@ -21,6 +21,7 @@
from pytext.utils.ascii_table import ascii_table


NAN_LABELS = ["__UNKNOWN__", "__PAD__"]
RECALL_AT_PRECISION_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9]
PRECISION_AT_RECALL_THRESHOLDS = [0.2, 0.4, 0.6, 0.8, 0.9]

Expand Down Expand Up @@ -95,6 +96,19 @@ class SoftClassificationMetrics(NamedTuple):
roc_auc: Optional[float]


class MultiLabelSoftClassificationMetrics(NamedTuple):
"""
Classification scores that are independent of thresholds.
"""

average_precision: Dict[str, float]
recall_at_precision: Dict[str, Dict[str, Dict[float, float]]]
decision_thresh_at_precision: Dict[str, Dict[str, Dict[float, float]]]
precision_at_recall: Dict[str, Dict[str, Dict[float, float]]]
decision_thresh_at_recall: Dict[str, Dict[str, Dict[float, float]]]
roc_auc: Optional[Dict[Optional[str], Optional[Dict[str, Optional[float]]]]]


class MacroPRF1Scores(NamedTuple):
"""
Macro precision/recall/F1 scores (averages across each label).
Expand Down Expand Up @@ -757,9 +771,10 @@ def compute_multi_label_multi_class_soft_metrics(
predictions: Sequence[Sequence[LabelListPrediction]],
label_names: Sequence[str],
label_vocabs: Sequence[Sequence[str]],
loss: float,
recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS,
precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS,
) -> Dict[int, SoftClassificationMetrics]:
) -> MultiLabelSoftClassificationMetrics:
"""
Computes multi-label soft classification metrics with multi-class accommodation
Expand All @@ -777,10 +792,30 @@ def compute_multi_label_multi_class_soft_metrics(
Returns:
Dict from label strings to their corresponding soft metrics.
"""
soft_metrics = {}
soft_metrics = MultiLabelSoftClassificationMetrics({}, {}, {}, {}, {}, {})
for label_idx, label_vocab in enumerate(label_vocabs):
label = list(label_names)[label_idx]
soft_metrics[label] = compute_soft_metrics(predictions[label_idx], label_vocab)
soft_metrics_ = compute_soft_metrics(predictions[label_idx], label_vocab)
temp_avg_precision_ = {k: v.average_precision for k, v in soft_metrics_.items()}
soft_metrics.average_precision[label] = sum(
v for k, v in temp_avg_precision_.items() if k not in NAN_LABELS
) / (
sum(1 for k, v in temp_avg_precision_.items() if k not in NAN_LABELS) * 1.0
)
soft_metrics.recall_at_precision[label] = {
k: v.recall_at_precision for k, v in soft_metrics_.items()
}
soft_metrics.decision_thresh_at_precision[label] = {
k: v.decision_thresh_at_precision for k, v in soft_metrics_.items()
}
soft_metrics.precision_at_recall[label] = {
k: v.precision_at_recall for k, v in soft_metrics_.items()
}
soft_metrics.decision_thresh_at_recall[label] = {
k: v.decision_thresh_at_recall for k, v in soft_metrics_.items()
}
soft_metrics.roc_auc[label] = {k: v.roc_auc for k, v in soft_metrics_.items()}

return soft_metrics


Expand Down
Expand Up @@ -9,7 +9,6 @@
from caffe2.python import core
from pytext.data.tensorizers import Tensorizer
from pytext.models.module import create_module

# from pytext.utils.label import get_label_weights
from pytext.utils.usage import log_class_usage
from torch import jit
Expand Down

0 comments on commit 447c799

Please sign in to comment.