diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index ee22f6011..c35d2bbea 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -31,46 +31,42 @@ class Document(LCDocument): doc_id: str = Field(default_factory=lambda: str(uuid.uuid4())) - filename: t.Optional[str] = None embedding: t.Optional[t.List[float]] = Field(default=None, repr=False) - @classmethod - def from_langchain_document(cls, doc: LCDocument): - doc_id = str(uuid.uuid4()) - if doc.metadata.get("filename"): - filename = doc.metadata["filename"] + @property + def filename(self): + filename = self.metadata.get("filename") + if filename is not None: + filename = self.metadata["filename"] else: logger.info( - "Document [ID: %s] has no filename. Using doc_id as filename.", doc_id + "Document [ID: %s] has no filename, using `doc_id` instead", self.doc_id ) - filename = doc_id + filename = self.doc_id + + return filename + + @classmethod + def from_langchain_document(cls, doc: LCDocument): + doc_id = str(uuid.uuid4()) return cls( page_content=doc.page_content, metadata=doc.metadata, doc_id=doc_id, - filename=filename, ) @classmethod def from_llamaindex_document(cls, doc: LlamaindexDocument): doc_id = str(uuid.uuid4()) - if doc.metadata.get("filename"): - filename = doc.metadata["filename"] - else: - logger.info( - "Document [ID: %s] has no filename. Using doc_id as filename.", doc_id - ) - filename = doc_id return cls( page_content=doc.text, metadata=doc.metadata, doc_id=doc_id, - filename=filename, ) - -class Node(Document): - keyphrases: t.List[str] = Field(default_factory=list, repr=False) + def __eq__(self, other) -> bool: + # if the doc_id's are same then the Document objects are same + return self.doc_id == other.doc_id class Direction(str, Enum): @@ -84,6 +80,21 @@ class Direction(str, Enum): DOWN = "down" +class Node(Document): + keyphrases: t.List[str] = Field(default_factory=list, repr=False) + relationships: t.Dict[Direction, t.Any] = Field(default_factory=dict, repr=False) + doc_similarity: t.Optional[float] = Field(default=None, repr=False) + wins: int = 0 + + @property + def next(self): + return self.relationships.get(Direction.NEXT) + + @property + def prev(self): + return self.relationships.get(Direction.PREV) + + class DocumentStore(ABC): def __init__(self): self.documents = {} @@ -110,12 +121,6 @@ def get_similar( ) -> t.Union[t.List[Document], t.List[Node]]: ... - @abstractmethod - def get_adjacent( - self, node: Node, direction: Direction = Direction.NEXT - ) -> t.Optional[Node]: - ... - def set_run_config(self, run_config: RunConfig): ... @@ -206,7 +211,6 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True): Node.from_langchain_document(d) for d in self.splitter.transform_documents(docs) ] - self.add_nodes(nodes, show_progress=show_progress) def add_nodes( @@ -264,14 +268,70 @@ def add_nodes( ), "Embedding must be list or np.ndarray" self.node_embeddings_list.append(n.embedding) + self.calculate_nodes_docs_similarity() + self.set_node_relataionships() + + def set_node_relataionships(self): + for i, node in enumerate(self.nodes): + if i > 0: + prev_node = self.nodes[i - 1] + if prev_node.filename == node.filename: + node.relationships[Direction.PREV] = prev_node + prev_node.relationships[Direction.NEXT] = node + else: + node.relationships[Direction.PREV] = None + prev_node.relationships[Direction.NEXT] = None + if i == len(self.nodes) - 1: + node.relationships[Direction.NEXT] = None + + def calculate_nodes_docs_similarity(self): + doc_embeddings = {} + filename_ids = set( + [node.filename for node in self.nodes if node.filename is not None] + ) + node_ids = set([node.doc_id for node in self.nodes]) + + if len(filename_ids) == len(node_ids): + logger.warning("Filename and doc_id are the same for all nodes.") + for node in self.nodes: + node.doc_similarity = 1.0 + + else: + for file_id in filename_ids: + nodes_embedding = np.array( + [node.embedding for node in self.nodes if node.filename == file_id] + ) + nodes_embedding = nodes_embedding.reshape(len(nodes_embedding), -1) + doc_embeddings[file_id] = np.mean(nodes_embedding, axis=0) + + for node in self.nodes: + assert node.embedding is not None, "Embedding cannot be None" + node.doc_similarity = similarity( + node.embedding, doc_embeddings[node.filename] + ) + def get_node(self, node_id: str) -> Node: return self.node_map[node_id] def get_document(self, doc_id: str) -> Node: raise NotImplementedError - def get_random_nodes(self, k=1) -> t.List[Node]: - return rng.choice(np.array(self.nodes), size=k).tolist() + def get_random_nodes(self, k=1, alpha=0.1) -> t.List[Node]: + def adjustment_factor(wins, alpha): + return np.exp(-alpha * wins) + + scores = [adjustment_factor(node.wins, alpha) for node in self.nodes] + similarity_scores = [node.doc_similarity for node in self.nodes] + prob = np.array(scores) * np.array(similarity_scores) + prob = prob / np.sum(prob) + + nodes = rng.choice(np.array(self.nodes), size=k, p=prob).tolist() + + for node in nodes: + idx = self.nodes.index(node) + self.nodes[idx].wins += 1 + + return nodes def get_similar( self, node: Node, threshold: float = 0.7, top_k: int = 3 @@ -293,31 +353,6 @@ def get_similar( items = [self.nodes[doc_id] for doc_id in doc_ids] return items - def get_adjacent( - self, node: Node, direction: Direction = Direction.NEXT - ) -> t.Optional[Node]: - # linear search for doc_id of doc in documents_list - index = self.nodes.index(node) - - if direction == Direction.NEXT: - if len(self.nodes) > index + 1: - next_doc = self.nodes[index + 1] - if next_doc.filename == node.filename: - return next_doc - else: - return None - else: - return None - if direction == Direction.PREV: - if index > 0: - prev_doc = self.nodes[index - 1] - if prev_doc.filename == node.filename: - return prev_doc - else: - return None - else: - return None - def set_run_config(self, run_config: RunConfig): if self.embeddings: self.embeddings.set_run_config(run_config) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index e312621ab..9bc64ca14 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -13,7 +13,7 @@ from ragas.llms.json_load import json_loader from ragas.llms.prompt import Prompt from ragas.run_config import RunConfig -from ragas.testset.docstore import Direction, DocumentStore, Node +from ragas.testset.docstore import DocumentStore, Node from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( compress_question_prompt, @@ -21,6 +21,7 @@ find_relevent_context_prompt, multi_context_question_prompt, question_answer_prompt, + question_rewrite_prompt, reasoning_question_prompt, seed_question_prompt, ) @@ -42,7 +43,7 @@ class CurrentNodes: class DataRow(BaseModel): question: str contexts: t.List[str] - ground_truth: str + ground_truth: t.Union[str, float] = np.nan evolution_type: str @@ -58,6 +59,9 @@ class Evolution: find_relevent_context_prompt: Prompt = field( default_factory=lambda: find_relevent_context_prompt ) + rewrite_invalid_question_prompt: Prompt = field( + default_factory=lambda: question_rewrite_prompt + ) max_tries: int = 5 is_async: bool = True @@ -122,34 +126,6 @@ async def _transform_question(self, prompt: Prompt, question: str) -> str: ) return results.generations[0][0].text.strip() - def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): - """ - if the evolutions doesn't have enough nodes to frame a question, get more nodes - """ - assert self.docstore is not None, "docstore cannot be None" - - # get more nodes from above the context window - prev_adjacent_node = self.docstore.get_adjacent( - current_nodes.nodes[0], Direction.PREV - ) - if prev_adjacent_node is None: - # get more nodes from below the context window - next_adjacent_node = self.docstore.get_adjacent( - current_nodes.nodes[-1], Direction.NEXT - ) - if next_adjacent_node is not None: - # add next nodes towards the end - current_nodes.nodes.append(next_adjacent_node) - else: - # retry with new base node - nodes = self.docstore.get_random_nodes(k=1) - return CurrentNodes(root_node=nodes[0], nodes=nodes) - else: - # add prev nodes in index 0 - current_nodes.nodes.insert(0, prev_adjacent_node) - - return current_nodes - def _get_new_random_node(self): assert self.docstore is not None, "docstore cannot be None" new_node = self.docstore.get_random_nodes(k=1)[0] @@ -170,12 +146,33 @@ async def evolve(self, current_nodes: CurrentNodes) -> DataRow: evolution_type=evolution_type, ) + async def fix_invalid_question(self, question: str, current_nodes: CurrentNodes): + """ + if the question is invalid, get more nodes and retry + """ + prev_node = current_nodes.root_node.prev + if prev_node is not None: + current_nodes.nodes.insert(0, prev_node) + current_nodes.root_node = prev_node + prompt = self.rewrite_invalid_question_prompt.format( + question=question, context=self.merge_nodes(current_nodes).page_content + ) + results = await self.generator_llm.generate( + prompt=prompt, is_async=self.is_async + ) + question = results.generations[0][0].text.strip() + + return question, current_nodes + @abstractmethod async def _aevolve( self, current_tries: int, current_nodes: CurrentNodes ) -> EvolutionOutput: ... + async def filter_and_retry(self, question): + ... + async def generate_datarow( self, question: str, @@ -217,16 +214,19 @@ async def generate_datarow( question=question, context=merged_nodes.page_content ) ) - answer = results.generations[0][0].text.strip() + answer = await json_loader.safe_load( + results.generations[0][0].text.strip(), self.generator_llm + ) + answer = answer if isinstance(answer, dict) else {} logger.debug("answer generated: %s", answer) - - if answer == "-1": - answer = None + answer = ( + np.nan if answer.get("verdict") == "-1" else answer.get("answer", np.nan) + ) return DataRow( question=question, contexts=[n.page_content for n in relevant_context.nodes], - ground_truth="" if answer is None else answer, + ground_truth=answer, evolution_type=evolution_type, ) @@ -243,6 +243,11 @@ def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: self.find_relevent_context_prompt = self.find_relevent_context_prompt.adapt( language, self.generator_llm, cache_dir ) + self.rewrite_invalid_question_prompt = ( + self.rewrite_invalid_question_prompt.adapt( + language, self.generator_llm, cache_dir + ) + ) self.node_filter.adapt(language, cache_dir) self.question_filter.adapt(language, cache_dir) @@ -283,23 +288,26 @@ async def _aevolve( results = await self.generator_llm.generate( prompt=self.seed_question_prompt.format( context=merged_node.page_content, - keyphrases=rng.choice( - np.array(merged_node.keyphrases), size=3 - ).tolist(), + keyphrase=rng.choice(np.array(merged_node.keyphrases), size=1)[0], ) ) seed_question = results.generations[0][0].text - # NOTE: might need improvement - # select only one seed question here + logger.info("seed question generated: %s", seed_question) is_valid_question = await self.question_filter.filter(seed_question) + if not is_valid_question: # get more context to rewrite question - current_nodes = self._get_more_adjacent_nodes(current_nodes) - # retry with new nodes added - return await self.aretry_evolve(current_tries, current_nodes) - else: - # if valid question - return seed_question, current_nodes, "simple" + seed_question, current_nodes = await self.fix_invalid_question( + seed_question, current_nodes + ) + logger.info("rewritten question: %s", seed_question) + is_valid_question = await self.question_filter.filter(seed_question) + if not is_valid_question: + # retry with new nodes added + current_nodes = self._get_new_random_node() + return await self.aretry_evolve(current_tries, current_nodes) + + return seed_question, current_nodes, "simple" def __hash__(self): return hash(self.__class__.__name__) @@ -369,21 +377,28 @@ async def _acomplex_evolution( ) reasoning_question = result.generations[0][0].text.strip() + if not await self.question_filter.filter(reasoning_question): + # retry + reasoning_question, current_nodes = await self.fix_invalid_question( + reasoning_question, current_nodes + ) + logger.info("rewritten question: %s", reasoning_question) + is_valid_question = await self.question_filter.filter(reasoning_question) + if not is_valid_question: + # retry with new nodes added + current_nodes = self.se._get_new_random_node() + return await self.aretry_evolve(current_tries, current_nodes) + # compress the question compressed_question = await self._transform_question( prompt=self.compress_question_prompt, question=reasoning_question ) logger.debug( - "[%s] multicontext question compressed: %s", + "[%s] question compressed: %s", self.__class__.__name__, reasoning_question, ) - if not await self.question_filter.filter(compressed_question): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_tries, current_nodes) - assert self.evolution_filter is not None, "evolution filter cannot be None" if await self.evolution_filter.filter(simple_question, compressed_question): # retry @@ -463,8 +478,17 @@ async def _aevolve( if not await self.question_filter.filter(question): # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_tries, current_nodes) + # get more context to rewrite question + question, current_nodes = await self.fix_invalid_question( + question, current_nodes + ) + logger.info("rewritten question: %s", question) + is_valid_question = await self.question_filter.filter(question) + + if not is_valid_question: + # retry with new nodes added + current_nodes = self.se._get_new_random_node() + return await self.aretry_evolve(current_tries, current_nodes) # compress the question compressed_question = await self._transform_question( diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index be6ebd886..76f53aa8d 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -77,7 +77,7 @@ def with_openai( critic_llm: str = "gpt-4", embeddings: str = "text-embedding-ada-002", docstore: t.Optional[DocumentStore] = None, - chunk_size: int = 512, + chunk_size: int = 1024, ) -> "TestsetGenerator": generator_llm_model = LangchainLLMWrapper(ChatOpenAI(model=generator_llm)) critic_llm_model = LangchainLLMWrapper(ChatOpenAI(model=critic_llm)) diff --git a/src/ragas/testset/prompts.py b/src/ragas/testset/prompts.py index 766f9e378..0dd853308 100644 --- a/src/ragas/testset/prompts.py +++ b/src/ragas/testset/prompts.py @@ -270,22 +270,28 @@ question_answer_prompt = Prompt( name="answer_formulate", - instruction="""Answer the question using the information from the given context. Answer '-1' if answer is not present in the context.""", + instruction="""Answer the question using the information from the given context. Output verdict as '1' if answer is present '-1' if answer is not present in the context.""", examples=[ { "context": """The novel '1984' by George Orwell is set in a dystopian future where the world is divided into three superstates. The story follows the life of Winston Smith, who lives in Oceania, a superstate constantly at war.""", "question": "In which superstate does Winston Smith live in the novel '1984'?", - "answer": "Winston Smith lives in the superstate of Oceania in the novel '1984'.", + "answer": { + "answer": "Winston Smith lives in the superstate of Oceania in the novel '1984'.", + "verdict": "1", + }, }, { "context": """The novel "Pride and Prejudice" by Jane Austen revolves around the character Elizabeth Bennet and her family. The story is set in the 19th century in rural England and deals with issues of marriage, morality, and misconceptions.""", "question": "What year was 'Pride and Prejudice' published?", - "answer": "-1", + "answer": { + "answer": "The answer to given question is not present in context", + "verdict": "-1", + }, }, ], input_keys=["context", "question"], output_key="answer", - output_type="string", + output_type="json", language="english", ) @@ -324,40 +330,25 @@ seed_question_prompt = Prompt( name="seed_question", - instruction="generate a question that can be fully answered from given context. The question should contain atleast two of the given keyphrases", + instruction="generate a question that can be fully answered from given context. The question should contain the given keyphrase", examples=[ { "context": "Photosynthesis in plants involves converting light energy into chemical energy, using chlorophyll and other pigments to absorb light. This process is crucial for plant growth and the production of oxygen.", - "keyphrases": [ - "Photosynthesis", - "Light energy", - "Chlorophyll", - "Oxygen production", - ], - "question": "How does chlorophyll aid in converting light energy into chemical energy during photosynthesis?", + "keyphrase": "Photosynthesis", + "question": "What is the role of photosynthesis in plant growth?", }, { "context": "The Industrial Revolution, starting in the 18th century, marked a major turning point in history as it led to the development of factories and urbanization.", - "keyphrases": [ - "Industrial Revolution", - "18th century", - "Factories", - "Urbanization", - ], - "question": "Why did the Industrial Revolution significantly contribute to the development of factories and urbanization?", + "keyphrase": "Industrial Revolution", + "question": "How did the Industrial Revolution mark a major turning point in history?", }, { - "context": "A black hole is a region of spacetime where gravity is so strong that nothing, including light and other electromagnetic waves, has enough energy to escape it. The theory of general relativity predicts that a sufficiently compact mass can deform spacetime to form a black hole.", - "keyphrases": [ - "Black hole", - "region of spacetime", - "Sufficiently compact mass", - "Energy to escape", - ], - "question": "What is a black hole and how does it relate to a region of spacetime?", + "context": "The process of evaporation plays a crucial role in the water cycle, converting water from liquid to vapor and allowing it to rise into the atmosphere.", + "keyphrase": "Evaporation", + "question": "Why is evaporation important in the water cycle?", }, ], - input_keys=["context", "keyphrases"], + input_keys=["context", "keyphrase"], output_key="question", output_type="string", ) diff --git a/tests/unit/testset_generator/test_docstore.py b/tests/unit/testset_generator/test_docstore.py index 0c6071171..b934ebaae 100644 --- a/tests/unit/testset_generator/test_docstore.py +++ b/tests/unit/testset_generator/test_docstore.py @@ -7,7 +7,7 @@ from langchain.text_splitter import TokenTextSplitter from langchain_core.embeddings import Embeddings -from ragas.testset.docstore import Direction, InMemoryDocumentStore, Node +from ragas.testset.docstore import InMemoryDocumentStore, Node class FakeEmbeddings(Embeddings): @@ -33,24 +33,19 @@ async def aembed_query(self, text: str) -> t.List[float]: def test_adjacent_nodes(): - a1 = Node(doc_id="a1", page_content="a1", filename="a") - a2 = Node(doc_id="a2", page_content="a2", filename="a") - b = Node(doc_id="b", page_content="b", filename="b") + a1 = Node(doc_id="a1", page_content="a1", metadata={"filename": "a"}) + a2 = Node(doc_id="a2", page_content="a2", metadata={"filename": "a"}) + b = Node(doc_id="b", page_content="b", metadata={"filename": "a"}) fake_embeddings = FakeEmbeddings() splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=0) - store = InMemoryDocumentStore(splitter=splitter, embeddings=fake_embeddings) store.nodes = [a1, a2, b] + store.set_node_relataionships() - assert store.get_adjacent(a1) == a2 - assert store.get_adjacent(a2, Direction.PREV) == a1 - assert store.get_adjacent(a2, Direction.NEXT) is None - assert store.get_adjacent(b, Direction.PREV) is None - - # raise ValueError if doc not in store - c = Node(doc_id="c", page_content="c", filename="c") - pytest.raises(ValueError, store.get_adjacent, c) + assert store.nodes[0].next == a2 + assert store.nodes[1].prev == a1 + assert store.nodes[2].next is None def create_test_nodes(with_embeddings=True): @@ -63,10 +58,16 @@ def create_test_nodes(with_embeddings=True): embeddings = defaultdict(lambda: None) a1 = Node( - doc_id="a1", page_content="cat", filename="a", embedding=embeddings["cat"] + doc_id="a1", + page_content="cat", + metadata={"filename": "a"}, + embedding=embeddings["cat"], ) a2 = Node( - doc_id="a2", page_content="mouse", filename="a", embedding=embeddings["mouse"] + doc_id="a2", + page_content="mouse", + metadata={"filename": "a"}, + embedding=embeddings["mouse"], ) b = Node( doc_id="b", @@ -144,7 +145,9 @@ async def test_fake_embeddings(): def test_docstore_add_batch(fake_llm): # create a dummy embeddings with support for async aembed_query() fake_embeddings = FakeEmbeddings() - store = InMemoryDocumentStore(splitter=None, embeddings=fake_embeddings, llm=fake_llm) # type: ignore + store = InMemoryDocumentStore( + splitter=None, embeddings=fake_embeddings, llm=fake_llm + ) # type: ignore # add documents in batch nodes = create_test_nodes(with_embeddings=False) @@ -154,8 +157,12 @@ def test_docstore_add_batch(fake_llm): == fake_embeddings.embeddings[nodes[0].page_content] ) # add documents in batch that have some embeddings - c = Node(doc_id="c", page_content="c", filename="c", embedding=[0.0] * 768) - d = Node(doc_id="d", page_content="d", filename="d", embedding=[0.0] * 768) + c = Node( + doc_id="c", page_content="c", metadata={"filename": "c"}, embedding=[0.0] * 768 + ) + d = Node( + doc_id="d", page_content="d", metadata={"filename": "d"}, embedding=[0.0] * 768 + ) store.add_nodes([c, d]) # test get() and that embeddings and keyphrases are correct diff --git a/tests/unit/testset_generator/test_document.py b/tests/unit/testset_generator/test_document.py new file mode 100644 index 000000000..e5cfe91c5 --- /dev/null +++ b/tests/unit/testset_generator/test_document.py @@ -0,0 +1,32 @@ +import uuid + +from ragas.testset.docstore import Document + + +def test_document_filename(monkeypatch): + monkeypatch.setattr(uuid, "uuid4", lambda: "test-uuid") + d1 = Document(page_content="a1") + assert d1.filename == "test-uuid" + + # now suppose I add a filename to metadata + d2 = Document(page_content="a2", metadata={"filename": "test-filename"}) + assert d2.filename == "test-filename" + + +def test_document_chunking(): + """ + Tests to make sure that there is no problem when you chunk a document into Nodes + especially because of the fact that Node objects are created again. + """ + from langchain.text_splitter import TokenTextSplitter + from langchain_core.documents import Document + + from ragas.testset.docstore import Node + + splitter = TokenTextSplitter(chunk_size=1, chunk_overlap=0) + doc = Document(page_content="Hello, world!", metadata={"filename": "test-filename"}) + nodes = [ + Node.from_langchain_document(d) for d in splitter.transform_documents([doc]) + ] + for node in nodes: + assert node.filename == "test-filename"