Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions ragas/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from ragas.metrics.base import Evaluation, Metric
from ragas.metrics.factual import EntailmentScore
from ragas.metrics.similarity import SBERTScore
from ragas.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL
from ragas.metrics.factual import entailment_score, q_square
from ragas.metrics.similarity import bert_score
from ragas.metrics.simple import (
bleu_score,
edit_distance,
edit_ratio,
rouge1,
rouge2,
rougeL,
)

__all__ = [
"Evaluation",
"Metric",
"EntailmentScore",
"SBERTScore",
"BLUE",
"EditDistance",
"EditRatio",
"RougeL",
"Rouge1",
"Rouge2",
"entailment_score",
"bert_score",
"q_square",
"bleu_score",
"edit_distance",
"edit_ratio",
"rouge1",
"rouge2",
"rougeL",
]
6 changes: 3 additions & 3 deletions ragas/metrics/factual.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __post_init__(

@property
def name(self):
return "Q^2"
return "Qsquare"

@property
def is_batchable(self):
Expand Down Expand Up @@ -340,5 +340,5 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs):
return scores


ENTScore = EntailmentScore()
Q2Score = Qsquare()
entailment_score = EntailmentScore()
q_square = Qsquare()
10 changes: 5 additions & 5 deletions ragas/metrics/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
if t.TYPE_CHECKING:
from torch import Tensor

SBERT_METRIC = t.Literal["cosine", "euclidean"]
BERT_METRIC = t.Literal["cosine", "euclidean"]


@dataclass
class SBERTScore(Metric):
similarity_metric: t.Literal[SBERT_METRIC] = "cosine"
class BERTScore(Metric):
similarity_metric: t.Literal[BERT_METRIC] = "cosine"
model_path: str = "all-MiniLM-L6-v2"
batch_size: int = 1000

Expand All @@ -28,7 +28,7 @@ def __post_init__(self):
def name(
self,
):
return f"SBERT_{self.similarity_metric}"
return f"BERTScore_{self.similarity_metric}"

@property
def is_batchable(self):
Expand Down Expand Up @@ -64,4 +64,4 @@ def score(
return score


__all__ = ["SBERTScore"]
bert_score = BERTScore()
12 changes: 6 additions & 6 deletions ragas/metrics/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def score(self, ground_truth: t.List[str], generated_text: t.List[str]):
return score


Rouge1 = ROUGE("rouge1")
Rouge2 = ROUGE("rouge2")
RougeL = ROUGE("rougeL")
BLUE = BLEUScore()
EditDistance = EditScore("distance")
EditRatio = EditScore("ratio")
rouge1 = ROUGE("rouge1")
rouge2 = ROUGE("rouge2")
rougeL = ROUGE("rougeL")
bleu_score = BLEUScore()
edit_distance = EditScore("distance")
edit_ratio = EditScore("ratio")
29 changes: 10 additions & 19 deletions tests/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,27 @@
import typing as t

from datasets import Dataset, load_dataset
from datasets import Dataset, arrow_dataset, load_dataset
from torch.cuda import is_available
from tqdm import tqdm
from utils import print_table, timeit

from ragas.metrics import (
EditDistance,
EditRatio,
EntailmentScore,
Evaluation,
Rouge1,
Rouge2,
RougeL,
SBERTScore,
)
from ragas.metrics import Evaluation, edit_distance, edit_ratio, rouge1, rouge2, rougeL

DEVICE = "cuda" if is_available() else "cpu"
BATCHES = [0, 1]
# init metrics
sbert_score = SBERTScore(similarity_metric="cosine")
entail = EntailmentScore(max_length=512, device=DEVICE)

METRICS = {
"Rouge1": Rouge1,
"Rouge2": Rouge2,
"RougeL": RougeL,
"EditRatio": EditRatio,
"EditDistance": EditDistance,
"Rouge1": rouge1,
"Rouge2": rouge2,
"RougeL": rougeL,
"EditRatio": edit_ratio,
"EditDistance": edit_distance,
# "SBERTScore": sbert_score,
# "EntailmentScore": entail,
}
DS = load_dataset("explodinggradients/eli5-test", split="test_eli5")
assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset"
DS = DS.select(range(100))


def setup() -> t.Iterator[tuple[str, Evaluation, Dataset]]:
Expand Down