Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
8b8d1fe
add langchain loaders to docs
shahules786 Oct 19, 2023
cd7f411
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Oct 20, 2023
5b18325
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Oct 26, 2023
bb8d984
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Oct 26, 2023
9cbb57d
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Oct 29, 2023
479e636
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 7, 2023
3eeb7ea
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 12, 2023
b09003f
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 17, 2023
0d28d62
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 20, 2023
8e7c0c4
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 24, 2023
dd218e1
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 26, 2023
eab12df
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 27, 2023
a0f1b9b
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Nov 29, 2023
2487430
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Dec 1, 2023
fb5064e
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Dec 10, 2023
fd8f458
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Dec 10, 2023
c2f4be8
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Dec 13, 2023
716807b
Added prompt class
tinomaxthayil Dec 17, 2023
22dd97b
fixed Lint errors
tinomaxthayil Dec 17, 2023
9dbcf49
resolve type issue
tinomaxthayil Dec 17, 2023
3247925
hindi prompts for faithfulness
shahules786 Dec 17, 2023
677732b
prompt adaptation
shahules786 Dec 17, 2023
a5fcded
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Dec 18, 2023
0f91d79
merge main
shahules786 Dec 21, 2023
93e184a
added prompt objects to metrics
tinomaxthayil Dec 21, 2023
dfb8f03
added adaption
shahules786 Dec 25, 2023
d090df5
noqa
shahules786 Dec 25, 2023
180ed99
Merge branch 'main' of https://github.com/explodinggradients/ragas in…
shahules786 Dec 25, 2023
f39a9ca
Merge branch 'main' of https://github.com/explodinggradients/ragas
shahules786 Dec 25, 2023
c9b1e6c
Merge branch 'main' into add-metric-prompts
shahules786 Dec 25, 2023
3deb30e
prompt
shahules786 Dec 25, 2023
8c158f5
merge prompts
shahules786 Dec 25, 2023
d793798
added name to prompts
shahules786 Dec 25, 2023
ebc33da
added attr to prompts
shahules786 Dec 25, 2023
92261d9
convert json loader
shahules786 Dec 25, 2023
0b5f190
rmv unused imports
shahules786 Dec 25, 2023
c57c1d7
added json loader
shahules786 Dec 25, 2023
6d9ac0d
moved json loader
shahules786 Dec 25, 2023
b8f9ca1
change json loader path
shahules786 Dec 25, 2023
b4087b8
linting change
shahules786 Dec 25, 2023
838eaf6
merge metrics-prompts
shahules786 Dec 25, 2023
da3a2ed
prompt adaption fix
shahules786 Dec 25, 2023
3693b5a
add ragas cache
shahules786 Dec 25, 2023
d4a423a
add name
shahules786 Dec 25, 2023
c82861c
added adapt and save
shahules786 Dec 26, 2023
7b9f906
removed json_loader
shahules786 Dec 26, 2023
92f3116
json loader without circular import error
shahules786 Dec 26, 2023
29d8667
safe load json
shahules786 Dec 26, 2023
3890fe9
merge metrics-prompts
shahules786 Dec 31, 2023
5d87601
merge metrics-prompts
shahules786 Dec 31, 2023
6da261b
fix tests
shahules786 Dec 31, 2023
8d8d60c
fix tests
shahules786 Dec 31, 2023
fc48c9b
remove loader
shahules786 Dec 31, 2023
258f179
linting
shahules786 Dec 31, 2023
c1f6d6b
adapting logs
shahules786 Dec 31, 2023
62db651
add support to dict type
shahules786 Dec 31, 2023
f68f008
add support to dict type
shahules786 Dec 31, 2023
1fd1fb0
add support to dict type
shahules786 Dec 31, 2023
4ae9113
remove prompts
shahules786 Dec 31, 2023
4835510
accept dict return type
shahules786 Dec 31, 2023
a2aa64f
fix return type
shahules786 Dec 31, 2023
2cb33c8
accept dict return type
shahules786 Dec 31, 2023
dce0eed
change to logging
shahules786 Jan 1, 2024
27ceb4b
return prompt objects
shahules786 Jan 1, 2024
ad0a464
logging and classes
jjmachan Jan 1, 2024
1595a27
fixed default factor error
jjmachan Jan 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
from langchain.embeddings import AzureOpenAIEmbeddings as BaseAzureOpenAIEmbeddings
from langchain.embeddings import OpenAIEmbeddings as BaseOpenAIEmbeddings
from langchain.embeddings import FastEmbedEmbeddings as BaseFastEmbedEmbeddings
from langchain.embeddings import OpenAIEmbeddings as BaseOpenAIEmbeddings
from langchain.schema.embeddings import Embeddings
from pydantic.dataclasses import dataclass

