Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
implemented perplexity reductions for lm score reporting (#779)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #779

Implements perplexity reductions for language model score reporting. Currently, the reported score is the average perplexity of all words in a sentence, but this may not work for all use cases (e.g., the presence of named entities in a sentence brings up the average perplexity, so the resulting metric may not be as useful). To this end, this diff introduces several ways to "reduce" the perplexity scores of a sentence: min, max, mean, median, and eos (i.e., the perplexity of the `__END_OF_SENTENCE__` token).

Reviewed By: abhinavarora

Differential Revision: D16244881

fbshipit-source-id: a579d27b7a17c0f52b0fa698f514db90be31694b
  • Loading branch information
shreydesai authored and facebook-github-bot committed Jul 16, 2019
1 parent fbf30bb commit 4124119
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
8 changes: 8 additions & 0 deletions pytext/config/module_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ class SlotAttentionType(Enum):
CONCAT = "concat"
MULTIPLY = "multiply"
DOT = "dot"


class PerplexityType(Enum):
MIN = "min"
MAX = "max"
MEAN = "mean"
MEDIAN = "median"
EOS = "eos"
36 changes: 32 additions & 4 deletions pytext/metric_reporters/language_model_metric_reporter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import operator
import time

import torch
import torch.nn.functional as F
from pytext.common.constants import Stage
from pytext.config.module_config import PerplexityType
from pytext.data import CommonMetadata
from pytext.metrics.language_model_metrics import (
LanguageModelMetric,
Expand All @@ -16,6 +18,22 @@
from .metric_reporter import MetricReporter


PERPLEXITY_FUNC_MAP = {
PerplexityType.MIN: torch.min,
PerplexityType.MAX: torch.max,
PerplexityType.MEAN: torch.mean,
PerplexityType.MEDIAN: torch.median,
PerplexityType.EOS: operator.itemgetter(-1),
}


def get_perplexity_func(perplexity_type):
func = PERPLEXITY_FUNC_MAP.get(perplexity_type, None)
if not func:
raise NotImplementedError
return func


class LanguageModelChannel(FileChannel):
def get_title(self):
return ("text", "perplexity")
Expand All @@ -32,19 +50,22 @@ class LanguageModelMetricReporter(MetricReporter):

class Config(MetricReporter.Config):
aggregate_metrics: bool = True
perplexity_type: PerplexityType = PerplexityType.MEDIAN

@classmethod
def from_config(cls, config: Config, meta: CommonMetadata = None, tensorizers=None):
return cls(
[ConsoleChannel(), LanguageModelChannel((Stage.TEST,), config.output_path)],
tensorizers,
config.aggregate_metrics,
config.perplexity_type,
)

def __init__(self, channels, tensorizers, aggregate_metrics):
def __init__(self, channels, tensorizers, aggregate_metrics, perplexity_type):
super().__init__(channels)
self.tensorizers = tensorizers
self.aggregate_metrics = aggregate_metrics
self.perplexity_func = get_perplexity_func(perplexity_type)

def add_batch_stats(
self, n_batches, preds, targets, scores, loss, m_input, **context
Expand Down Expand Up @@ -89,10 +110,12 @@ def batch_context(self, raw_batch, batch):
return context

def compute_scores(self, pred, target):
def _compute_score(tensor):
return torch.exp(self.perplexity_func(tensor[tensor != 0.0]))

logits, pad_idx = pred
scores = F.nll_loss(logits, target, ignore_index=pad_idx, reduction="none")
per_sentence_loss = (torch.exp(y[y != 0].mean()) for y in scores)
return map(lambda x: x.item(), per_sentence_loss)
return map(lambda x: _compute_score(x).item(), scores)

def aggregate_scores(self, scores):
self.all_scores.extend(scores)
Expand All @@ -107,7 +130,12 @@ def aggregate_context(self, context):
class MaskedLMMetricReporter(LanguageModelMetricReporter):
@classmethod
def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None):
return cls([ConsoleChannel()], tensorizers, config.aggregate_metrics)
return cls(
[ConsoleChannel()],
tensorizers,
config.aggregate_metrics,
config.perplexity_type,
)

def add_batch_stats(
self, n_batches, preds, targets, scores, loss, m_input, **context
Expand Down

0 comments on commit 4124119

Please sign in to comment.