From 9873234550efd4599d405454122f9ec4c7befdc5 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 12 May 2023 22:24:43 +0530 Subject: [PATCH 1/8] add sbert score --- belar/metrics/similarity.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 belar/metrics/similarity.py diff --git a/belar/metrics/similarity.py b/belar/metrics/similarity.py new file mode 100644 index 000000000..1d2669132 --- /dev/null +++ b/belar/metrics/similarity.py @@ -0,0 +1,53 @@ +from __future__ import annotations +from ast import List + +import typing as t +from dataclasses import dataclass +import numpy as np +from numpy.linalg import norm +from sentence_transformers import SentenceTransformer + +from belar.metrics.base import Metric + +SBERT_METRIC = t.Literal["cosine", "euclidean"] + +@dataclass +class SBERTScore(Metric): + + similarity_metric: t.Literal[SBERT_METRIC] = "cosine" + model_path: str = "all-MiniLM-L6-v2" + batch_size: int = 1000 + + def __post_init__(self): + + self.model = SentenceTransformer(self.model_path) + + def name(self,): + return f"SBERT-{self.similarity_metric}-score" + + def is_batchable(self): + return True + + def score(self, ground_truth: t.Union[str, t.List[str]], generated_text: t.Union[str, t.List[str]]): + + if isinstance(ground_truth, str): + ground_truth = [ground_truth] + if isinstance(generated_text, str): + generated_text = [generated_text] + + gndtruth_emb = self.model.encode(ground_truth, batch_size=self.batch_size, + convert_to_numpy=True) + gentext_emb = self.model.encode(generated_text, batch_size=self.batch_size, + convert_to_numpy=True) + + if self.similarity_metric == "cosine": + score = np.dot(gndtruth_emb, gentext_emb.T) / (norm(gndtruth_emb) * norm(gentext_emb)) + + elif self.similarity_metric == "euclidean": + score = norm(gndtruth_emb - gentext_emb, ord=2) + + else: + raise ValueError(f"Unkown metrics {self.similarity_metric}") + + return score + From ca9676f2c6f33f001466e6f0974101dd5c2bec03 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 12 May 2023 22:25:05 +0530 Subject: [PATCH 2/8] add relative import --- belar/metrics/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index f3b63be9f..7e42db99a 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,2 +1,3 @@ from belar.metrics.base import Evaluation, Metric from belar.metrics.simple import * +from belar.metrics.similarity import SBERTScore \ No newline at end of file From 8b8507b42d6b58441f3c23dab76c88badc319e5c Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 01:25:34 +0530 Subject: [PATCH 3/8] add entailment score --- belar/metrics/factual.py | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 belar/metrics/factual.py diff --git a/belar/metrics/factual.py b/belar/metrics/factual.py new file mode 100644 index 000000000..caeb527b6 --- /dev/null +++ b/belar/metrics/factual.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import typing as t +from dataclasses import dataclass + +from belar.metrics import Metric +from belar.utils import device_check + + +@dataclass +class EntailmentScore(Metric): + """ + Entailment score using ground truth as premise and generated text as hypothesis. + """ + + model_name: str = "typeform/distilbert-base-uncased-mnli" + batch_size: int = 4 + device: t.Literal["cpu", "cuda"] = "cpu" + + def __post_init__(self): + self.device = device_check(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) + self.model.eval() + self.model.to(self.device) + + model_config = self.model.config.to_dict() + assert model_config.get("id2label") or model_config.get( + "label2id" + ), "label-id mapping missing" + if model_config.get("id2label") is None: + self.id2label = {v: k for k, v in model_config.label2id} + else: + self.id2label = model_config["id2label"] + + def name(self): + return "Entailment-Score" + + def is_batchable(self): + return True + + def batch_infer(self, inputs: dict): + predictions = [] + input_ids = inputs["input_ids"] + label2id = {value.lower(): key for key, value in self.id2label.items()} + + for idx in range(0, len(input_ids), self.batch_size): + batch_ids = input_ids[idx : idx + self.batch_size] + output = self.model(batch_ids.to(self.device)) + pred = output.logits.softmax(axis=-1).detach().cpu() + predictions.extend(pred[:, label2id["entailment"]].tolist()) + + return predictions + + def score( + self, + ground_truth: t.Union[str, t.List[str]], + generated_text: t.Union[str, t.List[str]], + ): + """ + ground_truth : premis + generated_text : hypothesis + returns entailement probability score + """ + + if isinstance(ground_truth, str): + ground_truth = [ground_truth] + if isinstance(generated_text, str): + generated_text = [generated_text] + + encodings = self.tokenizer( + ground_truth, generated_text, truncation=True, return_tensors="pt" + ) + + score = self.batch_infer(encodings) + + return score From 56a67bff2eb3d4f4ba702c86820e4c33a761bcc3 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 01:25:48 +0530 Subject: [PATCH 4/8] add relative import --- belar/metrics/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index 7e42db99a..40275f0df 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,3 +1,4 @@ from belar.metrics.base import Evaluation, Metric from belar.metrics.simple import * -from belar.metrics.similarity import SBERTScore \ No newline at end of file +from belar.metrics.similarity import SBERTScore +from belar.metrics.factual import EntailmentScore \ No newline at end of file From fb379f9f993e43704888a0e66938359cd236c1bf Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 01:26:26 +0530 Subject: [PATCH 5/8] add device assigner --- belar/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 belar/utils.py diff --git a/belar/utils.py b/belar/utils.py new file mode 100644 index 000000000..f7fa09631 --- /dev/null +++ b/belar/utils.py @@ -0,0 +1,19 @@ +import torch +import typing as t +from warnings import warn + +DEVICES = t.Literal["cpu", "cuda"] + + +def device_check(device: t.Literal[DEVICES]): + if device == "cuda": + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + warn("cuda not available, using cpu") + elif device == "cpu": + device = torch.device("cpu") + else: + raise ValueError(f"Invalid device {device}") + + return device From 47e830d85eb0384bc6c9c1d4a5266c18f69c24f8 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 02:59:44 +0530 Subject: [PATCH 6/8] fix input type --- belar/metrics/similarity.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/belar/metrics/similarity.py b/belar/metrics/similarity.py index 377356675..77ffa64cf 100644 --- a/belar/metrics/similarity.py +++ b/belar/metrics/similarity.py @@ -10,6 +10,7 @@ SBERT_METRIC = t.Literal["cosine", "euclidean"] + @dataclass class SBERTScore(Metric): similarity_metric: t.Literal[SBERT_METRIC] = "cosine" @@ -25,19 +26,15 @@ def name( ): return f"SBERT_{self.similarity_metric}" + @property def is_batchable(self): return True def score( self, - ground_truth: str | list[str], - generated_text: str | list[str], + ground_truth: list[str], + generated_text: list[str], ): - if isinstance(ground_truth, str): - ground_truth = [ground_truth] - if isinstance(generated_text, str): - generated_text = [generated_text] - gndtruth_emb = self.model.encode( ground_truth, batch_size=self.batch_size, convert_to_numpy=True ) From e363a30d3c9e59842a9943d4d3f8ceb4871f0e1c Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 13 May 2023 03:03:27 +0530 Subject: [PATCH 7/8] change input type --- belar/metrics/factual.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/belar/metrics/factual.py b/belar/metrics/factual.py index caeb527b6..fc41df3dd 100644 --- a/belar/metrics/factual.py +++ b/belar/metrics/factual.py @@ -34,9 +34,11 @@ def __post_init__(self): else: self.id2label = model_config["id2label"] + @property def name(self): - return "Entailment-Score" + return "Entailment_score" + @property def is_batchable(self): return True @@ -55,8 +57,8 @@ def batch_infer(self, inputs: dict): def score( self, - ground_truth: t.Union[str, t.List[str]], - generated_text: t.Union[str, t.List[str]], + ground_truth: t.List[str], + generated_text: t.List[str], ): """ ground_truth : premis @@ -64,11 +66,6 @@ def score( returns entailement probability score """ - if isinstance(ground_truth, str): - ground_truth = [ground_truth] - if isinstance(generated_text, str): - generated_text = [generated_text] - encodings = self.tokenizer( ground_truth, generated_text, truncation=True, return_tensors="pt" ) From a40403837d42a2fa46d651ed3395bd95b9ff3783 Mon Sep 17 00:00:00 2001 From: Shahul ES Date: Sat, 13 May 2023 03:15:09 +0530 Subject: [PATCH 8/8] change DEVICE type Co-authored-by: Jithin James --- belar/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/belar/utils.py b/belar/utils.py index f7fa09631..9564cfa66 100644 --- a/belar/utils.py +++ b/belar/utils.py @@ -2,7 +2,7 @@ import typing as t from warnings import warn -DEVICES = t.Literal["cpu", "cuda"] +DEVICES = ["cpu", "cuda"] def device_check(device: t.Literal[DEVICES]):