Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions belar/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -31,13 +32,18 @@ class Evaluation:
batched: bool = True
batch_size: int = 1000

def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Dataset:
def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Result:
ds = Dataset.from_dict(
{"ground_truth": ground_truth, "generated_text": generated_text}
)
ds = ds.map(self._get_score, batched=self.batched, batch_size=self.batch_size)
ds = ds.map(
self._get_score,
batched=self.batched,
batch_size=self.batch_size,
remove_columns=["ground_truth", "generated_text"],
)

return ds
return Result(ds)

# TODO: set a typevar here for row
def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]):
Expand Down Expand Up @@ -69,3 +75,29 @@ def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]):
row[f"{metric.name}_score"] = score

return row


@dataclass
class Result(dict):
scores: Dataset

def __post_init__(self):
for cn in self.scores.column_names:
self[cn] = np.mean(self.scores[cn])

def describe(self):
description = {}
for cn in self.scores.column_names:
description[cn] = {
"mean": np.mean(self.scores[cn]),
"25%": np.percentile(self.scores[cn], 25),
"50%": np.percentile(self.scores[cn], 50),
"75%": np.percentile(self.scores[cn], 75),
"min": np.min(self.scores[cn]),
"max": np.max(self.scores[cn]),
"std": np.std(self.scores[cn]),
}
return description

def __repr__(self) -> str:
return super().__repr__()