diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index 6ad1367e0..8f6d9f56c 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,4 +1,5 @@ from belar.metrics.base import Evaluation, Metric from belar.metrics.similarity import * from belar.metrics.simple import * -from belar.metrics.similarity import SBERTScore \ No newline at end of file +from belar.metrics.similarity import SBERTScore +from belar.metrics.factual import EntailmentScore diff --git a/belar/metrics/factual.py b/belar/metrics/factual.py new file mode 100644 index 000000000..fc41df3dd --- /dev/null +++ b/belar/metrics/factual.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import typing as t +from dataclasses import dataclass + +from belar.metrics import Metric +from belar.utils import device_check + + +@dataclass +class EntailmentScore(Metric): + """ + Entailment score using ground truth as premise and generated text as hypothesis. + """ + + model_name: str = "typeform/distilbert-base-uncased-mnli" + batch_size: int = 4 + device: t.Literal["cpu", "cuda"] = "cpu" + + def __post_init__(self): + self.device = device_check(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) + self.model.eval() + self.model.to(self.device) + + model_config = self.model.config.to_dict() + assert model_config.get("id2label") or model_config.get( + "label2id" + ), "label-id mapping missing" + if model_config.get("id2label") is None: + self.id2label = {v: k for k, v in model_config.label2id} + else: + self.id2label = model_config["id2label"] + + @property + def name(self): + return "Entailment_score" + + @property + def is_batchable(self): + return True + + def batch_infer(self, inputs: dict): + predictions = [] + input_ids = inputs["input_ids"] + label2id = {value.lower(): key for key, value in self.id2label.items()} + + for idx in range(0, len(input_ids), self.batch_size): + batch_ids = input_ids[idx : idx + self.batch_size] + output = self.model(batch_ids.to(self.device)) + pred = output.logits.softmax(axis=-1).detach().cpu() + predictions.extend(pred[:, label2id["entailment"]].tolist()) + + return predictions + + def score( + self, + ground_truth: t.List[str], + generated_text: t.List[str], + ): + """ + ground_truth : premis + generated_text : hypothesis + returns entailement probability score + """ + + encodings = self.tokenizer( + ground_truth, generated_text, truncation=True, return_tensors="pt" + ) + + score = self.batch_infer(encodings) + + return score diff --git a/belar/metrics/similarity.py b/belar/metrics/similarity.py index 377356675..77ffa64cf 100644 --- a/belar/metrics/similarity.py +++ b/belar/metrics/similarity.py @@ -10,6 +10,7 @@ SBERT_METRIC = t.Literal["cosine", "euclidean"] + @dataclass class SBERTScore(Metric): similarity_metric: t.Literal[SBERT_METRIC] = "cosine" @@ -25,19 +26,15 @@ def name( ): return f"SBERT_{self.similarity_metric}" + @property def is_batchable(self): return True def score( self, - ground_truth: str | list[str], - generated_text: str | list[str], + ground_truth: list[str], + generated_text: list[str], ): - if isinstance(ground_truth, str): - ground_truth = [ground_truth] - if isinstance(generated_text, str): - generated_text = [generated_text] - gndtruth_emb = self.model.encode( ground_truth, batch_size=self.batch_size, convert_to_numpy=True ) diff --git a/belar/utils.py b/belar/utils.py new file mode 100644 index 000000000..9564cfa66 --- /dev/null +++ b/belar/utils.py @@ -0,0 +1,19 @@ +import torch +import typing as t +from warnings import warn + +DEVICES = ["cpu", "cuda"] + + +def device_check(device: t.Literal[DEVICES]): + if device == "cuda": + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + warn("cuda not available, using cpu") + elif device == "cpu": + device = torch.device("cpu") + else: + raise ValueError(f"Invalid device {device}") + + return device