diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index 8f6d9f56c..9131315af 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,5 +1,4 @@ from belar.metrics.base import Evaluation, Metric -from belar.metrics.similarity import * -from belar.metrics.simple import * -from belar.metrics.similarity import SBERTScore from belar.metrics.factual import EntailmentScore +from belar.metrics.similarity import SBERTScore +from belar.metrics.simple import * diff --git a/belar/metrics/factual.py b/belar/metrics/factual.py index fc41df3dd..adc0509f7 100644 --- a/belar/metrics/factual.py +++ b/belar/metrics/factual.py @@ -1,9 +1,10 @@ from __future__ import annotations -from transformers import AutoTokenizer, AutoModelForSequenceClassification import typing as t from dataclasses import dataclass +from transformers import AutoModelForSequenceClassification, AutoTokenizer + from belar.metrics import Metric from belar.utils import device_check @@ -15,6 +16,7 @@ class EntailmentScore(Metric): """ model_name: str = "typeform/distilbert-base-uncased-mnli" + max_length: int = 512 batch_size: int = 4 device: t.Literal["cpu", "cuda"] = "cpu" @@ -67,7 +69,12 @@ def score( """ encodings = self.tokenizer( - ground_truth, generated_text, truncation=True, return_tensors="pt" + ground_truth, + generated_text, + truncation=True, + return_tensors="pt", + max_length=self.max_length, + padding="max_length", ) score = self.batch_infer(encodings) diff --git a/belar/metrics/simple.py b/belar/metrics/simple.py index 347b673dc..882445a43 100644 --- a/belar/metrics/simple.py +++ b/belar/metrics/simple.py @@ -1,8 +1,9 @@ from __future__ import annotations -from Levenshtein import distance, ratio import typing as t -from dataclasses import dataclass +from dataclasses import dataclass, field + +from Levenshtein import distance, ratio from nltk.tokenize import word_tokenize from nltk.translate.bleu_score import sentence_bleu from rouge_score import rouge_scorer @@ -14,7 +15,7 @@ @dataclass class BLEU(Metric): - weights: t.List[float] = [0.25, 0.25, 0.25, 0.25] + weights: list[float] = field(default_factory=lambda: [0.25, 0.25, 0.25, 0.25]) smoothing_function = None @property @@ -49,9 +50,11 @@ def __post_init__(self): [self.type], use_stemmer=self.use_stemmer ) + @property def name(self): return self.type + @property def is_batchable(self): return False @@ -65,6 +68,7 @@ def score( return scores +@dataclass class EditScore(Metric): measure: t.Literal["distance", "ratio"] = "ratio" @@ -90,5 +94,8 @@ def score(self, ground_truth: t.List[str], generated_text: t.List[str]): Rouge1 = ROUGE("rouge1") Rouge2 = ROUGE("rouge2") RougeL = ROUGE("rougeL") +BLUE = BLEU() +EditDistance = EditScore("distance") +EditRatio = EditScore("ratio") -__all__ = ["Rouge1", "Rouge2", "RougeL"] +__all__ = ["Rouge1", "Rouge2", "RougeL", "BLEU", "EditDistance", "EditRatio"] diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index 6effae29b..43166fffe 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -2,8 +2,8 @@ "cells": [ { "cell_type": "code", - "execution_count": 7, - "id": "992c777a", + "execution_count": 1, + "id": "54b66a67", "metadata": {}, "outputs": [], "source": [ @@ -13,39 +13,10 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "5eaf4729", + "execution_count": null, + "id": "9710719d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Found cached dataset eli5 (/home/jjmachan/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a2231863e61c4ffd8d695c8531a48139", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/9 [00:00