Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion belar/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from belar.metrics.base import Evaluation, Metric
from belar.metrics.similarity import *
from belar.metrics.simple import *
from belar.metrics.similarity import SBERTScore
from belar.metrics.similarity import SBERTScore
from belar.metrics.factual import EntailmentScore
75 changes: 75 additions & 0 deletions belar/metrics/factual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import typing as t
from dataclasses import dataclass

from belar.metrics import Metric
from belar.utils import device_check


@dataclass
class EntailmentScore(Metric):
"""
Entailment score using ground truth as premise and generated text as hypothesis.
"""

model_name: str = "typeform/distilbert-base-uncased-mnli"
batch_size: int = 4
device: t.Literal["cpu", "cuda"] = "cpu"

def __post_init__(self):
self.device = device_check(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
self.model.eval()
self.model.to(self.device)

model_config = self.model.config.to_dict()
assert model_config.get("id2label") or model_config.get(
"label2id"
), "label-id mapping missing"
if model_config.get("id2label") is None:
self.id2label = {v: k for k, v in model_config.label2id}
else:
self.id2label = model_config["id2label"]

@property
def name(self):
return "Entailment_score"

@property
def is_batchable(self):
return True

def batch_infer(self, inputs: dict):
predictions = []
input_ids = inputs["input_ids"]
label2id = {value.lower(): key for key, value in self.id2label.items()}

for idx in range(0, len(input_ids), self.batch_size):
batch_ids = input_ids[idx : idx + self.batch_size]
output = self.model(batch_ids.to(self.device))
pred = output.logits.softmax(axis=-1).detach().cpu()
predictions.extend(pred[:, label2id["entailment"]].tolist())

return predictions

def score(
self,
ground_truth: t.List[str],
generated_text: t.List[str],
):
"""
ground_truth : premis
generated_text : hypothesis
returns entailement probability score
"""

encodings = self.tokenizer(
ground_truth, generated_text, truncation=True, return_tensors="pt"
)

score = self.batch_infer(encodings)

return score
11 changes: 4 additions & 7 deletions belar/metrics/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

SBERT_METRIC = t.Literal["cosine", "euclidean"]


@dataclass
class SBERTScore(Metric):
similarity_metric: t.Literal[SBERT_METRIC] = "cosine"
Expand All @@ -25,19 +26,15 @@ def name(
):
return f"SBERT_{self.similarity_metric}"

@property
def is_batchable(self):
return True

def score(
self,
ground_truth: str | list[str],
generated_text: str | list[str],
ground_truth: list[str],
generated_text: list[str],
):
if isinstance(ground_truth, str):
ground_truth = [ground_truth]
if isinstance(generated_text, str):
generated_text = [generated_text]

gndtruth_emb = self.model.encode(
ground_truth, batch_size=self.batch_size, convert_to_numpy=True
)
Expand Down
19 changes: 19 additions & 0 deletions belar/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import typing as t
from warnings import warn

DEVICES = ["cpu", "cuda"]


def device_check(device: t.Literal[DEVICES]):
if device == "cuda":
if torch.cuda.is_available():
device = torch.device("cuda")
else:
warn("cuda not available, using cpu")
elif device == "cpu":
device = torch.device("cpu")
else:
raise ValueError(f"Invalid device {device}")

return device