diff --git a/pyproject.toml b/pyproject.toml index d73b3b813..0851232f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ dependencies = [ "tiktoken", "langchain", "langchain-core", + "langchain-community", "langchain_openai", "openai>1", "pysbd>=0.3.4", diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index 1a1787259..6dfc90403 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -154,7 +154,11 @@ def evaluate( [m.init_model() for m in metrics] executor = Executor( - is_async=is_async, max_workers=max_workers, raise_exceptions=raise_exceptions + desc="Evaluating", + keep_progress_bar=True, + is_async=is_async, + max_workers=max_workers, + raise_exceptions=raise_exceptions, ) # new evaluation chain row_run_managers = [] diff --git a/src/ragas/executor.py b/src/ragas/executor.py index 43ce7bad0..079fca4f1 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -10,6 +10,7 @@ @dataclass class Executor: desc: str = "Evaluating" + keep_progress_bar: bool = True is_async: bool = True max_workers: t.Optional[int] = None futures: t.List[t.Any] = field(default_factory=list, repr=False) @@ -74,6 +75,8 @@ async def _aresults(self) -> t.List[t.Any]: asyncio.as_completed(self.futures), desc=self.desc, total=len(self.futures), + # whether you want to keep the progress bar after completion + leave=self.keep_progress_bar, ): r = (-1, np.nan) try: @@ -109,6 +112,8 @@ def results(self) -> t.List[t.Any]: as_completed(self.futures), desc=self.desc, total=len(self.futures), + # whether you want to keep the progress bar after completion + leave=self.keep_progress_bar, ): r = (-1, np.nan) try: diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 6224a86ad..8c2f2af3f 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -4,14 +4,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI -from langchain.llms import AzureOpenAI, OpenAI, VertexAI +from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI +from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI from langchain_core.language_models import BaseLanguageModel from langchain_core.outputs import LLMResult if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks - from langchain_core.prompts import ChatPromptTemplate from ragas.llms.prompt import PromptValue @@ -62,20 +61,6 @@ async def agenerate_text( ) -> LLMResult: ... - # TODO: remove after testset generator is refactored - def generate_text_with_hmpt( - self, - prompts: t.List[ChatPromptTemplate], - n: int = 1, - temperature: float = 1e-8, - stop: t.Optional[t.List[str]] = None, - callbacks: Callbacks = [], - ) -> LLMResult: - from ragas.llms.prompt import PromptValue - - prompt = PromptValue(prompt_str=prompts[0].format()) - return self.generate_text(prompt, n, temperature, stop, callbacks) - @dataclass class LangchainLLMWrapper(BaseRagasLLM): diff --git a/src/ragas/metrics/_answer_similarity.py b/src/ragas/metrics/_answer_similarity.py index d59e8c146..f6a3669fb 100644 --- a/src/ragas/metrics/_answer_similarity.py +++ b/src/ragas/metrics/_answer_similarity.py @@ -10,7 +10,7 @@ from ragas.metrics.base import EvaluationMode, MetricWithEmbeddings, MetricWithLLM if t.TYPE_CHECKING: - from langchain.callbacks.base import Callbacks + from langchain_core.callbacks.base import Callbacks logger = logging.getLogger(__name__) diff --git a/src/ragas/metrics/_context_relevancy.py b/src/ragas/metrics/_context_relevancy.py index 827b0fb4b..de2ebd464 100644 --- a/src/ragas/metrics/_context_relevancy.py +++ b/src/ragas/metrics/_context_relevancy.py @@ -11,7 +11,7 @@ from ragas.metrics.base import EvaluationMode, MetricWithLLM if t.TYPE_CHECKING: - from langchain.callbacks.base import Callbacks + from langchain_core.callbacks.base import Callbacks logger = logging.getLogger(__name__) diff --git a/src/ragas/metrics/critique.py b/src/ragas/metrics/critique.py index 8c6e63307..15c9ed522 100644 --- a/src/ragas/metrics/critique.py +++ b/src/ragas/metrics/critique.py @@ -12,7 +12,7 @@ from ragas.metrics.base import EvaluationMode, MetricWithLLM if t.TYPE_CHECKING: - from langchain.callbacks.base import Callbacks + from langchain_core.callbacks.base import Callbacks from ragas.llms import BaseRagasLLM diff --git a/src/ragas/testset/__init__.py b/src/ragas/testset/__init__.py index b065c3f9b..61be28859 100644 --- a/src/ragas/testset/__init__.py +++ b/src/ragas/testset/__init__.py @@ -1,3 +1,3 @@ -from ragas.testset.testset_generator import TestsetGenerator +from ragas.testset.generator import TestsetGenerator __all__ = ["TestsetGenerator"] diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 44e88a6b1..30d65d27a 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -14,8 +14,8 @@ from langchain_core.pydantic_v1 import Field from llama_index.readers.schema import Document as LlamaindexDocument -from ragas.async_utils import run_async_tasks from ragas.embeddings.base import BaseRagasEmbeddings +from ragas.executor import Executor Embedding = t.Union[t.List[float], npt.NDArray[np.float64]] logger = logging.getLogger(__name__) @@ -204,22 +204,29 @@ def add_nodes( assert self.embeddings is not None, "Embeddings must be set" # NOTE: Adds everything in async mode for now. - embed_tasks = [] - docs_to_embed = [] + nodes_to_embed = [] # get embeddings for the docs - for n in nodes: + executor = Executor( + desc="embedding nodes", + keep_progress_bar=False, + is_async=True, + raise_exceptions=True, + ) + for i, n in enumerate(nodes): if n.embedding is None: - embed_tasks.append(self.embeddings.aembed_query(n.page_content)) - docs_to_embed.append(n) + nodes_to_embed.append(n) + executor.submit( + self.embeddings.aembed_query, + n.page_content, + name=f"embed_node_task[{i}]", + ) else: self.nodes.append(n) self.node_map[n.doc_id] = n self.node_embeddings_list.append(n.embedding) - embeddings = run_async_tasks( - embed_tasks, show_progress=show_progress, progress_bar_desc=desc - ) - for n, embedding in zip(docs_to_embed, embeddings): + embeddings = executor.results() + for n, embedding in zip(nodes_to_embed, embeddings): n.embedding = embedding self.nodes.append(n) self.node_map[n.doc_id] = n diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 09e62eaa3..0f9002f89 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -3,12 +3,14 @@ import logging import typing as t from abc import abstractmethod -from collections import namedtuple from dataclasses import dataclass, field from fsspec.exceptions import asyncio +from langchain_core.pydantic_v1 import BaseModel from numpy.random import default_rng +from ragas.llms import BaseRagasLLM +from ragas.llms.prompt import Prompt from ragas.testset.docstore import Direction, DocumentStore, Node from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( @@ -22,10 +24,6 @@ rng = default_rng() logger = logging.getLogger(__name__) -if t.TYPE_CHECKING: - from ragas.llms import BaseRagasLLM - from ragas.llms.prompt import Prompt - @dataclass class CurrentNodes: @@ -33,21 +31,20 @@ class CurrentNodes: nodes: t.List[Node] = field(default_factory=list) -DataRow = namedtuple( - "DataRow", - [ - "question", - "context", - "answer", - "question_type", - "evolution_elimination", - ], -) +# (question, current_nodes, evolution_type) +EvolutionOutput = t.Tuple[str, CurrentNodes, str] + + +class DataRow(BaseModel): + question: str + context: str + answer: str + evolution_type: str @dataclass class Evolution: - generator_llm: t.Optional[BaseRagasLLM] = None + generator_llm: BaseRagasLLM = t.cast(BaseRagasLLM, None) docstore: t.Optional[DocumentStore] = None node_filter: t.Optional[NodeFilter] = None question_filter: t.Optional[QuestionFilter] = None @@ -61,7 +58,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node: async def aretry_evolve( self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True - ) -> str: + ) -> EvolutionOutput: if update_count: current_tries += 1 logger.info("retrying evolution: %s times", current_tries) @@ -112,22 +109,29 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow: async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: # init tries with 0 when first called current_tries = 0 - evolved_question = await self._aevolve(current_tries, current_nodes) + ( + evolved_question, + current_nodes, + evolution_type, + ) = await self._aevolve(current_tries, current_nodes) + return self.generate_datarow( question=evolved_question, current_nodes=current_nodes, + evolution_type=evolution_type, ) @abstractmethod - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: ... def generate_datarow( self, question: str, current_nodes: CurrentNodes, - question_type: str = "", - evolution_elimination: bool = False, + evolution_type: str, ): assert self.generator_llm is not None, "generator_llm cannot be None" @@ -146,15 +150,16 @@ def generate_datarow( return DataRow( question=question, context=merged_nodes.page_content, - answer=answer, - question_type=question_type, - evolution_elimination=evolution_elimination, + answer="" if answer is None else answer, + evolution_type=evolution_type, ) @dataclass class SimpleEvolution(Evolution): - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: assert self.docstore is not None, "docstore cannot be None" assert self.node_filter is not None, "node filter cannot be None" assert self.generator_llm is not None, "generator_llm cannot be None" @@ -183,7 +188,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str return await self.aretry_evolve(current_tries, current_nodes) else: # if valid question - return seed_question + return seed_question, current_nodes, "simple" def __hash__(self): return hash(self.__class__.__name__) @@ -209,13 +214,15 @@ def init_evolution(self): @dataclass class MultiContextEvolution(ComplexEvolution): - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: assert self.docstore is not None, "docstore cannot be None" assert self.generator_llm is not None, "generator_llm cannot be None" assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question = await self.se._aevolve(current_tries, current_nodes) + simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) @@ -254,7 +261,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str current_nodes = self.se._get_more_adjacent_nodes(current_nodes) return await self.aretry_evolve(current_tries, current_nodes) - return compressed_question + return compressed_question, current_nodes, "multi_context" def __hash__(self): return hash(self.__class__.__name__) @@ -262,12 +269,14 @@ def __hash__(self): @dataclass class ReasoningEvolution(ComplexEvolution): - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: assert self.generator_llm is not None, "generator_llm cannot be None" assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question = await self.se._aevolve(current_tries, current_nodes) + simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[ReasoningEvolution] simple question generated: %s", simple_question ) @@ -304,7 +313,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str ) return await self.aretry_evolve(current_tries, current_nodes) - return reasoning_question + return reasoning_question, current_nodes, "reasoning" def __hash__(self): return hash(self.__class__.__name__) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 29714a9a9..d1dd6bb3a 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -5,8 +5,8 @@ from dataclasses import dataclass import pandas as pd -from langchain.embeddings import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI +from langchain_openai.embeddings import OpenAIEmbeddings from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings import BaseRagasEmbeddings @@ -31,16 +31,9 @@ class TestDataset: def to_pandas(self) -> pd.DataFrame: data_samples = [] for data in self.test_data: - question_type = data.question_type - data = { - "question": data.question, - "context": data.context, - "answer": "" if data.answer is None else data.answer, - "question_type": question_type, - "episode_done": True, - "evolution_elimination": data.evolution_elimination, - } - data_samples.append(data) + data_dict = dict(data) + data_dict["episode_done"] = True + data_samples.append(data_dict) return pd.DataFrame.from_records(data_samples) @@ -90,14 +83,18 @@ def generate_with_llamaindex_docs( documents: t.Sequence[LlamaindexDocument], test_size: int, distributions: Distributions = {}, - **kwargs, + show_debug_logs=False, ): # chunk documents and add to docstore self.docstore.add_documents( [Document.from_llamaindex_document(doc) for doc in documents] ) - return self.generate(test_size=test_size, distributions=distributions) + return self.generate( + test_size=test_size, + distributions=distributions, + show_debug_logs=show_debug_logs, + ) def generate( self, test_size: int, distributions: Distributions = {}, show_debug_logs=False @@ -123,7 +120,12 @@ def generate( patch_logger("ragas.testset.evolutions", logging.DEBUG) - exec = Executor(desc="Generating", raise_exceptions=True, is_async=True) + exec = Executor( + desc="Generating", + keep_progress_bar=True, + raise_exceptions=True, + is_async=True, + ) current_nodes = [ CurrentNodes(root_node=n, nodes=[n]) diff --git a/src/ragas/testset/prompts.py b/src/ragas/testset/prompts.py index 16269f362..d2addc6bb 100644 --- a/src/ragas/testset/prompts.py +++ b/src/ragas/testset/prompts.py @@ -1,4 +1,4 @@ -from langchain.prompts import HumanMessagePromptTemplate +from langchain_core.prompts import HumanMessagePromptTemplate from ragas.llms.prompt import Prompt diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py deleted file mode 100644 index 0d9481b83..000000000 --- a/src/ragas/testset/testset_generator.py +++ /dev/null @@ -1,547 +0,0 @@ -from __future__ import annotations - -import logging -import typing as t -from collections import defaultdict, namedtuple -from dataclasses import dataclass - -import numpy as np -import numpy.testing as npt -import pandas as pd -from langchain.chat_models import ChatOpenAI -from langchain.embeddings import OpenAIEmbeddings -from langchain.embeddings.base import Embeddings -from langchain.prompts import ChatPromptTemplate -from langchain.schema.document import Document as LangchainDocument -from numpy.random import default_rng -from tqdm.notebook import tqdm - -try: - from llama_index.indices.query.embedding_utils import get_top_k_embeddings - from llama_index.node_parser import SimpleNodeParser - from llama_index.readers.schema import Document as LlamaindexDocument - from llama_index.schema import BaseNode -except ImportError: - raise ImportError( - "llama_index must be installed to use this function. " - "Please, install it with `pip install llama_index`." - ) - -from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper -from ragas.llms.json_load import load_as_json -from ragas.testset.prompts import ( - ANSWER_FORMULATE, - COMPRESS_QUESTION, - CONDITIONAL_QUESTION, - CONTEXT_FORMULATE, - CONVERSATION_QUESTION, - EVOLUTION_ELIMINATION, - FILTER_QUESTION, - INFORMAL_QUESTION, - MULTICONTEXT_QUESTION, - REASONING_QUESTION, - REWRITE_QUESTION, - SCORE_CONTEXT, - SEED_QUESTION, - TABLE_QA, -) - -DEFAULT_TEST_DISTRIBUTION = { - "simple": 0.4, - "reasoning": 0.3, - "multi_context": 0.0, - "conditional": 0.3, -} - -question_deep_map = { - "reasoning": "_reasoning_question", - "conditional": "_condition_question", -} - -DataRow = namedtuple( - "DataRow", - [ - "seed_question", - "question", - "context", - "answer", - "question_type", - "evolution_elimination", - ], -) - -logger = logging.getLogger(__name__) - - -@dataclass -class TestDataset: - """ - TestDataset class - """ - - test_data: t.List[DataRow] - - def to_pandas(self) -> pd.DataFrame: - data_samples = [] - for data in self.test_data: - is_conv = len(data.context) > 1 - question_type = data.question_type - data = [ - { - "seed_question": seed, - "question": qstn, - "context": ctx, - "answer": ans, - "question_type": question_type, - "episode_done": True, - "evolution_elimination": data.evolution_elimination, - } - for seed, qstn, ctx, ans in zip( - data.seed_question, data.question, data.context, data.answer - ) - ] - if is_conv: - data[0].update({"episode_done": False}) - data_samples.extend(data) - - return pd.DataFrame.from_records(data_samples) - - -class TestsetGenerator: - - """ - Ragas Test Set Generator - - Attributes - ---------- - generator_llm: BaseRagasLLM - LLM used for all the generator operations in the TestGeneration paradigm. - critique_llm: BaseRagasLLM - LLM used for all the filtering and scoring operations in TestGeneration - paradigm. - embeddings_model: Embeddings - Embeddings used for vectorizing nodes when required. - chat_qa: float - Determines the fraction of conversational questions the resulting test set. - chunk_size: int - The chunk size of nodes created from data. - test_distribution : dict - Distribution of different types of questions to be generated from given - set of documents. Defaults to {"easy":0.1, "reasoning":0.4, "conversation":0.5} - """ - - def __init__( - self, - generator_llm: BaseRagasLLM, - critic_llm: BaseRagasLLM, - embeddings_model: Embeddings, - testset_distribution: t.Optional[t.Dict[str, float]] = None, - chat_qa: float = 0.0, - chunk_size: int = 356, - seed: int = 42, - ) -> None: - self.generator_llm = generator_llm - self.critic_llm = critic_llm - self.embedding_model = embeddings_model - testset_distribution = testset_distribution or DEFAULT_TEST_DISTRIBUTION - npt.assert_almost_equal( - 1, - sum(testset_distribution.values()), - err_msg="Sum of distribution should be 1", - ) - - probs = np.cumsum(list(testset_distribution.values())) - types = testset_distribution.keys() - self.testset_distribution = dict(zip(types, probs)) - - self.chat_qa = chat_qa - self.chunk_size = chunk_size - self.threshold = 7.5 - self.max_fixes = 2 - self.rng = default_rng(seed) - - @classmethod - def from_default( - cls, - openai_generator_llm: str = "gpt-3.5-turbo", - openai_filter_llm: str = "gpt-4", - chat_qa: float = 0.3, - chunk_size: int = 512, - testset_distribution: dict = DEFAULT_TEST_DISTRIBUTION, - ): - generator_llm = LangchainLLMWrapper( - langchain_llm=ChatOpenAI(model=openai_generator_llm) # type: ignore - ) - critic_llm = LangchainLLMWrapper( - langchain_llm=ChatOpenAI(model=openai_filter_llm) # type: ignore - ) - embeddings_model = OpenAIEmbeddings() # type: ignore - return cls( - generator_llm=generator_llm, - critic_llm=critic_llm, - embeddings_model=embeddings_model, - chat_qa=chat_qa, - chunk_size=chunk_size, - testset_distribution=testset_distribution, - ) - - def _get_evolve_type(self) -> str: - """ - Decides question evolution type based on probability - """ - prob = self.rng.uniform(0, 1) - return next( - ( - key - for key in self.testset_distribution.keys() - if prob <= self.testset_distribution[key] - ), - "simple", - ) - - def _filter_context(self, context: str) -> t.Dict: - """ - context: str - The input context - - Checks if the context is has information worthy of framing a question - """ - human_prompt = SCORE_CONTEXT.format(context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.critic_llm.generate_text_with_hmpt(prompts=[prompt]) - output = results.generations[0][0].text.strip() - output = load_as_json(output) - output.update({"score": output.get("score", 0) >= self.threshold}) - return output - - def _seed_question(self, context: str, is_table_present: bool) -> t.List[str]: - if is_table_present: - human_prompt = TABLE_QA.format(context=context) - else: - from ragas.testset.prompts import demonstrations - - sample = self.rng.choice(demonstrations, 1)[0] # type: ignore - questions = self.rng.choice(sample["questions"], 2, replace=False) - questions = ( - "{" - + str( - {k: v for dic in questions.tolist() for k, v in dic.items()} - ).replace("'", '"') - + "}" - ) - demo = f'Context:{sample["context"]}\nQuestions:{questions}' - human_prompt = SEED_QUESTION.format(demonstration=demo, context=context) - - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text - if is_table_present: - return [results] - else: - results = load_as_json(results) - return [v for v in results.values()] - - def _filter_question(self, question: str) -> bool: - human_prompt = FILTER_QUESTION.format(question=question) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - - results = self.critic_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("%s", json_results) - return json_results.get("verdict") != "No" - - def _rewrite_question(self, question: str, context: str) -> str: - human_prompt = REWRITE_QUESTION.format(question=question, context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text.strip() - return results - - def _evolution_elimination(self, question1: str, question2: str) -> bool: - human_prompt = EVOLUTION_ELIMINATION.format( - question1=question1, question2=question2 - ) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - - results = self.critic_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - return json_results.get("verdict") != "Not Equal" - - def _reasoning_question(self, question: str, context: str) -> str: - return self._qc_template(REASONING_QUESTION, question, context) - - def _condition_question(self, question: str, context: str) -> str: - return self._qc_template(CONDITIONAL_QUESTION, question, context) - - def _multicontext_question( - self, question: str, context1: str, context2: str - ) -> str: - human_prompt = MULTICONTEXT_QUESTION.format( - question=question, context1=context1, context2=context2 - ) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - return results.generations[0][0].text.strip() - - def _compress_question(self, question: str) -> str: - return self._question_transformation(COMPRESS_QUESTION, question=question) - - def _conversational_question(self, question: str) -> str: - return self._question_transformation(CONVERSATION_QUESTION, question=question) - - def _question_transformation(self, prompt, question: str) -> str: - human_prompt = prompt.format(question=question) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - return results.generations[0][0].text.strip() - - def _qc_template(self, prompt, question, context) -> str: - human_prompt = prompt.format(question=question, context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - return results.generations[0][0].text.strip() - - def _generate_answer(self, question: str, context: list[str]) -> t.List[str]: - return [ - self._qc_template(ANSWER_FORMULATE, qstn, context[i]) - for i, qstn in enumerate(question.split("\n")) - ] - - def _generate_context(self, question: str, text_chunk: str) -> t.List[str]: - return [ - self._qc_template(CONTEXT_FORMULATE, qstn, text_chunk) - for qstn in question.split("\n") - ] - - def _question_transform(self, question: str) -> str: - output = [] - for qstn in question.split("\n"): - human_prompt = INFORMAL_QUESTION.format(question=qstn) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - output.append(results.generations[0][0].text.strip()) - - return "\n".join(output) - - def _remove_nodes( - self, available_indices: list[BaseNode], node_idx: list - ) -> t.List[BaseNode]: - for idx in node_idx: - if idx in available_indices: - available_indices.remove(idx) - return available_indices - - def _generate_doc_nodes_map( - self, document_nodes: t.List[BaseNode] - ) -> t.Dict[str, t.List[BaseNode]]: - doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list[BaseNode]) - for node in document_nodes: - file_name = node.metadata.get("file_name") - if file_name: - doc_nodes_map[file_name].append(node) - - return doc_nodes_map # type: ignore - - def _get_neighbour_node( - self, - node: BaseNode, - related_nodes: list[BaseNode], - max_tokens=1000, - after: bool = True, - ) -> t.List[BaseNode]: - if len(related_nodes) < 2: - logger.warn("No neighbors exists") - return [node] - idx = related_nodes.index(node) if node in related_nodes else [] - if idx: - tokens = 0 - nodes = [] - inc = 1 if after else -1 - while tokens < max_tokens and idx >= 0 and idx < len(related_nodes): # type: ignore - nodes.append(related_nodes[idx]) # type: ignore - idx += inc # type: ignore - # TODO: replace split with tikitoken - tokens += len(related_nodes[idx].get_content().split()) # type: ignore - - return nodes if after else nodes[::-1] - return [node] - - def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]: - embeddings = {} - for node in nodes: - embeddings[node.id_] = list( - self.embedding_model.embed_query(node.get_content()) - ) - - return embeddings - - def generate( - self, - documents: list[LlamaindexDocument] | list[LangchainDocument], - test_size: int, - ) -> TestDataset: - if not isinstance(documents[0], (LlamaindexDocument, LangchainDocument)): - raise ValueError( - "Testset Generatation only supports LlamaindexDocuments or LangchainDocuments" # noqa - ) - - if isinstance(documents[0], LangchainDocument): - # cast to LangchainDocument since its the only case here - documents = t.cast(t.List[LangchainDocument], documents) - documents = [ - LlamaindexDocument.from_langchain_format(doc) for doc in documents - ] - # Convert documents into nodes - # TODO: modify this to - # each node should contain docs of preffered chunk size - # append document to provide enough context - # Use metadata for this. - node_parser = SimpleNodeParser.from_defaults( - chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True - ) - documents = t.cast(t.List[LlamaindexDocument], documents) - document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents( - documents=documents - ) - # maximum 1 seed question per node - if test_size > len(document_nodes): - raise ValueError( - """Maximum possible number of samples exceeded, - reduce test_size or add more documents""" - ) - - available_nodes = document_nodes - doc_nodes_map = self._generate_doc_nodes_map(document_nodes) - count = 0 - samples = [] - - pbar = tqdm(total=test_size) - while count < test_size and available_nodes != []: - evolve_type = self._get_evolve_type() - curr_node = self.rng.choice(np.array(available_nodes), size=1)[0] - available_nodes = self._remove_nodes(available_nodes, [curr_node]) - - neighbor_nodes = doc_nodes_map[curr_node.metadata["file_name"]] - - # Append multiple nodes randomly to remove chunking bias - if len(curr_node.get_content().split()) < self.chunk_size: - size = self.chunk_size - len(curr_node.get_content().split()) - nodes = self._get_neighbour_node( - curr_node, neighbor_nodes, max_tokens=size, after=False - ) - else: - nodes = [curr_node] - - text_chunk = "\n".join([node.get_content() for node in nodes]) - logger.debug( - "Len of text chunks %s %s", len(nodes), len(text_chunk.split()) - ) - context_filter = self._filter_context(text_chunk) - if not context_filter.get("score"): - continue - - # is_table_qa = context_filter.get("is_table_present", False) - is_table_qa = False - seed_questions = self._seed_question(text_chunk, is_table_qa) - evolve_type = ( - "simple" - if ((evolve_type == "multi_context") and (is_table_qa)) - else evolve_type - ) - logger.debug("seed question %s", seed_questions) - for seed_question in seed_questions: - is_valid_question = self._filter_question(seed_question) - tries = 1 - - while tries < self.max_fixes and not is_valid_question: - nodes = self._get_neighbour_node( - nodes[0], neighbor_nodes, max_tokens=500, after=False - ) - text_chunk = "\n".join([node.get_content() for node in nodes]) - seed_question = self._rewrite_question( - question=seed_question, context=text_chunk - ) - logger.debug("rewritten question %s", seed_question) - is_valid_question = self._filter_question(seed_question) - tries += 1 - - if not is_valid_question: - continue - - if evolve_type == "multi_context": - # Find most similar chunk in same document - # TODO: handle cases where neighbour nodes is null, ie multi context across documents - # First preference - nodes from same document, second preference - other docs - node_embedding = self._embed_nodes([nodes[-1]]) - neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes) - neighbor_emb = self._embed_nodes(neighbor_nodes) - - _, indices = get_top_k_embeddings( - list(node_embedding.values())[0], - list(neighbor_emb.values()), - similarity_cutoff=self.threshold / 10, - ) - if indices: - # type cast indices from list[Any] to list[int] - indices = t.cast(t.List[int], indices) - best_neighbor = neighbor_nodes[indices[0]] - question = self._multicontext_question( - question=seed_question, - context1=text_chunk, - context2=best_neighbor.get_content(), - ) - text_chunk = "\n".join( - [text_chunk, best_neighbor.get_content()] - ) - else: - continue - - # for reasoning and conditional modes, evolve question with the - # functions from question_deep_map - else: - evolve_fun = question_deep_map.get(evolve_type) - question = ( - getattr(self, evolve_fun)(seed_question, text_chunk) - if evolve_fun - else seed_question - ) - - # compress question or convert into conversational questions - if evolve_type != "simple": - prob = self.rng.uniform(0, 1) - if self.chat_qa and prob <= self.chat_qa: - question = self._conversational_question(question=question) - else: - question = self._compress_question(question=question) - - is_valid_question = ( - self._filter_question(question) if evolve_type != "simple" else True - ) - if evolve_type != "simple": - evolution_elimination = self._evolution_elimination( - question1=seed_question, question2=question - ) - else: - evolution_elimination = None - - if is_valid_question: - # question = self._question_transform(question) - context = self._generate_context(question, text_chunk) - answer = self._generate_answer(question, context) - samples.append( - DataRow( - [seed_question], - question.split("\n"), - context, - answer, - evolve_type, - evolution_elimination, - ) - ) - count += 1 - pbar.update(count) - - return TestDataset(test_data=samples) diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index 09b6b0f03..c83ce48f9 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -2,7 +2,7 @@ import typing as t -from langchain.schema import Generation, LLMResult +from langchain_core.outputs import Generation, LLMResult from ragas.llms.base import BaseRagasLLM diff --git a/tests/unit/test_simple.py b/tests/unit/test_simple.py index 43b27eef6..4089f0a23 100644 --- a/tests/unit/test_simple.py +++ b/tests/unit/test_simple.py @@ -5,7 +5,7 @@ def test_import(): import ragas - from ragas.testset.testset_generator import TestsetGenerator + from ragas.testset.generator import TestsetGenerator assert TestsetGenerator is not None assert ragas is not None