diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index 7e42db99a..6ad1367e0 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,3 +1,4 @@ from belar.metrics.base import Evaluation, Metric +from belar.metrics.similarity import * from belar.metrics.simple import * from belar.metrics.similarity import SBERTScore \ No newline at end of file diff --git a/belar/metrics/base.py b/belar/metrics/base.py index b900abce5..9662a8cf4 100644 --- a/belar/metrics/base.py +++ b/belar/metrics/base.py @@ -23,24 +23,21 @@ def is_batchable(self) -> bool: def score(self, ground_truth, generated_text) -> float | list[float]: ... - def __call__(self, row): - score = self.score(row["ground_truth"], row["generated_text"]) - row[f"{self.name}_score"] = score - - return row - @dataclass class Evaluation: metrics: list[Metric] + batched: bool = False - def eval( - self, ground_truth: Dataset, generated_text: t.Sequence, batched: bool = False - ): + def eval(self, ground_truth: list[list[str]], generated_text: list[list[str]]): ds = ground_truth.add_column("generated_text", generated_text) - scores_list = [] + ds = ds.map(self._get_score, batched=self.batched) + + return ds + + def _get_score(self, row): for metric in self.metrics: - scores = ds.map(metric, batched=batched)[f"{metric.name}_score"] - scores_list.append(scores) + score = metric.score(row["ground_truth"], row["generated_text"]) + row[f"{metric.name}_score"] = score - return scores_list + return row diff --git a/belar/metrics/similarity.py b/belar/metrics/similarity.py index 1d2669132..377356675 100644 --- a/belar/metrics/similarity.py +++ b/belar/metrics/similarity.py @@ -1,5 +1,4 @@ from __future__ import annotations -from ast import List import typing as t from dataclasses import dataclass @@ -13,41 +12,51 @@ @dataclass class SBERTScore(Metric): - similarity_metric: t.Literal[SBERT_METRIC] = "cosine" model_path: str = "all-MiniLM-L6-v2" batch_size: int = 1000 def __post_init__(self): - self.model = SentenceTransformer(self.model_path) - def name(self,): - return f"SBERT-{self.similarity_metric}-score" + @property + def name( + self, + ): + return f"SBERT_{self.similarity_metric}" def is_batchable(self): return True - - def score(self, ground_truth: t.Union[str, t.List[str]], generated_text: t.Union[str, t.List[str]]): + def score( + self, + ground_truth: str | list[str], + generated_text: str | list[str], + ): if isinstance(ground_truth, str): ground_truth = [ground_truth] if isinstance(generated_text, str): generated_text = [generated_text] - - gndtruth_emb = self.model.encode(ground_truth, batch_size=self.batch_size, - convert_to_numpy=True) - gentext_emb = self.model.encode(generated_text, batch_size=self.batch_size, - convert_to_numpy=True) - + + gndtruth_emb = self.model.encode( + ground_truth, batch_size=self.batch_size, convert_to_numpy=True + ) + gentext_emb = self.model.encode( + generated_text, batch_size=self.batch_size, convert_to_numpy=True + ) + if self.similarity_metric == "cosine": - score = np.dot(gndtruth_emb, gentext_emb.T) / (norm(gndtruth_emb) * norm(gentext_emb)) + score = np.dot(gndtruth_emb, gentext_emb.T) / ( + norm(gndtruth_emb) * norm(gentext_emb) + ) elif self.similarity_metric == "euclidean": score = norm(gndtruth_emb - gentext_emb, ord=2) - + else: raise ValueError(f"Unkown metrics {self.similarity_metric}") return score - + + +__all__ = ["SBERTScore"] diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index 338506565..1ff7f9acc 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -20,9 +20,11 @@ def __post_init__(self): [self.type], use_stemmer=self.use_stemmer ) + @property def name(self): return self.type + @property def is_batchable(self): return False diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index a4c85a8d8..6effae29b 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "806b182e", + "id": "992c777a", "metadata": {}, "outputs": [], "source": [ @@ -13,8 +13,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "4dfd5a18", + "execution_count": 4, + "id": "5eaf4729", "metadata": {}, "outputs": [ { @@ -27,7 +27,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b3e986700a834fa6ab1e327b49e1fae3", + "model_id": "a2231863e61c4ffd8d695c8531a48139", "version_major": 2, "version_minor": 0 }, @@ -42,7 +42,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading cached processed dataset at /home/jjmachan/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa/cache-b6864f61633d1e41.arrow\n" + "Loading cached processed dataset at /home/jjmachan/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa/cache-f3427cd7a8a8674f.arrow\n" ] } ], @@ -51,41 +51,70 @@ "\n", "def format_for_belar(row):\n", " row[\"context\"] = row[\"selftext\"]\n", - " row[\"question\"] = row[\"title\"]\n", - " row['answers'] = row[\"answers\"][\"text\"]\n", + " row[\"prompt\"] = row[\"title\"]\n", + " row['ground_truth'] = row[\"answers\"][\"text\"]\n", " return row\n", " \n", "d = load_dataset(\"eli5\")\n", "ds = d['test_eli5'].map(format_for_belar, batched=False)\n", - "ds = ds.select_columns([\"context\", \"question\", \"answers\"])" + "ds = ds.select_columns([\"context\", \"prompt\", \"ground_truth\"])" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "5c2974c5", + "execution_count": 5, + "id": "a501c296", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(100, 3)" + "(500, 3)" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ds = ds.shuffle().select(range(100))\n", + "ds = ds.shuffle(seed=42).select(range(500))\n", "ds.shape" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "5546153b", + "execution_count": 8, + "id": "763335eb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['context', 'prompt', 'ground_truth']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.column_names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6aff2ae9", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c32c39f5", "metadata": {}, "outputs": [ { @@ -96,7 +125,7 @@ "version_minor": 0 }, "text/plain": [ - "Map: 0%| | 0/100 [00:00