Expand Down Expand Up @@ -47,6 +47,7 @@ def validate_api_key(self):
else:
raise OpenAIKeyNotFound


class FastEmbedEmbeddings(BaseFastEmbedEmbeddings, RagasEmbeddings):
"""
Find the list of supported models at:
Expand All @@ -64,6 +65,7 @@ def validate_api_key(self):
"""
pass


class AzureOpenAIEmbeddings(BaseAzureOpenAIEmbeddings, RagasEmbeddings):
azure_endpoint: t.Optional[str] = None
deployment: t.Optional[str] = None
Expand Down
127 changes: 123 additions & 4 deletions src/ragas/llms/prompt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from __future__ import annotations

import json
import logging
import os
import typing as t

from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import root_validator

from ragas.llms import RagasLLM
from ragas.utils import RAGAS_CACHE_HOME, json_loader


class Prompt(PromptValue):
"""
RagasPrompt is a class that represents a prompt for the ragas metrics.
Prompt is a class that represents a prompt for the ragas metrics.
"""

name: str
instruction: str
examples: t.List[t.Dict[str, t.Any]] = []
input_keys: t.List[str]
output_key: str
output_type: str = "json"
language = "en"

@root_validator
def validate_prompt(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
Expand All @@ -44,10 +51,11 @@ def validate_prompt(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
raise ValueError(
f"example {no+1} does not have the variable {output_key} in the definition"
)
if values["output_type"] == "json":
if values["output_type"].lower() == "json":
try:
if output_key in example:
json.loads(example[output_key])
if isinstance(example[output_key], str):
json.loads(example[output_key])
except ValueError as e:
raise ValueError(
f"{output_key} in example {no+1} is not in valid json format: {e}"
Expand All @@ -64,6 +72,7 @@ def to_string(self) -> str:
# Format the examples to match the Langchain prompt template
for example in self.examples:
for key, value in example.items():
value = json.dumps(value, ensure_ascii=False).encode("utf8").decode()
value = (
value.replace("{", "{{").replace("}", "}}")
if self.output_type.lower() == "json"
Expand All @@ -90,6 +99,7 @@ def get_example_str(self, example_no: int) -> str:
example = self.examples[example_no]
example_str = ""
for key, value in example.items():
value = json.dumps(value, ensure_ascii=False).encode("utf8").decode()
value = (
value.replace("{", "{{").replace("}", "}}")
if self.output_type.lower() == "json"
Expand All @@ -100,7 +110,7 @@ def get_example_str(self, example_no: int) -> str:

def format(self, **kwargs: t.Any) -> ChatPromptTemplate:
"""
Format the RagasPrompt object into a ChatPromptTemplate object to be used in metrics.
Format the Prompt object into a ChatPromptTemplate object to be used in metrics.
"""
if set(self.input_keys) != set(kwargs.keys()):
raise ValueError(
Expand All @@ -109,3 +119,112 @@ def format(self, **kwargs: t.Any) -> ChatPromptTemplate:
prompt = self.to_string()
human_prompt = HumanMessagePromptTemplate.from_template(prompt)
return ChatPromptTemplate.from_messages([human_prompt.format(**kwargs)])

def adapt(
self, language: str, llm: RagasLLM, cache_dir: t.Optional[str] = None
) -> Prompt:
# TODO: Add callbacks
cache_dir = cache_dir if cache_dir else RAGAS_CACHE_HOME
if os.path.exists(os.path.join(cache_dir, language, f"{self.name}.json")):
return self._load(language, self.name, cache_dir)

prompts = []
for example in self.examples:
prompts.extend(
[
str_translation.format(
translate_to=language, input=example.get(key)
)
for key in self.input_keys
]
)
prompts.append(
json_translatation.format(
translate_to=language, input=example.get(self.output_key)
)
if self.output_type.lower() == "json"
else str_translation.format(
translate_to=language, input=example.get(self.output_key)
)
)

results = [result[0].text for result in llm.generate(prompts).generations]
per_example_items = len(self.input_keys) + 1
grouped_results = [
results[i : i + per_example_items]
for i in range(0, len(results), per_example_items)
]
assert len(grouped_results) == len(
self.examples
), "examples and adapted examples must be of equal length"
Comment on lines +157 to +159
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this raise an exception or is it okay

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving it as assert itself - with the assumption that there is no try - except method we can do to catch this.
the con is that the whole operation fails in this case. I think we should raise an Exception and catch it and throw a warning

for i, example in enumerate(grouped_results):
example_dict = {}
example_dict.update(
{k: v for k, v in zip(self.input_keys, example[: len(self.input_keys)])}
)
example_dict[self.output_key] = (
json_loader.safe_load(example[-1], llm)
if self.output_type.lower() == "json"
else example[-1]
)

self.examples[i] = example_dict

self.language = language

# TODO:Validate the prompt after adaptation

return self

def save(self, cache_dir: t.Optional[str] = None) -> None:
cache_dir = cache_dir if cache_dir else RAGAS_CACHE_HOME
cache_dir = os.path.join(cache_dir, self.language)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

cache_path = os.path.join(cache_dir, f"{self.name}.json")
with open(cache_path, "w") as file:
json.dump(self.to_json(), file, indent=4)

@classmethod
def _load(cls, language: str, name: str, cache_dir: str) -> Prompt:
logging.log(logging.INFO, f"Loading {name} from {cache_dir}")
path = os.path.join(cache_dir, language, f"{name}.json")
return cls(**json.load(open(path))["kwargs"])


str_translation = Prompt(
name="str_translation",
instruction="Language translation",
examples=[
{
"translate_to": "hindi",
"input": "Who was Albert Einstein and what is he best known for?",
"output": "अल्बर्ट आइंस्टीन कौन थे और वे किसके लिए सबसे ज्यादा प्रसिद्ध हैं?",
},
],
input_keys=["translate_to", "input"],
output_key="output",
output_type="str",
)

json_translatation = Prompt(
name="json_translation",
instruction="Translate values in given json to target language ",
examples=[
{
"translate_to": "hindi",
"input": """{"statements": [
"Albert Einstein was born in Germany.",
"Albert Einstein was best known for his theory of relativity."
]}""",
"output": """{"statements": [
"अल्बर्ट आइंस्टीन का जन्म जर्मनी में हुआ था।",
"अल्बर्ट आइंस्टीन अपने सापेक्षता के सिद्धांत के लिए सबसे अधिक प्रसिद्ध थे।"
]}""",
}
],
input_keys=["translate_to", "input"],
output_key="output",
output_type="JSON",
)
28 changes: 24 additions & 4 deletions src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

import logging
import typing as t
from dataclasses import dataclass, field

import numpy as np
from datasets import Dataset
from langchain.callbacks.manager import CallbackManager, trace_as_chain_group

from ragas.utils import json_loader
from ragas.llms.prompt import Prompt
from ragas.metrics._answer_similarity import AnswerSimilarity
from ragas.metrics.base import EvaluationMode, MetricWithLLM
from ragas.utils import json_loader

logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks

CORRECTNESS_PROMPT = Prompt(
name="answer_correctness",
instruction="""Extract following from given question and ground truth""",
examples=[
{
Expand Down Expand Up @@ -71,13 +75,16 @@ class AnswerCorrectness(MetricWithLLM):

name: str = "answer_correctness" # type: ignore[reportIncompatibleMethodOverride]
evaluation_mode: EvaluationMode = EvaluationMode.qga # type: ignore[reportIncompatibleMethodOverride]
correctness_prompt: Prompt = field(default_factory=lambda: CORRECTNESS_PROMPT)
batch_size: int = 15
weights: list[float] = field(default_factory=lambda: [0.75, 0.25])
answer_similarity: AnswerSimilarity | None = None

def __post_init__(self: t.Self):
if len(self.weights) != 2:
raise ValueError("Expects a list of two weights. First for factuality, second for semantic similarity")
raise ValueError(
"Expects a list of two weights. First for factuality, second for semantic similarity"
)
if all([w == 0 for w in self.weights]):
raise ValueError("At least one weight must be non-zero")
if not all([w >= 0 for w in self.weights]):
Expand All @@ -88,6 +95,15 @@ def __post_init__(self: t.Self):
llm=self.llm, batch_size=self.batch_size
)

def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
logger.info(f"Adapting AnswerCorrectness metric to {language}")
self.correctness_prompt = self.correctness_prompt.adapt(
language, self.llm, cache_dir
)

def save(self, cache_dir: t.Optional[str] = None) -> None:
self.correctness_prompt.save(cache_dir)

def _score_batch(
self: t.Self,
dataset: Dataset,
Expand All @@ -107,7 +123,9 @@ def _score_batch(
) as batch_group:
for q, a, g in zip(question, answer, ground_truths):
prompts.append(
CORRECTNESS_PROMPT.format(question=q, ground_truth=g[0], answer=a)
self.correctness_prompt.format(
question=q, ground_truth=g[0], answer=a
)
)

result = self.llm.generate(prompts, callbacks=batch_group)
Expand All @@ -121,7 +139,9 @@ def _score_batch(
f1_score = []
for prediction in outputs:
prediction = json_loader.safe_load(prediction[0].text, self.llm)
prediction = prediction if isinstance(prediction, list) else []
prediction = (
prediction if isinstance(prediction, list) else [prediction]
)
if prediction:
prediction = [
item.get(key_map[k], np.nan)
Expand Down
20 changes: 18 additions & 2 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import typing as t
from dataclasses import dataclass, field

Expand All @@ -10,16 +11,19 @@

from ragas.embeddings.base import embedding_factory
from ragas.exceptions import OpenAIKeyNotFound
from ragas.utils import json_loader
from ragas.llms.prompt import Prompt
from ragas.metrics.base import EvaluationMode, MetricWithLLM
from ragas.utils import json_loader

logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks

from ragas.embeddings.base import RagasEmbeddings

QUESTION_GEN = Prompt(
name="question_generation",
instruction="""Generate a question for the given answer and Identify if answer is noncommittal""",
examples=[
{
Expand Down Expand Up @@ -72,6 +76,7 @@ class AnswerRelevancy(MetricWithLLM):

name: str = "answer_relevancy" # type: ignore
evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore
question_generation: Prompt = field(default_factory=lambda: QUESTION_GEN)
batch_size: int = 15
strictness: int = 3
embeddings: RagasEmbeddings = field(default_factory=embedding_factory)
Expand All @@ -83,6 +88,15 @@ def init_model(self):
if self.embeddings.openai_api_key == "no-key":
raise OpenAIKeyNotFound

def adapt(self, language: str, cache_dir: str | None = None) -> None:
logger.info(f"Adapting AnswerRelevancy metric to {language}")
self.question_generation = self.question_generation.adapt(
language, self.llm, cache_dir
)

def save(self, cache_dir: str | None = None) -> None:
self.question_generation.save(cache_dir)

def _score_batch(
self: t.Self,
dataset: Dataset,
Expand All @@ -101,7 +115,9 @@ def _score_batch(
) as batch_group:
prompts = []
for ans, ctx in zip(answers, contexts):
prompts.append(QUESTION_GEN.format(answer=ans, context="\n".join(ctx)))
prompts.append(
self.question_generation.format(answer=ans, context="\n".join(ctx))
)

results = self.llm.generate(
prompts,
Expand Down
3 changes: 3 additions & 0 deletions src/ragas/metrics/_answer_similarity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import typing as t
from dataclasses import dataclass, field

Expand All @@ -19,6 +20,8 @@

from ragas.embeddings.base import RagasEmbeddings

logger = logging.getLogger(__name__)


@dataclass
class AnswerSimilarity(MetricWithLLM):
Expand Down
Loading