Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies = [
"tiktoken",
"langchain",
"langchain-core",
"langchain-community",
"langchain_openai",
"openai>1",
"pysbd>=0.3.4",
Expand Down
6 changes: 5 additions & 1 deletion src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def evaluate(
[m.init_model() for m in metrics]

executor = Executor(
is_async=is_async, max_workers=max_workers, raise_exceptions=raise_exceptions
desc="Evaluating",
keep_progress_bar=True,
is_async=is_async,
max_workers=max_workers,
raise_exceptions=raise_exceptions,
)
# new evaluation chain
row_run_managers = []
Expand Down
5 changes: 5 additions & 0 deletions src/ragas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@dataclass
class Executor:
desc: str = "Evaluating"
keep_progress_bar: bool = True
is_async: bool = True
max_workers: t.Optional[int] = None
futures: t.List[t.Any] = field(default_factory=list, repr=False)
Expand Down Expand Up @@ -74,6 +75,8 @@ async def _aresults(self) -> t.List[t.Any]:
asyncio.as_completed(self.futures),
desc=self.desc,
total=len(self.futures),
# whether you want to keep the progress bar after completion
leave=self.keep_progress_bar,
):
r = (-1, np.nan)
try:
Expand Down Expand Up @@ -109,6 +112,8 @@ def results(self) -> t.List[t.Any]:
as_completed(self.futures),
desc=self.desc,
total=len(self.futures),
# whether you want to keep the progress bar after completion
leave=self.keep_progress_bar,
):
r = (-1, np.nan)
try:
Expand Down
19 changes: 2 additions & 17 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from langchain.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI
from langchain.llms import AzureOpenAI, OpenAI, VertexAI
from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI
from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import LLMResult

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
from langchain_core.prompts import ChatPromptTemplate

from ragas.llms.prompt import PromptValue

Expand Down Expand Up @@ -62,20 +61,6 @@ async def agenerate_text(
) -> LLMResult:
...

# TODO: remove after testset generator is refactored
def generate_text_with_hmpt(
self,
prompts: t.List[ChatPromptTemplate],
n: int = 1,
temperature: float = 1e-8,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = [],
) -> LLMResult:
from ragas.llms.prompt import PromptValue

prompt = PromptValue(prompt_str=prompts[0].format())
return self.generate_text(prompt, n, temperature, stop, callbacks)


@dataclass
class LangchainLLMWrapper(BaseRagasLLM):
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_answer_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ragas.metrics.base import EvaluationMode, MetricWithEmbeddings, MetricWithLLM

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain_core.callbacks.base import Callbacks


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_context_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ragas.metrics.base import EvaluationMode, MetricWithLLM

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain_core.callbacks.base import Callbacks

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/critique.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ragas.metrics.base import EvaluationMode, MetricWithLLM

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain_core.callbacks.base import Callbacks

from ragas.llms import BaseRagasLLM

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/testset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ragas.testset.testset_generator import TestsetGenerator
from ragas.testset.generator import TestsetGenerator

__all__ = ["TestsetGenerator"]
27 changes: 17 additions & 10 deletions src/ragas/testset/docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from langchain_core.pydantic_v1 import Field
from llama_index.readers.schema import Document as LlamaindexDocument

from ragas.async_utils import run_async_tasks
from ragas.embeddings.base import BaseRagasEmbeddings
from ragas.executor import Executor

