diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index eee55949f..cd71ba97b 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -3,7 +3,8 @@ from Levenshtein import distance, ratio import typing as t from dataclasses import dataclass - +from nltk.tokenize import word_tokenize +from nltk.translate.bleu_score import sentence_bleu from rouge_score import rouge_scorer from belar.metrics.base import Metric @@ -11,6 +12,33 @@ ROUGE_TYPES = t.Literal["rouge1", "rouge2", "rougeL"] +@dataclass +class BLEU(Metric): + weights: t.List[float] = [0.25, 0.25, 0.25, 0.25] + smoothing_function = None + + @property + def name(self): + return "BLEU" + + @property + def is_batchable(self): + return True + + def score(self, ground_truth: t.List[str], generated_text: t.List[str]): + ground_truth_ = [[word_tokenize(text)] for text in ground_truth] + generated_text_ = [word_tokenize(text) for text in generated_text] + return [ + sentence_bleu( + s1, + s2, + weights=self.weights, + smoothing_function=self.smoothing_function, + ) + for s1, s2 in zip(ground_truth_, generated_text_) + ] + + @dataclass class ROUGE(Metric): type: t.Literal[ROUGE_TYPES] @@ -21,11 +49,9 @@ def __post_init__(self): [self.type], use_stemmer=self.use_stemmer ) - @property def name(self): return self.type - @property def is_batchable(self): return False