diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index 1ff7f9acc..eee55949f 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -1,5 +1,6 @@ from __future__ import annotations +from Levenshtein import distance, ratio import typing as t from dataclasses import dataclass @@ -38,6 +39,28 @@ def score(self, ground_truth: str, generated_text: str): return score.fmeasure +class EditScore(Metric): + measure: t.Literal["distance", "ratio"] = "ratio" + + @property + def name(self) -> str: + return f"edit_{self.measure}" + + @property + def is_batchable(self): + return True + + def score(self, ground_truth: t.List[str], generated_text: t.List[str]): + if self.measure == "distance": + score = [distance(s1, s2) for s1, s2 in zip(ground_truth, generated_text)] + elif self.measure == "ratio": + score = [ratio(s1, s2) for s1, s2 in zip(ground_truth, generated_text)] + else: + raise ValueError(f"Unkown measure {self.measure}") + + return score + + Rouge1 = ROUGE("rouge1") Rouge2 = ROUGE("rouge2") RougeL = ROUGE("rougeL")