Skip to content
4 changes: 4 additions & 0 deletions src/ragas/metrics/collections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ragas.metrics.collections._answer_relevancy import AnswerRelevancy
from ragas.metrics.collections._answer_similarity import AnswerSimilarity
from ragas.metrics.collections._bleu_score import BleuScore
from ragas.metrics.collections._context_precision import ContextPrecision
from ragas.metrics.collections._context_recall import ContextRecall
from ragas.metrics.collections._rouge_score import RougeScore
from ragas.metrics.collections._semantic_similarity import SemanticSimilarity
from ragas.metrics.collections._string import (
Expand All @@ -18,6 +20,8 @@
"AnswerRelevancy",
"AnswerSimilarity",
"BleuScore",
"ContextPrecision",
"ContextRecall",
"DistanceMeasure",
"ExactMatch",
"NonLLMStringSimilarity",
Expand Down
163 changes: 163 additions & 0 deletions src/ragas/metrics/collections/_context_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Context Precision metric v2 - Modern implementation with instructor LLMs."""

import typing as t

import numpy as np
from pydantic import BaseModel, Field

from ragas.metrics.collections.base import BaseMetric
from ragas.metrics.result import MetricResult
from ragas.prompt.metrics.context_precision import context_precision_prompt

if t.TYPE_CHECKING:
from ragas.llms.base import InstructorBaseRagasLLM


class ContextPrecisionVerification(BaseModel):
"""Structured output for context precision verification."""

reason: str = Field(..., description="Reason for the verdict")
verdict: int = Field(..., description="Binary verdict: 1 if useful, 0 if not")


class ContextPrecision(BaseMetric):
"""
Evaluate context precision using Average Precision metric.

This metric evaluates whether all relevant items (contexts) are ranked higher
by checking if each context was useful in arriving at the given answer.

This implementation uses modern instructor LLMs with structured output.
Only supports modern components - legacy wrappers are rejected with clear error messages.

Usage:
>>> from openai import AsyncOpenAI
>>> from ragas.llms.base import instructor_llm_factory
>>> from ragas.metrics.collections import ContextPrecision
>>>
>>> # Setup dependencies
>>> client = AsyncOpenAI()
>>> llm = instructor_llm_factory("openai", client=client, model="gpt-4o-mini")
>>>
>>> # Create metric instance
>>> metric = ContextPrecision(llm=llm)
>>>
>>> # Single evaluation
>>> result = await metric.ascore(
... user_input="What is the capital of France?",
... retrieved_contexts=["Paris is the capital of France.", "London is in England."],
... reference="Paris"
... )
>>> print(f"Score: {result.value}")
>>>
>>> # Batch evaluation
>>> results = await metric.abatch_score([
... {"user_input": "Q1", "retrieved_contexts": ["C1", "C2"], "reference": "A1"},
... {"user_input": "Q2", "retrieved_contexts": ["C1", "C2"], "reference": "A2"},
... ])

Attributes:
llm: Modern instructor-based LLM for verification
name: The metric name
allowed_values: Score range (0.0 to 1.0)
"""

# Type hints for linter (attributes are set in __init__)
llm: "InstructorBaseRagasLLM"

def __init__(
self,
llm: "InstructorBaseRagasLLM",
name: str = "context_precision",
**kwargs,
):
"""Initialize ContextPrecision metric with required components."""
# Set attributes explicitly before calling super()
self.llm = llm

# Call super() for validation
super().__init__(name=name, **kwargs)

async def ascore(
self,
user_input: str,
retrieved_contexts: t.List[str],
reference: str,
) -> MetricResult:
"""
Calculate context precision score asynchronously.

The metric evaluates each retrieved context to determine if it was useful
for arriving at the reference answer, then calculates average precision.

Args:
user_input: The original question
retrieved_contexts: List of retrieved context strings (in ranked order)
reference: The reference answer to evaluate against

Returns:
MetricResult with average precision score (0.0-1.0)
"""
# Handle edge cases
if not retrieved_contexts:
return MetricResult(value=0.0)

if not reference or not user_input:
return MetricResult(value=0.0)

# Evaluate each context
verdicts = []
for context in retrieved_contexts:
# Generate prompt for this context
prompt = context_precision_prompt(
question=user_input, context=context, answer=reference
)

# Get verification from LLM
verification = await self.llm.agenerate(
prompt, ContextPrecisionVerification
)

# Store binary verdict (1 if useful, 0 if not)
verdicts.append(1 if verification.verdict else 0)

# Calculate average precision
score = self._calculate_average_precision(verdicts)

return MetricResult(value=float(score))

def _calculate_average_precision(self, verdict_list: t.List[int]) -> float:
"""
Calculate average precision from list of binary verdicts.

Average Precision formula:
AP = (sum of (precision@k * relevance@k)) / (total relevant items)

Where:
- precision@k = (relevant items in top k) / k
- relevance@k = 1 if item k is relevant, 0 otherwise

Args:
verdict_list: List of binary verdicts (1 for relevant, 0 for not)

Returns:
Average precision score (0.0-1.0), or nan if no relevant items
"""
# Count total relevant items
denominator = sum(verdict_list) + 1e-10

# Calculate sum of precision at each relevant position
numerator = sum(
[
(sum(verdict_list[: i + 1]) / (i + 1)) * verdict_list[i]
for i in range(len(verdict_list))
]
)

score = numerator / denominator

# Return nan if score is invalid
if np.isnan(score):
return np.nan

return score
124 changes: 124 additions & 0 deletions src/ragas/metrics/collections/_context_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Context Recall metric v2 - Class-based implementation with modern components."""

import typing as t

import numpy as np
from pydantic import BaseModel

from ragas.metrics.collections.base import BaseMetric
from ragas.metrics.result import MetricResult
from ragas.prompt.metrics.context_recall import context_recall_prompt

if t.TYPE_CHECKING:
from ragas.llms.base import InstructorBaseRagasLLM


class ContextRecallClassification(BaseModel):
"""Structured output for a single statement classification."""

statement: str
reason: str
attributed: int


class ContextRecallOutput(BaseModel):
"""Structured output for context recall classifications."""

classifications: t.List[ContextRecallClassification]


class ContextRecall(BaseMetric):
"""
Evaluate context recall by classifying if statements can be attributed to context.

This implementation uses modern instructor LLMs with structured output.
Only supports modern components - legacy wrappers are rejected with clear error messages.

Usage:
>>> import instructor
>>> from openai import AsyncOpenAI
>>> from ragas.llms.base import instructor_llm_factory
>>> from ragas.metrics.collections import ContextRecall
>>>
>>> # Setup dependencies
>>> client = AsyncOpenAI()
>>> llm = instructor_llm_factory("openai", client=client, model="gpt-4o-mini")
>>>
>>> # Create metric instance
>>> metric = ContextRecall(llm=llm)
>>>
>>> # Single evaluation
>>> result = await metric.ascore(
... user_input="What is the capital of France?",
... retrieved_contexts=["Paris is the capital of France."],
... reference="Paris is the capital and largest city of France."
... )
>>> print(f"Score: {result.value}")
>>>
>>> # Batch evaluation
>>> results = await metric.abatch_score([
... {"user_input": "Q1", "retrieved_contexts": ["C1"], "reference": "A1"},
... {"user_input": "Q2", "retrieved_contexts": ["C2"], "reference": "A2"},
... ])

Attributes:
llm: Modern instructor-based LLM for classification
name: The metric name
allowed_values: Score range (0.0 to 1.0)
"""

# Type hints for linter (attributes are set in __init__)
llm: "InstructorBaseRagasLLM"

def __init__(
self,
llm: "InstructorBaseRagasLLM",
name: str = "context_recall",
**kwargs,
):
"""Initialize ContextRecall metric with required components."""
# Set attributes explicitly before calling super()
self.llm = llm

# Call super() for validation
super().__init__(name=name, **kwargs)

async def ascore(
self,
user_input: str,
retrieved_contexts: t.List[str],
reference: str,
) -> MetricResult:
"""
Calculate context recall score asynchronously.

Components are guaranteed to be validated and non-None by the base class.

Args:
user_input: The original question
retrieved_contexts: List of retrieved context strings
reference: The reference answer to evaluate

Returns:
MetricResult with recall score (0.0-1.0)
"""
# Combine contexts into a single string
context = "\n".join(retrieved_contexts) if retrieved_contexts else ""

# Generate prompt
prompt = context_recall_prompt(
question=user_input, context=context, answer=reference
)

# Get classifications from LLM
result = await self.llm.agenerate(prompt, ContextRecallOutput)

# Calculate score
if not result.classifications:
return MetricResult(value=np.nan)

# Count attributions
attributions = [c.attributed for c in result.classifications]
score = sum(attributions) / len(attributions) if attributions else np.nan

return MetricResult(value=float(score))
68 changes: 68 additions & 0 deletions src/ragas/prompt/metrics/context_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Context Precision prompt for verifying context usefulness."""

import json


def context_precision_prompt(question: str, context: str, answer: str) -> str:
"""
Generate the prompt for context precision evaluation.

This prompt evaluates whether a given context was useful in arriving at the answer.

Args:
question: The original question
context: A single retrieved context to evaluate
answer: The reference answer to compare against

Returns:
Formatted prompt string for the LLM
"""
# Use json.dumps() to safely escape the strings
safe_question = json.dumps(question)
safe_context = json.dumps(context)
safe_answer = json.dumps(answer)

return f"""Given question, answer and context verify if the context was useful in arriving at the given answer. Give verdict as "1" if useful and "0" if not with json output.

--------EXAMPLES-----------
Example 1
Input: {{
"question": "What can you tell me about Albert Einstein?",
"context": "Albert Einstein (14 March 1879 – 18 April 1955) was a German-born theoretical physicist, widely held to be one of the greatest and most influential scientists of all time. Best known for developing the theory of relativity, he also made important contributions to quantum mechanics, and was thus a central figure in the revolutionary reshaping of the scientific understanding of nature that modern physics accomplished in the first decades of the twentieth century. His mass–energy equivalence formula E = mc2, which arises from relativity theory, has been called 'the world's most famous equation'. He received the 1921 Nobel Prize in Physics 'for his services to theoretical physics, and especially for his discovery of the law of the photoelectric effect', a pivotal step in the development of quantum theory. His work is also known for its influence on the philosophy of science. In a 1999 poll of 130 leading physicists worldwide by the British journal Physics World, Einstein was ranked the greatest physicist of all time. His intellectual achievements and originality have made Einstein synonymous with genius.",
"answer": "Albert Einstein, born on 14 March 1879, was a German-born theoretical physicist, widely held to be one of the greatest and most influential scientists of all time. He received the 1921 Nobel Prize in Physics for his services to theoretical physics."
}}
Output: {{
"reason": "The provided context was indeed useful in arriving at the given answer. The context includes key information about Albert Einstein's life and contributions, which are reflected in the answer.",
"verdict": 1
}}

Example 2
Input: {{
"question": "who won 2020 icc world cup?",
"context": "The 2022 ICC Men's T20 World Cup, held from October 16 to November 13, 2022, in Australia, was the eighth edition of the tournament. Originally scheduled for 2020, it was postponed due to the COVID-19 pandemic. England emerged victorious, defeating Pakistan by five wickets in the final to clinch their second ICC Men's T20 World Cup title.",
"answer": "England"
}}
Output: {{
"reason": "the context was useful in clarifying the situation regarding the 2020 ICC World Cup and indicating that England was the winner of the tournament that was intended to be held in 2020 but actually took place in 2022.",
"verdict": 1
}}

Example 3
Input: {{
"question": "What is the tallest mountain in the world?",
"context": "The Andes is the longest continental mountain range in the world, located in South America. It stretches across seven countries and features many of the highest peaks in the Western Hemisphere. The range is known for its diverse ecosystems, including the high-altitude Andean Plateau and the Amazon rainforest.",
"answer": "Mount Everest."
}}
Output: {{
"reason": "the provided context discusses the Andes mountain range, which, while impressive, does not include Mount Everest or directly relate to the question about the world's tallest mountain.",
"verdict": 0
}}
-----------------------------

Now perform the same with the following input
Input: {{
"question": {safe_question},
"context": {safe_context},
"answer": {safe_answer}
}}
Output: """
Loading
Loading