From 790651e21e425e4ce8203c90336e4beafa6741e5 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Sat, 13 May 2023 02:32:05 +0530 Subject: [PATCH 1/3] basic batching and variable ground_truth support --- belar/metrics/base.py | 43 +++++++++++++++++++++++++++++++------ belar/metrics/similarity.py | 1 + belar/metrics/simple.py | 16 +++++++------- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/belar/metrics/base.py b/belar/metrics/base.py index 9662a8cf4..4e16f45f8 100644 --- a/belar/metrics/base.py +++ b/belar/metrics/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +import numpy as np from datasets import Dataset @@ -20,24 +21,52 @@ def is_batchable(self) -> bool: ... @abstractmethod - def score(self, ground_truth, generated_text) -> float | list[float]: + def score( + self, ground_truth: list[str] | str, generated_text: list[str] | str + ) -> list[float] | float: ... @dataclass class Evaluation: metrics: list[Metric] - batched: bool = False + batched: bool = True + batch_size: int = 1000 - def eval(self, ground_truth: list[list[str]], generated_text: list[list[str]]): - ds = ground_truth.add_column("generated_text", generated_text) - ds = ds.map(self._get_score, batched=self.batched) + def eval(self, ground_truth: list[list[str]], generated_text: list[str]): + ds = Dataset.from_dict( + {"ground_truth": ground_truth, "generated_text": generated_text} + ) + ds = ds.map(self._get_score, batched=self.batched, batch_size=self.batch_size) return ds - def _get_score(self, row): + def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]): for metric in self.metrics: - score = metric.score(row["ground_truth"], row["generated_text"]) + if self.batched: + split_indices = [] + last_split_index = 0 + ground_truths = [] + generated_texts = [] + for i, ground_truth_list in enumerate(row["ground_truth"]): + split_indices.append(last_split_index + len(ground_truth_list)) + last_split_index = split_indices[-1] + ground_truths.extend(ground_truth_list) + generated_texts.extend( + [row["generated_text"][i]] * len(ground_truth_list) + ) + + # contruct variable array back and compute score + batch_scores_flat = metric.score(ground_truths, generated_texts) + batch_scores = np.split(batch_scores_flat, split_indices) + score = [np.max(x) for x in batch_scores[:-1]] + else: # not batched + split_indices = len(row["ground_truth"]) + ground_truths = row["ground_truth"] + generated_texts = [row["generated_text"]] * split_indices + scores = metric.score(ground_truths, generated_texts) + score = np.max(scores) + row[f"{metric.name}_score"] = score return row diff --git a/belar/metrics/similarity.py b/belar/metrics/similarity.py index 77ffa64cf..4dd55985a 100644 --- a/belar/metrics/similarity.py +++ b/belar/metrics/similarity.py @@ -2,6 +2,7 @@ import typing as t from dataclasses import dataclass + import numpy as np from numpy.linalg import norm from sentence_transformers import SentenceTransformer diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index eee55949f..4f06f2f6d 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -29,14 +29,14 @@ def name(self): def is_batchable(self): return False - def score(self, ground_truth: str, generated_text: str): - if isinstance(ground_truth, list): - ground_truth = ground_truth[0] - if isinstance(generated_text, list): - generated_text = generated_text[0] - - score = self.scorer.score(ground_truth, generated_text)[self.type] - return score.fmeasure + def score( + self, ground_truths: list[str], generated_texts: list[str] + ) -> list[float]: + scores = [ + self.scorer.score(ground_truth, generated_texts[i])[self.type].fmeasure + for i, ground_truth in enumerate(ground_truths) + ] + return scores class EditScore(Metric): From fa29ac2634789713ffb3d389007b15ecd6fa7dd0 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Sat, 13 May 2023 02:46:42 +0530 Subject: [PATCH 2/3] type fixes --- belar/metrics/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/belar/metrics/base.py b/belar/metrics/base.py index 4e16f45f8..d2d9fd932 100644 --- a/belar/metrics/base.py +++ b/belar/metrics/base.py @@ -21,9 +21,7 @@ def is_batchable(self) -> bool: ... @abstractmethod - def score( - self, ground_truth: list[str] | str, generated_text: list[str] | str - ) -> list[float] | float: + def score(self, ground_truth: list[str], generated_text: list[str]) -> list[float]: ... @@ -33,7 +31,7 @@ class Evaluation: batched: bool = True batch_size: int = 1000 - def eval(self, ground_truth: list[list[str]], generated_text: list[str]): + def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Dataset: ds = Dataset.from_dict( {"ground_truth": ground_truth, "generated_text": generated_text} ) @@ -41,6 +39,7 @@ def eval(self, ground_truth: list[list[str]], generated_text: list[str]): return ds + # TODO: set a typevar here for row def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]): for metric in self.metrics: if self.batched: From 0674ca81957bae76b0b87f96f394349b4192d5d3 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Sat, 13 May 2023 03:22:47 +0530 Subject: [PATCH 3/3] feat: added results --- belar/metrics/base.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/belar/metrics/base.py b/belar/metrics/base.py index d2d9fd932..5d0108767 100644 --- a/belar/metrics/base.py +++ b/belar/metrics/base.py @@ -2,6 +2,7 @@ import typing as t from abc import ABC, abstractmethod +from collections import namedtuple from dataclasses import dataclass import numpy as np @@ -31,13 +32,18 @@ class Evaluation: batched: bool = True batch_size: int = 1000 - def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Dataset: + def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Result: ds = Dataset.from_dict( {"ground_truth": ground_truth, "generated_text": generated_text} ) - ds = ds.map(self._get_score, batched=self.batched, batch_size=self.batch_size) + ds = ds.map( + self._get_score, + batched=self.batched, + batch_size=self.batch_size, + remove_columns=["ground_truth", "generated_text"], + ) - return ds + return Result(ds) # TODO: set a typevar here for row def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]): @@ -69,3 +75,29 @@ def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]): row[f"{metric.name}_score"] = score return row + + +@dataclass +class Result(dict): + scores: Dataset + + def __post_init__(self): + for cn in self.scores.column_names: + self[cn] = np.mean(self.scores[cn]) + + def describe(self): + description = {} + for cn in self.scores.column_names: + description[cn] = { + "mean": np.mean(self.scores[cn]), + "25%": np.percentile(self.scores[cn], 25), + "50%": np.percentile(self.scores[cn], 50), + "75%": np.percentile(self.scores[cn], 75), + "min": np.min(self.scores[cn]), + "max": np.max(self.scores[cn]), + "std": np.std(self.scores[cn]), + } + return description + + def __repr__(self) -> str: + return super().__repr__()