From 5826734998e4185df5edf8df76fffff3a361896b Mon Sep 17 00:00:00 2001 From: Jithin James Date: Fri, 12 May 2023 23:57:50 +0530 Subject: [PATCH] fix: batching in Metric --- belar/metrics/__init__.py | 1 + belar/metrics/base.py | 23 ++- belar/metrics/similarity.py | 64 +++++++++ belar/metrics/simple.py | 2 + examples/quickstart.ipynb | 276 +++++++++++++++++------------------- 5 files changed, 208 insertions(+), 158 deletions(-) create mode 100644 belar/metrics/similarity.py diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index f3b63be9f..1dc56eba2 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,2 +1,3 @@ from belar.metrics.base import Evaluation, Metric +from belar.metrics.similarity import * from belar.metrics.simple import * diff --git a/belar/metrics/base.py b/belar/metrics/base.py index b900abce5..9662a8cf4 100644 --- a/belar/metrics/base.py +++ b/belar/metrics/base.py @@ -23,24 +23,21 @@ def is_batchable(self) -> bool: def score(self, ground_truth, generated_text) -> float | list[float]: ... - def __call__(self, row): - score = self.score(row["ground_truth"], row["generated_text"]) - row[f"{self.name}_score"] = score - - return row - @dataclass class Evaluation: metrics: list[Metric] + batched: bool = False - def eval( - self, ground_truth: Dataset, generated_text: t.Sequence, batched: bool = False - ): + def eval(self, ground_truth: list[list[str]], generated_text: list[list[str]]): ds = ground_truth.add_column("generated_text", generated_text) - scores_list = [] + ds = ds.map(self._get_score, batched=self.batched) + + return ds + + def _get_score(self, row): for metric in self.metrics: - scores = ds.map(metric, batched=batched)[f"{metric.name}_score"] - scores_list.append(scores) + score = metric.score(row["ground_truth"], row["generated_text"]) + row[f"{metric.name}_score"] = score - return scores_list + return row diff --git a/belar/metrics/similarity.py b/belar/metrics/similarity.py new file mode 100644 index 000000000..06d7e9135 --- /dev/null +++ b/belar/metrics/similarity.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import typing as t +from dataclasses import dataclass + +import numpy as np +from numpy.linalg import norm +from sentence_transformers import SentenceTransformer + +from belar.metrics.base import Metric + +SBERT_METRIC = t.Literal["cosine", "euclidean"] + + +@dataclass +class SBERTScore(Metric): + similarity_metric: t.Literal[SBERT_METRIC] = "cosine" + model_path: str = "all-MiniLM-L6-v2" + batch_size: int = 1000 + + def __post_init__(self): + self.model = SentenceTransformer(self.model_path) + + @property + def name( + self, + ): + return f"SBERT_{self.similarity_metric}" + + def is_batchable(self): + return True + + def score( + self, + ground_truth: str | list[str], + generated_text: str | list[str], + ): + if isinstance(ground_truth, str): + ground_truth = [ground_truth] + if isinstance(generated_text, str): + generated_text = [generated_text] + + gndtruth_emb = self.model.encode( + ground_truth, batch_size=self.batch_size, convert_to_numpy=True + ) + gentext_emb = self.model.encode( + generated_text, batch_size=self.batch_size, convert_to_numpy=True + ) + + if self.similarity_metric == "cosine": + score = np.dot(gndtruth_emb, gentext_emb.T) / ( + norm(gndtruth_emb) * norm(gentext_emb) + ) + + elif self.similarity_metric == "euclidean": + score = norm(gndtruth_emb - gentext_emb, ord=2) + + else: + raise ValueError(f"Unkown metrics {self.similarity_metric}") + + return score + + +__all__ = ["SBERTScore"] diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index 338506565..1ff7f9acc 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -20,9 +20,11 @@ def __post_init__(self): [self.type], use_stemmer=self.use_stemmer ) + @property def name(self): return self.type + @property def is_batchable(self): return False diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index a4c85a8d8..6effae29b 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "806b182e", + "id": "992c777a", "metadata": {}, "outputs": [], "source": [ @@ -13,8 +13,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "4dfd5a18", + "execution_count": 4, + "id": "5eaf4729", "metadata": {}, "outputs": [ { @@ -27,7 +27,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b3e986700a834fa6ab1e327b49e1fae3", + "model_id": "a2231863e61c4ffd8d695c8531a48139", "version_major": 2, "version_minor": 0 }, @@ -42,7 +42,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading cached processed dataset at /home/jjmachan/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa/cache-b6864f61633d1e41.arrow\n" + "Loading cached processed dataset at /home/jjmachan/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa/cache-f3427cd7a8a8674f.arrow\n" ] } ], @@ -51,41 +51,70 @@ "\n", "def format_for_belar(row):\n", " row[\"context\"] = row[\"selftext\"]\n", - " row[\"question\"] = row[\"title\"]\n", - " row['answers'] = row[\"answers\"][\"text\"]\n", + " row[\"prompt\"] = row[\"title\"]\n", + " row['ground_truth'] = row[\"answers\"][\"text\"]\n", " return row\n", " \n", "d = load_dataset(\"eli5\")\n", "ds = d['test_eli5'].map(format_for_belar, batched=False)\n", - "ds = ds.select_columns([\"context\", \"question\", \"answers\"])" + "ds = ds.select_columns([\"context\", \"prompt\", \"ground_truth\"])" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "5c2974c5", + "execution_count": 5, + "id": "a501c296", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(100, 3)" + "(500, 3)" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ds = ds.shuffle().select(range(100))\n", + "ds = ds.shuffle(seed=42).select(range(500))\n", "ds.shape" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "5546153b", + "execution_count": 8, + "id": "763335eb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['context', 'prompt', 'ground_truth']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.column_names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6aff2ae9", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c32c39f5", "metadata": {}, "outputs": [ { @@ -96,7 +125,7 @@ "version_minor": 0 }, "text/plain": [ - "Map: 0%| | 0/100 [00:00