Skip to content

Commit

Permalink
fix classifier and add f-score evaluator (#504)
Browse files Browse the repository at this point in the history
* fix classifier and add f-score evaluator

* rename to f-measure

* fix unit test
  • Loading branch information
msperber committed Aug 3, 2018
1 parent 73bfaaa commit c0e30a4
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 13 deletions.
24 changes: 12 additions & 12 deletions test/config/classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ classifier: !Experiment
run_for_epochs: 2
src_file: examples/data/head.en
trg_file: examples/data/head.smallids
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.en
ref_file: examples/data/head.smallids
- !AccuracyEvalTask
eval_metrics: accuracy
src_file: examples/data/head.en
ref_file: examples/data/head.smallids
hyp_file: test/tmp/{EXP}.test_hyp
inference: !IndependentOutputInference
batcher: !InOrderBatcher
batch_size: 2
evaluate:
- !LossEvalTask
src_file: examples/data/head.en
ref_file: examples/data/head.smallids
- !AccuracyEvalTask
eval_metrics: accuracy,fmeasure
src_file: examples/data/head.en
ref_file: examples/data/head.smallids
hyp_file: test/tmp/{EXP}.test_hyp
inference: !IndependentOutputInference
batcher: !InOrderBatcher
batch_size: 2
66 changes: 66 additions & 0 deletions xnmt/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,41 @@ def aggregate(scores: Sequence['SentenceLevelEvalScore'], desc: Any = None):
desc=desc)


class FMeasure(SentenceLevelEvalScore, Serializable):
yaml_tag = "!FMeasure"
@serializable_init
def __init__(self, true_pos: int, false_neg: int, false_pos: int, desc: Any = None):
self.true_pos = true_pos
self.false_neg = false_neg
self.false_pos = false_pos
self.serialize_params = {"true_pos": true_pos, "false_neg": false_neg, "false_pos": false_pos}
if desc is not None: self.serialize_params["desc"] = desc
def higher_is_better(self): return True
def value(self):
if self.true_pos + self.false_neg + self.false_pos > 0:
return 2*self.true_pos/(2*self.true_pos + self.false_neg + self.false_pos)
else:
return "n/a"
def metric_name(self): return "F1 Score"
def score_str(self):
prec = 0
if self.true_pos+self.false_pos > 0: prec = self.true_pos/(self.true_pos+self.false_pos)
rec = 0
if self.true_pos+self.false_neg > 0: rec = self.true_pos/(self.true_pos+self.false_neg)
val = self.value()
if isinstance(val, float): val = f"{self.value()*100.0:.2f}%"
return f"{val} " \
f"(prec: {prec}, " \
f"recall: {rec}; " \
f"TP={self.true_pos},FP={self.false_pos},FN={self.false_neg})"
@staticmethod
def aggregate(scores: Sequence['SentenceLevelEvalScore'], desc: Any = None):
return FMeasure( true_pos=sum(s.true_pos for s in scores),
false_neg=sum(s.false_neg for s in scores),
false_pos=sum(s.false_pos for s in scores),
desc=desc)


class Evaluator(object):
"""
A template class to evaluate the quality of output.
Expand Down Expand Up @@ -884,3 +919,34 @@ def evaluate_one_sent(self, ref:Sequence[str], hyp:Sequence[str]):
"""
correct = 1 if self._compare(ref, hyp) else 0
return SequenceAccuracyScore(num_correct=correct, num_total=1)


class FMeasureEvaluator(SentenceLevelEvaluator, Serializable):
"""
A class to evaluate the quality of output in terms of classification F-score.
Args:
pos_token: token for the 'positive' class
write_sentence_scores: path of file to write sentence-level scores to (in YAML format)
"""
yaml_tag = "!FMeasureEvaluator"
@serializable_init
def __init__(self, pos_token:str="1", write_sentence_scores: Optional[str] = None) -> None:
super().__init__(write_sentence_scores=write_sentence_scores)
self.pos_token = pos_token

def evaluate_one_sent(self, ref:Sequence[str], hyp:Sequence[str]):
"""
Calculate the accuracy of output given a references.
Args:
ref: list of list of reference words
hyp: list of list of decoded words
Return: formatted string
"""
if len(ref)!=1 or len(hyp)!=1: raise ValueError("FScore requires scalar ref and hyp")
ref = ref[0]
hyp = hyp[0]
return FMeasure( true_pos=1 if (ref == hyp) and (hyp == self.pos_token) else 0,
false_neg=1 if (ref != hyp) and (hyp != self.pos_token) else 0,
false_pos=1 if (ref != hyp) and (hyp == self.pos_token) else 0)
2 changes: 1 addition & 1 deletion xnmt/models/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def generate(self, src, forced_trg_ids=None, normalize_scores=False):
outputs = []
for batch_i in range(src.batch_size()):
score = np_scores[:, batch_i][output_action[batch_i]]
outputs.append(sent.ScalarSentence(value=output_action,
outputs.append(sent.ScalarSentence(value=output_action[batch_i],
score=score))
return outputs

Expand Down
1 change: 1 addition & 0 deletions xnmt/xnmt_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def read_data(loc_, post_process=None):
"cer": lambda: metrics.CEREvaluator(),
"recall": lambda: metrics.RecallEvaluator(),
"accuracy": lambda: metrics.SequenceAccuracyEvaluator(),
"fmeasure" : lambda: metrics.FMeasureEvaluator(),
}


Expand Down

0 comments on commit c0e30a4

Please sign in to comment.