Skip to content

Commit 0674ca8

Browse files
author
Jithin James
committed
feat: added results
1 parent fa29ac2 commit 0674ca8

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

belar/metrics/base.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import typing as t
44
from abc import ABC, abstractmethod
5+
from collections import namedtuple
56
from dataclasses import dataclass
67

78
import numpy as np
@@ -31,13 +32,18 @@ class Evaluation:
3132
batched: bool = True
3233
batch_size: int = 1000
3334

34-
def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Dataset:
35+
def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Result:
3536
ds = Dataset.from_dict(
3637
{"ground_truth": ground_truth, "generated_text": generated_text}
3738
)
38-
ds = ds.map(self._get_score, batched=self.batched, batch_size=self.batch_size)
39+
ds = ds.map(
40+
self._get_score,
41+
batched=self.batched,
42+
batch_size=self.batch_size,
43+
remove_columns=["ground_truth", "generated_text"],
44+
)
3945

40-
return ds
46+
return Result(ds)
4147

4248
# TODO: set a typevar here for row
4349
def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]):
@@ -69,3 +75,29 @@ def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]):
6975
row[f"{metric.name}_score"] = score
7076

7177
return row
78+
79+
80+
@dataclass
81+
class Result(dict):
82+
scores: Dataset
83+
84+
def __post_init__(self):
85+
for cn in self.scores.column_names:
86+
self[cn] = np.mean(self.scores[cn])
87+
88+
def describe(self):
89+
description = {}
90+
for cn in self.scores.column_names:
91+
description[cn] = {
92+
"mean": np.mean(self.scores[cn]),
93+
"25%": np.percentile(self.scores[cn], 25),
94+
"50%": np.percentile(self.scores[cn], 50),
95+
"75%": np.percentile(self.scores[cn], 75),
96+
"min": np.min(self.scores[cn]),
97+
"max": np.max(self.scores[cn]),
98+
"std": np.std(self.scores[cn]),
99+
}
100+
return description
101+
102+
def __repr__(self) -> str:
103+
return super().__repr__()

0 commit comments

Comments
 (0)