From 41241197518a01f775b1a6f4cbd3a12c9d112d37 Mon Sep 17 00:00:00 2001 From: Shrey Desai Date: Mon, 15 Jul 2019 19:41:14 -0700 Subject: [PATCH] implemented perplexity reductions for lm score reporting (#779) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/779 Implements perplexity reductions for language model score reporting. Currently, the reported score is the average perplexity of all words in a sentence, but this may not work for all use cases (e.g., the presence of named entities in a sentence brings up the average perplexity, so the resulting metric may not be as useful). To this end, this diff introduces several ways to "reduce" the perplexity scores of a sentence: min, max, mean, median, and eos (i.e., the perplexity of the `__END_OF_SENTENCE__` token). Reviewed By: abhinavarora Differential Revision: D16244881 fbshipit-source-id: a579d27b7a17c0f52b0fa698f514db90be31694b --- pytext/config/module_config.py | 8 +++++ .../language_model_metric_reporter.py | 36 ++++++++++++++++--- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/pytext/config/module_config.py b/pytext/config/module_config.py index 1418d7995..f784027b2 100644 --- a/pytext/config/module_config.py +++ b/pytext/config/module_config.py @@ -34,3 +34,11 @@ class SlotAttentionType(Enum): CONCAT = "concat" MULTIPLY = "multiply" DOT = "dot" + + +class PerplexityType(Enum): + MIN = "min" + MAX = "max" + MEAN = "mean" + MEDIAN = "median" + EOS = "eos" diff --git a/pytext/metric_reporters/language_model_metric_reporter.py b/pytext/metric_reporters/language_model_metric_reporter.py index aa9c31990..b1d58b550 100644 --- a/pytext/metric_reporters/language_model_metric_reporter.py +++ b/pytext/metric_reporters/language_model_metric_reporter.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import math +import operator import time import torch import torch.nn.functional as F from pytext.common.constants import Stage +from pytext.config.module_config import PerplexityType from pytext.data import CommonMetadata from pytext.metrics.language_model_metrics import ( LanguageModelMetric, @@ -16,6 +18,22 @@ from .metric_reporter import MetricReporter +PERPLEXITY_FUNC_MAP = { + PerplexityType.MIN: torch.min, + PerplexityType.MAX: torch.max, + PerplexityType.MEAN: torch.mean, + PerplexityType.MEDIAN: torch.median, + PerplexityType.EOS: operator.itemgetter(-1), +} + + +def get_perplexity_func(perplexity_type): + func = PERPLEXITY_FUNC_MAP.get(perplexity_type, None) + if not func: + raise NotImplementedError + return func + + class LanguageModelChannel(FileChannel): def get_title(self): return ("text", "perplexity") @@ -32,6 +50,7 @@ class LanguageModelMetricReporter(MetricReporter): class Config(MetricReporter.Config): aggregate_metrics: bool = True + perplexity_type: PerplexityType = PerplexityType.MEDIAN @classmethod def from_config(cls, config: Config, meta: CommonMetadata = None, tensorizers=None): @@ -39,12 +58,14 @@ def from_config(cls, config: Config, meta: CommonMetadata = None, tensorizers=No [ConsoleChannel(), LanguageModelChannel((Stage.TEST,), config.output_path)], tensorizers, config.aggregate_metrics, + config.perplexity_type, ) - def __init__(self, channels, tensorizers, aggregate_metrics): + def __init__(self, channels, tensorizers, aggregate_metrics, perplexity_type): super().__init__(channels) self.tensorizers = tensorizers self.aggregate_metrics = aggregate_metrics + self.perplexity_func = get_perplexity_func(perplexity_type) def add_batch_stats( self, n_batches, preds, targets, scores, loss, m_input, **context @@ -89,10 +110,12 @@ def batch_context(self, raw_batch, batch): return context def compute_scores(self, pred, target): + def _compute_score(tensor): + return torch.exp(self.perplexity_func(tensor[tensor != 0.0])) + logits, pad_idx = pred scores = F.nll_loss(logits, target, ignore_index=pad_idx, reduction="none") - per_sentence_loss = (torch.exp(y[y != 0].mean()) for y in scores) - return map(lambda x: x.item(), per_sentence_loss) + return map(lambda x: _compute_score(x).item(), scores) def aggregate_scores(self, scores): self.all_scores.extend(scores) @@ -107,7 +130,12 @@ def aggregate_context(self, context): class MaskedLMMetricReporter(LanguageModelMetricReporter): @classmethod def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None): - return cls([ConsoleChannel()], tensorizers, config.aggregate_metrics) + return cls( + [ConsoleChannel()], + tensorizers, + config.aggregate_metrics, + config.perplexity_type, + ) def add_batch_stats( self, n_batches, preds, targets, scores, loss, m_input, **context