From 142f74491190a8ec21c861c2b30967590f4394a1 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 02:27:38 +0530 Subject: [PATCH 1/3] add bleu score --- belar/metrics/simple.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index 1ff7f9acc..7d9dfb369 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -2,7 +2,8 @@ import typing as t from dataclasses import dataclass - +from nltk.tokenize import word_tokenize +from nltk.translate.bleu_score import corpus_bleu from rouge_score import rouge_scorer from belar.metrics.base import Metric @@ -20,11 +21,9 @@ def __post_init__(self): [self.type], use_stemmer=self.use_stemmer ) - @property def name(self): return self.type - @property def is_batchable(self): return False @@ -38,6 +37,29 @@ def score(self, ground_truth: str, generated_text: str): return score.fmeasure +class BLEU(Metric): + weights: t.List[float] = [0.25, 0.25, 0.25, 0.25] + smoothing_function = None + + @property + def name(self): + return "BLEU" + + @property + def is_batchable(self): + return True + + def score(self, ground_truth: t.List[str], generated_text: t.List[str]): + ground_truth_ = [[word_tokenize(text)] for text in ground_truth] + generated_text_ = [word_tokenize(text) for text in generated_text] + return corpus_bleu( + ground_truth_, + generated_text_, + weights=self.weights, + smoothing_function=self.smoothing_function, + ) + + Rouge1 = ROUGE("rouge1") Rouge2 = ROUGE("rouge2") RougeL = ROUGE("rougeL") From 618aa0443e92a462259fa54ce467246bda4499e8 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 03:09:33 +0530 Subject: [PATCH 2/3] return list of scores --- belar/metrics/simple.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index 7d9dfb369..bd23e6a3e 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -3,7 +3,7 @@ import typing as t from dataclasses import dataclass from nltk.tokenize import word_tokenize -from nltk.translate.bleu_score import corpus_bleu +from nltk.translate.bleu_score import sentence_bleu from rouge_score import rouge_scorer from belar.metrics.base import Metric @@ -52,12 +52,15 @@ def is_batchable(self): def score(self, ground_truth: t.List[str], generated_text: t.List[str]): ground_truth_ = [[word_tokenize(text)] for text in ground_truth] generated_text_ = [word_tokenize(text) for text in generated_text] - return corpus_bleu( - ground_truth_, - generated_text_, - weights=self.weights, - smoothing_function=self.smoothing_function, - ) + return [ + sentence_bleu( + s1, + s2, + weights=self.weights, + smoothing_function=self.smoothing_function, + ) + for s1, s2 in zip(ground_truth_, generated_text_) + ] Rouge1 = ROUGE("rouge1") From 41ed2055b6c9fd1560def6e22a5b4d2a974f05a9 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 03:30:31 +0530 Subject: [PATCH 3/3] move bleu score --- belar/metrics/simple.py | 51 +++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index bd23e6a3e..2bbdb5883 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -12,31 +12,6 @@ @dataclass -class ROUGE(Metric): - type: t.Literal[ROUGE_TYPES] - use_stemmer: bool = False - - def __post_init__(self): - self.scorer = rouge_scorer.RougeScorer( - [self.type], use_stemmer=self.use_stemmer - ) - - def name(self): - return self.type - - def is_batchable(self): - return False - - def score(self, ground_truth: str, generated_text: str): - if isinstance(ground_truth, list): - ground_truth = ground_truth[0] - if isinstance(generated_text, list): - generated_text = generated_text[0] - - score = self.scorer.score(ground_truth, generated_text)[self.type] - return score.fmeasure - - class BLEU(Metric): weights: t.List[float] = [0.25, 0.25, 0.25, 0.25] smoothing_function = None @@ -63,6 +38,32 @@ def score(self, ground_truth: t.List[str], generated_text: t.List[str]): ] +@dataclass +class ROUGE(Metric): + type: t.Literal[ROUGE_TYPES] + use_stemmer: bool = False + + def __post_init__(self): + self.scorer = rouge_scorer.RougeScorer( + [self.type], use_stemmer=self.use_stemmer + ) + + def name(self): + return self.type + + def is_batchable(self): + return False + + def score(self, ground_truth: str, generated_text: str): + if isinstance(ground_truth, list): + ground_truth = ground_truth[0] + if isinstance(generated_text, list): + generated_text = generated_text[0] + + score = self.scorer.score(ground_truth, generated_text)[self.type] + return score.fmeasure + + Rouge1 = ROUGE("rouge1") Rouge2 = ROUGE("rouge2") RougeL = ROUGE("rougeL")