Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions docs/source/package_reference/metrics.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,21 @@
[[autodoc]] metrics.metrics_sample.BLEU
### StringDistance
[[autodoc]] metrics.metrics_sample.StringDistance

### Metrics allowing sampling
#### PassAtK
[[autodoc]] metrics.metrics_sample.PassAtK
#### MajAtN
[[autodoc]] metrics.metrics_sample.MajAtN
#### AvgAtN
[[autodoc]] metrics.metrics_sample.AvgAtN

## LLM-as-a-Judge
### JudgeLM
[[autodoc]] metrics.utils.llm_as_judge.JudgeLM
### JudgeLLM
[[autodoc]] metrics.metrics_sample.JudgeLLM
### JudgeLLMMTBench
[[autodoc]] metrics.metrics_sample.JudgeLLMMTBench
### JudgeLLMMixEval
[[autodoc]] metrics.metrics_sample.JudgeLLMMixEval
### MajAtK
[[autodoc]] metrics.metrics_sample.MajAtK

## LLM-as-a-Judge
### JudgeLM
[[autodoc]] metrics.utils.llm_as_judge.JudgeLM
22 changes: 11 additions & 11 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
MRR,
ROUGE,
AccGoldLikelihood,
AvgAtK,
AvgAtN,
BertScore,
ExactMatches,
Extractiveness,
Expand All @@ -50,7 +50,7 @@
GPassAtK,
JudgeLLMSimpleQA,
LoglikelihoodAcc,
MajAtK,
MajAtN,
PassAtK,
Recall,
StringDistance,
Expand Down Expand Up @@ -85,16 +85,16 @@ class Metrics(Enum):
corpus_level_fn=np.mean,
higher_is_better=True,
)
avg_at_k = SampleLevelMetric(
metric_name="avg@k",
sample_level_fn=AvgAtK(strip_strings=True),
avg_at_n = SampleLevelMetric(
metric_name="avg@n",
sample_level_fn=AvgAtN(strip_strings=True),
category=SamplingMethod.GENERATIVE,
corpus_level_fn=np.mean,
higher_is_better=True,
)
avg_at_k_math = SampleLevelMetric(
metric_name="avg@k",
sample_level_fn=AvgAtK(
avg_at_n_math = SampleLevelMetric(
metric_name="avg@n",
sample_level_fn=AvgAtN(
sample_scoring_function=MultilingualExtractiveMatchMetric(
language=Language.ENGLISH,
gold_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()],
Expand Down Expand Up @@ -365,9 +365,9 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(None),
higher_is_better=True,
)
maj_at_k = SampleLevelMetric(
metric_name="maj@k",
sample_level_fn=MajAtK(),
maj_at_n = SampleLevelMetric(
metric_name="maj@n",
sample_level_fn=MajAtN(),
category=SamplingMethod.GENERATIVE,
corpus_level_fn=np.mean,
higher_is_better=True,
Expand Down
74 changes: 33 additions & 41 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

class SampleLevelComputation(ABC):
@abstractmethod
def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
raise NotImplementedError

def __str__(self):
Expand Down Expand Up @@ -444,7 +444,7 @@ def __init__(self, length_normalization: bool = False):
"""
self.length_normalization = length_normalization

def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
"""Mean reciprocal rank. Measures the quality of a ranking of choices (ordered by correctness).

Args:
Expand Down Expand Up @@ -1129,14 +1129,13 @@ def __init__(
raise ValueError(f"Unknown normalization function: {normalize}")
else:
self.normalize = normalize

self.strip_strings = strip_strings

if callable(sample_scoring_function):
self.compute_score = sample_scoring_function
self.type_exact_match = None
elif isinstance(sample_scoring_function, SampleLevelComputation):
self.score_sample = sample_scoring_function.compute
self.compute_score = sample_scoring_function.compute
self.type_exact_match = None
else:
if isinstance(sample_scoring_function, str):
Expand All @@ -1145,11 +1144,9 @@ def __init__(
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
)
self.type_exact_match = sample_scoring_function
self.score_sample = self.default_sample_scoring
else:
self.type_exact_match = "full"
self.compute_score = self.default_sample_scoring
self.score_sample = self.default_sample_scoring

def preprocess(self, text: str) -> str:
if not text:
Expand All @@ -1176,19 +1173,19 @@ def name_metrics(self) -> str | list[str]:
raise NotImplementedError


class AvgAtK(SamplingMetric, SampleLevelComputation):
def __init__(self, k: int | None = None, **kwargs):
"""Sample score averages all the individual k predictions scores.
class AvgAtN(SamplingMetric, SampleLevelComputation):
def __init__(self, n: int | None = None, **kwargs):
"""Sample score averages all the individual n predictions scores.

Args:
k (int | None): The number of top choices to consider.
n (int | None): Number of samples to generate
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self.k = k
self.attribute_must_be_set = ["k"]
self.n = n
self.attribute_must_be_set = ["n"]

def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
"""Computes the metric over a list of golds and predictions for one single sample.
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
then compares it to the gold.
Expand All @@ -1203,36 +1200,32 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
"""
all_scores = []
for i in range(self.k):
all_scores.append(self.score_sample(doc, model_response[i]))
all_scores.append(self.compute_score(doc, model_response[i]))

avg_score = np.mean(all_scores)
return avg_score

def num_samples(self):
"""Get the number of samples for this metric.

Returns:
int: The number of samples
"""
return self.k
return self.n


class MajAtK(SamplingMetric, SampleLevelComputation):
def __init__(self, k: int | None = None, **kwargs):
class MajAtN(SamplingMetric, SampleLevelComputation):
def __init__(self, n: int | None = None, **kwargs):
"""An exact match class.

Args:
k (int): The number of top choices to consider.
n (int): Total number of samples to generate
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)

self.k = k
self.attribute_must_be_set = ["k"]
self.n = n
self.attribute_must_be_set = ["n"]

def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
"""Computes the metric over a list of golds and predictions for one single sample.
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones, then compares it to the gold.
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
then compares it to the gold.

Args:
doc (Doc): The document containing gold references.
Expand All @@ -1243,39 +1236,38 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
float: Aggregated score over the current sample's items.
"""
if self.k is None:
raise Exception("You did not set the value of k")
raise Exception("You did not set the value of n")

golds = doc.get_golds()

if len(golds) > 1:
raise Exception("Cannot compute maj@k with several golds")
raise Exception("Cannot compute maj@n with several golds")

processed_choices = [self.preprocess(text=g) for g in doc.get_golds()]
new_doc = Doc(
choices=processed_choices,
query=doc.query,
gold_index=list(range(len(processed_choices))),
gold_index=doc.gold_index,
)
all_answers = []
for pred in model_response.final_text[: self.k]:
for pred in model_response.final_text[: self.n]:
all_answers.append(self.preprocess(text=pred))
majority_prediction = max(all_answers, key=all_answers.count)
new_model_response = ModelResponse(
text=[majority_prediction],
)
return self.compute_score(new_doc, new_model_response)
return self.compute_score(new_model_response, new_doc)

def num_samples(self):
return self.k
return self.n


class PassAtK(SamplingMetric, SampleLevelComputation):
def __init__(self, k: int | None = None, n: int | None = None, **kwargs):
"""Computing pass at k
"""Computing pass at k with an estimator

Args:
k (int | None): Threshold for the number of successful attempts.
n (int | None): Number of samples to generate.
k (int | None): Number of correct samples threshold
n (int | None): Total number of samples to generate.
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
Expand Down Expand Up @@ -1320,7 +1312,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
new_model_response = ModelResponse(
text=[cur_pred],
)
all_scores.append(self.score_sample(doc=new_doc, model_response=new_model_response))
all_scores.append(self.compute_score(doc=new_doc, model_response=new_model_response))

return self.pass_at_k(all_scores)

Expand Down Expand Up @@ -1348,8 +1340,8 @@ def __init__(
"""Computing G-Pass@k from http://arxiv.org/abs/2412.13147

Args:
k (Union[int, list[int]] | None): The number of successful attempts to be considered.
n (int | None): Number of samples to generate.
k (Union[int, list[int]] | None): Number of correct samples threshold
n (int | None): Total number of samples to generate.
thresholds (list[float]): Thresholds to control successful attempts in k generate.
name_prefix (str | None): Prefix for the metric name.
**kwargs: Additional keyword arguments.
Expand All @@ -1370,7 +1362,7 @@ def k(self):
def k(self, new_val):
self._k = as_list(new_val)

def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
then aggregates the scores over the samples using a pass@k.
Expand Down Expand Up @@ -1410,7 +1402,7 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
new_model_response = ModelResponse(
text=[cur_pred],
)
all_scores.append(self.score_sample(new_doc, new_model_response))
all_scores.append(self.compute_score(new_doc, new_model_response))

return self.g_pass_at_k(all_scores)

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/metrics/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, sample_params: dict | None):
sample_params_name = "&".join(f"{k}={v}" for k, v in sample_params.items())
if isinstance(self, MetricGrouping):
if hasattr(self.sample_level_fn, "metric_names"):
# this is mostly for the gpass@k metrics
# this is mostly for the gpass@k metrics which redefine submetric names
self.metric_name = self.sample_level_fn.metric_names
else:
self.metric_name = [f"{metric}:{sample_params_name}" for metric in self.metric_name]
Expand Down
Loading
Loading