Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Rare Word F1 #3566

Merged
merged 7 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,58 @@ def compute(guess: str, answers: List[str]) -> F1Metric:
return F1Metric(max(f1 for p, r, f1 in scores), 1)


class RareWordF1Calculator:
"""
Helper class for computing F1 with an emphasis on infrequent words.
"""

def __init__(self, corpus: str, top_p: float = 0.5):
try:
import nltk
except ImportError:
raise ImportError('Please install nltk (e.g. pip install nltk).')
words = normalize_answer(corpus).split()
self._freq_dist = nltk.FreqDist(words)
self._cutoff_count = RareWordF1Calculator._find_cutoff_count(
self._freq_dist, top_p
)

@staticmethod
def _find_cutoff_count(freq_dist, top_p: float) -> int:
"""
Finds the word occurance for which the cumulative occurances
are `top_p` of the overall word count.
"""
assert top_p < 1
target = sum(freq_dist.values()) * top_p
cumul = 0
for _, v in freq_dist.most_common():
cumul += v
if cumul > target:
return v
raise RuntimeError(f"Invalid top {top_p*100}% of the corpus distribution")

@staticmethod
def _filter(freq_dist, cutoff: int, text: str) -> str:
"""
For words that are found in the reference distribution, filters those
with an occurrence count less than the cutoff.
"""
words = normalize_answer(text).split()
return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff])

def compute(self, guess: str, answers: List[str]) -> F1Metric:
guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess)
answers = [
RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a)
for a in answers
]
if not any(len(a) for a in answers):
# no rare words in labels, set denominator to zero
return F1Metric(0, 0)
return F1Metric.compute(guess, answers)


class ExactMatchMetric(AverageMetric):
@staticmethod
def compute(guess: str, answers: List[str]) -> ExactMatchMetric:
Expand Down
30 changes: 29 additions & 1 deletion parlai/tasks/wizard_of_wikipedia/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from typing import Optional, Tuple
from parlai.core.message import Message
from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric
from parlai.core.metrics import (
AverageMetric,
normalize_answer,
F1Metric,
RareWordF1Calculator,
)
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
import copy
Expand Down Expand Up @@ -181,6 +186,15 @@ def share(self):
###############################################################


def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator:
all_text = ''
data_path = os.path.join(datapath, 'wizard_of_wikipedia', 'data.json')
with PathManager.open(data_path) as f:
data = json.load(f)
all_text += ' '.join(m['text'] for d in data for m in d['dialog']) + ' '
return RareWordF1Calculator(all_text, top_p=0.5)


class WizardDialogKnowledgeTeacher(WizardOfWikipediaTeacher):
"""
Teacher that returns the following action dict:
Expand Down Expand Up @@ -210,6 +224,10 @@ def __init__(self, opt, shared=None):
self.knowledge_separator = opt.get('include_knowledge_separator', False)
self.chosen_topic_delimiter = opt.get('chosen_topic_delimiter', '\n')
self.num_exs = sum(self.len_episode(i) for i in range(len(self.data)))
if shared and 'rare_word_f1' in shared:
self.rare_word_f1 = shared['rare_word_f1']
elif self.label_type == 'response':
self.rare_word_f1 = _build_rare_word_f1(opt['datapath'])

@classmethod
def add_cmdline_args(
Expand Down Expand Up @@ -258,6 +276,12 @@ def add_cmdline_args(
)
return parser

def share(self):
shared = super().share()
if hasattr(self, 'rare_word_f1'):
shared['rare_word_f1'] = self.rare_word_f1
return shared

def len_episode(self, ep):
d = self.data[ep]
wizard_first = 'Wizard' in d['dialog'][0]['speaker']
Expand Down Expand Up @@ -390,6 +414,10 @@ def custom_evaluation(
model_response['text'], [teacher_action['checked_sentence']]
),
)
self.metrics.add(
'rare_word_f1',
self.rare_word_f1.compute(model_response['text'], labels),
)
elif (
self.label_type == 'chosen_sent'
and TOKEN_KNOWLEDGE in model_response['text']
Expand Down