Embedding = t.Union[t.List[float], npt.NDArray[np.float64]]
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -204,22 +204,29 @@ def add_nodes(
assert self.embeddings is not None, "Embeddings must be set"

# NOTE: Adds everything in async mode for now.
embed_tasks = []
docs_to_embed = []
nodes_to_embed = []
# get embeddings for the docs
for n in nodes:
executor = Executor(
desc="embedding nodes",
keep_progress_bar=False,
is_async=True,
raise_exceptions=True,
)
for i, n in enumerate(nodes):
if n.embedding is None:
embed_tasks.append(self.embeddings.aembed_query(n.page_content))
docs_to_embed.append(n)
nodes_to_embed.append(n)
executor.submit(
self.embeddings.aembed_query,
n.page_content,
name=f"embed_node_task[{i}]",
)
else:
self.nodes.append(n)
self.node_map[n.doc_id] = n
self.node_embeddings_list.append(n.embedding)

embeddings = run_async_tasks(
embed_tasks, show_progress=show_progress, progress_bar_desc=desc
)
for n, embedding in zip(docs_to_embed, embeddings):
embeddings = executor.results()
for n, embedding in zip(nodes_to_embed, embeddings):
n.embedding = embedding
self.nodes.append(n)
self.node_map[n.doc_id] = n
Expand Down
73 changes: 41 additions & 32 deletions src/ragas/testset/evolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import logging
import typing as t
from abc import abstractmethod
from collections import namedtuple
from dataclasses import dataclass, field

from fsspec.exceptions import asyncio
from langchain_core.pydantic_v1 import BaseModel
from numpy.random import default_rng

from ragas.llms import BaseRagasLLM
from ragas.llms.prompt import Prompt
from ragas.testset.docstore import Direction, DocumentStore, Node
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
from ragas.testset.prompts import (
Expand All @@ -22,32 +24,27 @@
rng = default_rng()
logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
from ragas.llms import BaseRagasLLM
from ragas.llms.prompt import Prompt


@dataclass
class CurrentNodes:
root_node: Node
nodes: t.List[Node] = field(default_factory=list)


DataRow = namedtuple(
"DataRow",
[
"question",
"context",
"answer",
"question_type",
"evolution_elimination",
],
)
# (question, current_nodes, evolution_type)
EvolutionOutput = t.Tuple[str, CurrentNodes, str]


class DataRow(BaseModel):
question: str
context: str
answer: str
evolution_type: str


@dataclass
class Evolution:
generator_llm: t.Optional[BaseRagasLLM] = None
generator_llm: BaseRagasLLM = t.cast(BaseRagasLLM, None)
docstore: t.Optional[DocumentStore] = None
node_filter: t.Optional[NodeFilter] = None
question_filter: t.Optional[QuestionFilter] = None
Expand All @@ -61,7 +58,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node:

async def aretry_evolve(
self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True
) -> str:
) -> EvolutionOutput:
if update_count:
current_tries += 1
logger.info("retrying evolution: %s times", current_tries)
Expand Down Expand Up @@ -112,22 +109,29 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow:
async def aevolve(self, current_nodes: CurrentNodes) -> DataRow:
# init tries with 0 when first called
current_tries = 0
evolved_question = await self._aevolve(current_tries, current_nodes)
(
evolved_question,
current_nodes,
evolution_type,
) = await self._aevolve(current_tries, current_nodes)

return self.generate_datarow(
question=evolved_question,
current_nodes=current_nodes,
evolution_type=evolution_type,
)

@abstractmethod
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
async def _aevolve(
self, current_tries: int, current_nodes: CurrentNodes
) -> EvolutionOutput:
...

def generate_datarow(
self,
question: str,
current_nodes: CurrentNodes,
question_type: str = "",
evolution_elimination: bool = False,
evolution_type: str,
):
assert self.generator_llm is not None, "generator_llm cannot be None"

Expand All @@ -146,15 +150,16 @@ def generate_datarow(
return DataRow(
question=question,
context=merged_nodes.page_content,
answer=answer,
question_type=question_type,
evolution_elimination=evolution_elimination,
answer="" if answer is None else answer,
evolution_type=evolution_type,
)


@dataclass
class SimpleEvolution(Evolution):
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
async def _aevolve(
self, current_tries: int, current_nodes: CurrentNodes
) -> EvolutionOutput:
assert self.docstore is not None, "docstore cannot be None"
assert self.node_filter is not None, "node filter cannot be None"
assert self.generator_llm is not None, "generator_llm cannot be None"
Expand Down Expand Up @@ -183,7 +188,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
return await self.aretry_evolve(current_tries, current_nodes)
else:
# if valid question
return seed_question
return seed_question, current_nodes, "simple"

def __hash__(self):
return hash(self.__class__.__name__)
Expand All @@ -209,13 +214,15 @@ def init_evolution(self):

@dataclass
class MultiContextEvolution(ComplexEvolution):
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
async def _aevolve(
self, current_tries: int, current_nodes: CurrentNodes
) -> EvolutionOutput:
assert self.docstore is not None, "docstore cannot be None"
assert self.generator_llm is not None, "generator_llm cannot be None"
assert self.question_filter is not None, "question_filter cannot be None"
assert self.se is not None, "simple evolution cannot be None"

simple_question = await self.se._aevolve(current_tries, current_nodes)
simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
logger.debug(
"[MultiContextEvolution] simple question generated: %s", simple_question
)
Expand Down Expand Up @@ -254,20 +261,22 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
return await self.aretry_evolve(current_tries, current_nodes)

return compressed_question
return compressed_question, current_nodes, "multi_context"

def __hash__(self):
return hash(self.__class__.__name__)


@dataclass
class ReasoningEvolution(ComplexEvolution):
async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str:
async def _aevolve(
self, current_tries: int, current_nodes: CurrentNodes
) -> EvolutionOutput:
assert self.generator_llm is not None, "generator_llm cannot be None"
assert self.question_filter is not None, "question_filter cannot be None"
assert self.se is not None, "simple evolution cannot be None"

simple_question = await self.se._aevolve(current_tries, current_nodes)
simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
logger.debug(
"[ReasoningEvolution] simple question generated: %s", simple_question
)
Expand Down Expand Up @@ -304,7 +313,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
)
return await self.aretry_evolve(current_tries, current_nodes)

return reasoning_question
return reasoning_question, current_nodes, "reasoning"

def __hash__(self):
return hash(self.__class__.__name__)
Expand Down
Loading