|
2 | 2 |
|
3 | 3 | import typing as t |
4 | 4 | from abc import ABC, abstractmethod |
| 5 | +from collections import namedtuple |
5 | 6 | from dataclasses import dataclass |
6 | 7 |
|
7 | 8 | import numpy as np |
@@ -31,13 +32,18 @@ class Evaluation: |
31 | 32 | batched: bool = True |
32 | 33 | batch_size: int = 1000 |
33 | 34 |
|
34 | | - def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Dataset: |
| 35 | + def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Result: |
35 | 36 | ds = Dataset.from_dict( |
36 | 37 | {"ground_truth": ground_truth, "generated_text": generated_text} |
37 | 38 | ) |
38 | | - ds = ds.map(self._get_score, batched=self.batched, batch_size=self.batch_size) |
| 39 | + ds = ds.map( |
| 40 | + self._get_score, |
| 41 | + batched=self.batched, |
| 42 | + batch_size=self.batch_size, |
| 43 | + remove_columns=["ground_truth", "generated_text"], |
| 44 | + ) |
39 | 45 |
|
40 | | - return ds |
| 46 | + return Result(ds) |
41 | 47 |
|
42 | 48 | # TODO: set a typevar here for row |
43 | 49 | def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]): |
@@ -69,3 +75,29 @@ def _get_score(self, row: dict[str, list[t.Any]] | dict[str, t.Any]): |
69 | 75 | row[f"{metric.name}_score"] = score |
70 | 76 |
|
71 | 77 | return row |
| 78 | + |
| 79 | + |
| 80 | +@dataclass |
| 81 | +class Result(dict): |
| 82 | + scores: Dataset |
| 83 | + |
| 84 | + def __post_init__(self): |
| 85 | + for cn in self.scores.column_names: |
| 86 | + self[cn] = np.mean(self.scores[cn]) |
| 87 | + |
| 88 | + def describe(self): |
| 89 | + description = {} |
| 90 | + for cn in self.scores.column_names: |
| 91 | + description[cn] = { |
| 92 | + "mean": np.mean(self.scores[cn]), |
| 93 | + "25%": np.percentile(self.scores[cn], 25), |
| 94 | + "50%": np.percentile(self.scores[cn], 50), |
| 95 | + "75%": np.percentile(self.scores[cn], 75), |
| 96 | + "min": np.min(self.scores[cn]), |
| 97 | + "max": np.max(self.scores[cn]), |
| 98 | + "std": np.std(self.scores[cn]), |
| 99 | + } |
| 100 | + return description |
| 101 | + |
| 102 | + def __repr__(self) -> str: |
| 103 | + return super().__repr__() |
0 commit comments