-
Notifications
You must be signed in to change notification settings - Fork 13
/
evaluate
executable file
·126 lines (107 loc) · 5.32 KB
/
evaluate
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python
import os
import sys
import click
import nlgeval
import json
import zipfile
import re
import logging
import tabulate
from io import TextIOWrapper
LOG = logging.getLogger("eval")
NLGEVAL = None
def get_domain_results(domain, domain_gt, domain_test_spec, predictions, all_metrics):
global NLGEVAL
# sanity checks
missing_gts, missing_preds = [], []
for spec in domain_test_spec:
if spec['target_dlg'] not in domain_gt:
missing_gts.append(spec['target_dlg'])
if missing_gts:
print("Domain %s incomplete: missing %d ground truth dialogues for %s" % (
domain, len(missing_gts), ", ".join(missing_gts[:3] + ["..."])))
sys.exit(1)
LOG.debug("Found ground truth dialogues for all test specifications")
pred_ids = {(pred['dlg_id'], pred['predict_turn']): pred for pred in predictions}
for spec in domain_test_spec:
if (spec['target_dlg'], spec['predict_turn']) not in pred_ids:
missing_preds.append("%s turn %d" % (spec['target_dlg'], spec['predict_turn']))
if missing_preds:
print("Domain %s incomplete: missing %d predictions for %s" % (
domain, len(missing_preds), ", ".join(missing_preds[:3] + ["..."])))
sys.exit(1)
LOG.debug("Found prediction for all test specifications")
references, hypotheses = [], []
for spec in domain_test_spec:
references.append(domain_gt[spec['target_dlg']]['turns'][spec['predict_turn']])
key = (spec['target_dlg'], spec['predict_turn'])
hypotheses.append(pred_ids[key]['response'])
if NLGEVAL is None:
if all_metrics:
NLGEVAL = nlgeval.NLGEval()
else:
NLGEVAL = nlgeval.NLGEval(no_skipthoughts=True, no_glove=True)
return NLGEVAL.compute_metrics([references], hypotheses)
def format_results(results):
"""results: dict where key is domain, value is a dict metric -> value
"""
metric_names = sorted(next(iter(results.values())).keys())
metric_sums = {m: 0. for m in metric_names}
rows = []
for domain, res in results.items():
rows.append((domain, *[res[m] for m in metric_names]))
for m in metric_names:
metric_sums[m] += res[m]
rows.append(("AVERAGE", *[metric_sums[m] / len(rows) for m in metric_names]))
print(tabulate.tabulate(rows, headers=['Domain', *metric_names]))
@click.command()
@click.argument("ground-truth-zip", type=click.Path(exists=True, dir_okay=False))
@click.argument("test-spec", type=click.Path(exists=True, dir_okay=False))
@click.argument("results-dir", type=click.Path(exists=True, file_okay=False))
@click.argument("domains", nargs=-1)
@click.option('-v', '--verbose', is_flag=True)
@click.option('-a', '--all-metrics', is_flag=True, help="also show metrics that are expensive to compute")
def main(ground_truth_zip, results_dir, test_spec, domains, verbose, all_metrics):
"""
Evaluates model responses using nlgeval.
Arguments:
ground-truth-zip: zip file containing one txt file with dialogues per domain in the "dialogues" folder.
test-spec: jsonl file specifying which dialogue+turn to predict
results-dir: directory containing subfolders named after domains, with a "results.jsonl" each
domains: any number of domains to evaluate. The domain (or its dialogues) must be referenced in all previous arguments.
"""
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
LOG.debug("reading test specification from %s", test_spec)
with open(test_spec, 'rt') as test_spec_f:
test_spec = [json.loads(line) for line in test_spec_f.readlines()]
results_all_domains = {}
with zipfile.ZipFile(ground_truth_zip) as gt_f:
if not len(domains):
rgx = re.compile(r'dialogues/(.*?)\.txt')
domains = [rgx.search(p).group(1) for p in gt_f.namelist() \
if 'tasks' not in os.path.basename(p) and p.startswith('dialogues')]
for domain in domains:
LOG.info("processing domain %s", domain)
gt_filename_in_zip = "dialogues/" + domain + ".txt"
LOG.debug("reading ground truth from %s", gt_filename_in_zip)
with gt_f.open(gt_filename_in_zip) as domain_f:
with TextIOWrapper(domain_f, encoding='utf-8') as domain_f_utf8:
domain_gt = [json.loads(line) for line in domain_f_utf8]
domain_gt = {dlg['id']: dlg for dlg in domain_gt}
domain_test_spec = [spec for spec in test_spec if spec['target_dlg'] in domain_gt]
LOG.debug("Found %d specs for %s", len(domain_test_spec), domain)
results_filename = os.path.join(results_dir, domain, "results.jsonl")
LOG.debug("Reading results from %s", results_filename)
with open(results_filename) as pred_f:
lines = pred_f.readlines()
if len(lines) == 1:
# early versions of results output lack \n between lines
lines[0] = re.sub(r'}{', '}\n{', lines[0])
lines = re.split(r'\n', lines[0])
pred = [json.loads(line) for line in lines]
domain_res = get_domain_results(domain, domain_gt, domain_test_spec, pred, all_metrics)
results_all_domains[domain] = domain_res
format_results(results_all_domains)
if __name__ == '__main__':
main()