From 721a34cd4c9eae04fdbbf2ace3bba90c426825e2 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sat, 20 Jan 2024 12:23:14 -0800 Subject: [PATCH 01/14] cleanup simple evolution --- src/ragas/testset/evolutions.py | 17 ++++++++--------- src/ragas/testset/generator.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 88e879467..2e8f7c43a 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -130,8 +130,10 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): self.nodes = docstore.get_random_nodes(k=1) return await self.aretry_evolve(llm, docstore, update_count=False) - # frame a basic question with with node - seed_question = await simple_evolution(llm, merged_node) + results = llm.generate_text( + prompt=seed_question_prompt.format(context=merged_node.page_content) + ) + seed_question = results.generations[0][0].text # NOTE: might need improvement # select only one seed question here is_valid_question = await self.question_filter.afilter(seed_question) @@ -145,13 +147,6 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): return seed_question -async def simple_evolution(llm: BaseRagasLLM, seed_doc: Document): - prompt = seed_question_prompt.format(context=seed_doc.page_content) - results = llm.generate_text(prompt=prompt) - results = results.generations[0][0].text - return results - - async def multi_context_evolution( llm: BaseRagasLLM, seed_node: Node, doc_store: DocumentStore ): @@ -163,3 +158,7 @@ async def multi_context_evolution( results = await llm.agenerate_text(prompt=prompt) question = results.generations[0][0].text.strip() return question + + +class MultiContext(Evolution): + ... diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 94b957034..1ff77b5ce 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -6,10 +6,10 @@ from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings import BaseRagasEmbeddings -from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution +from ragas.executor import Executor +from ragas.testset.evolutions import SimpleEvolution, QuestionFilter, NodeFilter @dataclass From 6d583b265b3f382dade8b5c3877eabea6694166a Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sat, 20 Jan 2024 18:19:08 -0800 Subject: [PATCH 02/14] added multi_context question --- src/ragas/testset/evolutions.py | 63 ++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 2e8f7c43a..d49a5522d 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import typing as t from abc import ABC, abstractmethod @@ -14,11 +16,15 @@ filter_question_prompt, multi_context_question_prompt, seed_question_prompt, + compress_question_prompt, ) rng = default_rng() logger = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from ragas.llms.prompt import Prompt + @dataclass class Filter(ABC): @@ -147,18 +153,49 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): return seed_question -async def multi_context_evolution( - llm: BaseRagasLLM, seed_node: Node, doc_store: DocumentStore -): - question = simple_evolution(llm, seed_node) - similar_context = doc_store.get_similar(seed_node)[0] - prompt = multi_context_question_prompt.format( - question=question, context1=seed_node.page_content, context2=similar_context - ) - results = await llm.agenerate_text(prompt=prompt) - question = results.generations[0][0].text.strip() - return question +@dataclass +class MultiContextEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution(self.node_filter, self.question_filter) -class MultiContext(Evolution): - ... + def _transform_question( + self, llm: BaseRagasLLM, prompt: Prompt, question: str + ) -> str: + results = llm.generate_text(prompt=prompt.format(question=question)) + return results.generations[0][0].text.strip() + + def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + logger.info("evolving question") + return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + + async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + # gerenate seed question + self._root_node = docstore.get_random_nodes(k=1)[0] + question = await self.se.aevolve(llm, docstore) + logger.debug("[MultiContextEvolution] simple question generated: %s", question) + + # find a similar node and generate a question based on both + similar_context = docstore.get_similar(self._root_node)[0] + prompt = multi_context_question_prompt.format( + question=question, + context1=self._root_node.page_content, + context2=similar_context, + ) + results = await llm.agenerate_text(prompt=prompt) + question = results.generations[0][0].text.strip() + logger.debug( + "[MultiContextEvolution] multicontext question generated: %s", question + ) + + # compress the question + compressed_question = self._transform_question( + llm=llm, prompt=compress_question_prompt, question=question + ) + logger.debug( + "[MultiContextEvolution] multicontext question compressed: %s", question + ) + + return compressed_question From 7d07f4f64ac5c68c82c97b45fa484efe9925055f Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sat, 20 Jan 2024 20:43:24 -0800 Subject: [PATCH 03/14] resoning_question --- src/ragas/testset/evolutions.py | 53 +++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index d49a5522d..cff96ba79 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -17,6 +17,7 @@ multi_context_question_prompt, seed_question_prompt, compress_question_prompt, + reasoning_question_prompt, ) rng = default_rng() @@ -174,15 +175,17 @@ def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: # gerenate seed question self._root_node = docstore.get_random_nodes(k=1)[0] - question = await self.se.aevolve(llm, docstore) - logger.debug("[MultiContextEvolution] simple question generated: %s", question) + simple_question = await self.se.aevolve(llm, docstore) + logger.debug( + "[MultiContextEvolution] simple question generated: %s", simple_question + ) # find a similar node and generate a question based on both - similar_context = docstore.get_similar(self._root_node)[0] + similar_node = docstore.get_similar(self._root_node)[0] prompt = multi_context_question_prompt.format( - question=question, + question=simple_question, context1=self._root_node.page_content, - context2=similar_context, + context2=similar_node, ) results = await llm.agenerate_text(prompt=prompt) question = results.generations[0][0].text.strip() @@ -198,4 +201,44 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: "[MultiContextEvolution] multicontext question compressed: %s", question ) + if not await self.question_filter.afilter(compressed_question): + # retry + ... + + # if not await self.evolution_elimation.afilter( + # simple_question, compressed_question + # ): + # ... + return compressed_question + + +@dataclass +class ReasoningEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution(self.node_filter, self.question_filter) + + def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + logger.debug("evolving question") + return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + + async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + self._root_node = docstore.get_random_nodes(k=1)[0] + simple_question = await self.se.aevolve(llm, docstore) + logger.debug( + "[MultiContextEvolution] simple question generated: %s", simple_question + ) + + result = await llm.agenerate_text( + prompt=reasoning_question_prompt.format( + question=simple_question, context=self._root_node + ) + ) + reasoning_question = result.generations[0][0].text.strip() + if not await self.question_filter.afilter(reasoning_question): + # retry + ... + return reasoning_question From 77640338f4873f919e9d06e487fc2c448f421657 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 12:50:18 -0800 Subject: [PATCH 04/14] refactored evolutions --- src/ragas/testset/evolutions.py | 211 ++++++++++++++++++++------------ 1 file changed, 135 insertions(+), 76 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index cff96ba79..c63b4539d 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -18,6 +18,7 @@ seed_question_prompt, compress_question_prompt, reasoning_question_prompt, + evolution_elimination_prompt, ) rng = default_rng() @@ -65,22 +66,49 @@ async def afilter(self, question: str) -> bool: return json_results.get("verdict") != "No" +@dataclass +class EvolutionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, simple_question: str, compressed_question: str) -> bool: + return asyncio.get_event_loop().run_until_complete( + self.afilter(simple_question, compressed_question) + ) + + async def afilter(self, simple_question: str, compressed_question: str) -> bool: + prompt = evolution_elimination_prompt.format( + question1=simple_question, question2=compressed_question + ) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" + + +@dataclass +class CurrentNodes: + root_node: Node + nodes: t.List[Node] = field(default_factory=list) + + @dataclass class Evolution: + generator_llm: BaseRagasLLM + docstore: DocumentStore node_filter: NodeFilter question_filter: QuestionFilter - nodes: t.List[Node] = field(default_factory=list) max_tries: int = 5 - _root_node: t.Optional[Node] = field(default=None, init=False, repr=False) _tries: int = field(default=0, init=False, repr=False) - def merged_nodes(self) -> Node: + @staticmethod + def merge_nodes(nodes: CurrentNodes) -> Node: return Node( - doc_id="merged", page_content=" ".join(n.page_content for n in self.nodes) + doc_id="merged", page_content=" ".join(n.page_content for n in nodes.nodes) ) async def aretry_evolve( - self, llm: BaseRagasLLM, docstore: DocumentStore, update_count: bool = True + self, current_nodes: CurrentNodes, update_count: bool = True ): if update_count: self._tries += 1 @@ -88,56 +116,81 @@ async def aretry_evolve( if self._tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self.aevolve(llm, docstore) + return await self.aevolve(current_nodes) + + def _transform_question(self, prompt: Prompt, question: str) -> str: + results = self.generator_llm.generate_text( + prompt=prompt.format(question=question) + ) + return results.generations[0][0].text.strip() @abstractmethod - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + def evolve(self, current_nodes: CurrentNodes) -> str: ... @abstractmethod - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + async def aevolve(self, current_nodes: CurrentNodes) -> str: ... +@dataclass +class ComplexEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + evolution_filter: EvolutionFilter = field(init=False, repr=False) + + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + @dataclass class SimpleEvolution(Evolution): - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + def evolve(self, current_nodes: CurrentNodes) -> str: logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - def _get_more_adjacent_nodes(self, docstore: DocumentStore): + def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes """ - assert self._root_node is not None, "root node cannot be None" # get more nodes from above the context window - prev_adjacent_node = docstore.get_adjacent(self._root_node, Direction.PREV) + prev_adjacent_node = self.docstore.get_adjacent( + current_nodes.nodes[0], Direction.PREV + ) if prev_adjacent_node is None: # get more nodes from below the context window - next_adjacent_node = docstore.get_adjacent(self._root_node, Direction.NEXT) + next_adjacent_node = self.docstore.get_adjacent( + current_nodes.nodes[-1], Direction.NEXT + ) if next_adjacent_node is not None: # add next nodes towards the end - self.nodes.append(next_adjacent_node) + current_nodes.nodes.append(next_adjacent_node) else: # retry with new base node - self.nodes = docstore.get_random_nodes(k=1) - self._root_node = self.nodes[0] + nodes = self.docstore.get_random_nodes(k=1) + return CurrentNodes(root_node=nodes[0], nodes=nodes) else: # add prev nodes in index 0 - self.nodes.insert(0, prev_adjacent_node) - - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): - # can the node be used to frame a question? - if self._tries == 0: - self.nodes = docstore.get_random_nodes(k=1) - self._root_node = self.nodes[0] - merged_node = self.merged_nodes() - passed = await self.node_filter.afilter(self.nodes[0]) + current_nodes.nodes.insert(0, prev_adjacent_node) + + return current_nodes + + async def aevolve(self, current_nodes: CurrentNodes) -> str: + merged_node = self.merge_nodes(current_nodes) + passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: - self.nodes = docstore.get_random_nodes(k=1) - return await self.aretry_evolve(llm, docstore, update_count=False) + nodes = self.docstore.get_random_nodes(k=1) + new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) + return await self.aretry_evolve(new_current_nodes, update_count=False) - results = llm.generate_text( + results = self.generator_llm.generate_text( prompt=seed_question_prompt.format(context=merged_node.page_content) ) seed_question = results.generations[0][0].text @@ -146,48 +199,34 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): is_valid_question = await self.question_filter.afilter(seed_question) if not is_valid_question: # get more context to rewrite question - self._get_more_adjacent_nodes(docstore) + current_nodes = self._get_more_adjacent_nodes(current_nodes) # retry with new nodes added - return await self.aretry_evolve(llm, docstore) + return await self.aretry_evolve(current_nodes) else: # if valid question return seed_question @dataclass -class MultiContextEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution(self.node_filter, self.question_filter) - - def _transform_question( - self, llm: BaseRagasLLM, prompt: Prompt, question: str - ) -> str: - results = llm.generate_text(prompt=prompt.format(question=question)) - return results.generations[0][0].text.strip() - - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): +class MultiContextEvolution(ComplexEvolution): + def evolve(self, current_nodes: CurrentNodes) -> str: logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: - # gerenate seed question - self._root_node = docstore.get_random_nodes(k=1)[0] - simple_question = await self.se.aevolve(llm, docstore) + async def aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se.aevolve(current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) # find a similar node and generate a question based on both - similar_node = docstore.get_similar(self._root_node)[0] + similar_node = self.docstore.get_similar(current_nodes.root_node)[0] prompt = multi_context_question_prompt.format( question=simple_question, - context1=self._root_node.page_content, + context1=current_nodes.root_node.page_content, context2=similar_node, ) - results = await llm.agenerate_text(prompt=prompt) + results = await self.generator_llm.agenerate_text(prompt=prompt) question = results.generations[0][0].text.strip() logger.debug( "[MultiContextEvolution] multicontext question generated: %s", question @@ -195,7 +234,7 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: # compress the question compressed_question = self._transform_question( - llm=llm, prompt=compress_question_prompt, question=question + prompt=compress_question_prompt, question=question ) logger.debug( "[MultiContextEvolution] multicontext question compressed: %s", question @@ -203,42 +242,62 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: if not await self.question_filter.afilter(compressed_question): # retry - ... + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_nodes) - # if not await self.evolution_elimation.afilter( - # simple_question, compressed_question - # ): - # ... + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_nodes) return compressed_question @dataclass -class ReasoningEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution(self.node_filter, self.question_filter) - - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): +class ReasoningEvolution(ComplexEvolution): + def evolve(self, current_nodes: CurrentNodes) -> str: logger.debug("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: - self._root_node = docstore.get_random_nodes(k=1)[0] - simple_question = await self.se.aevolve(llm, docstore) + async def aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se.aevolve(current_nodes) logger.debug( - "[MultiContextEvolution] simple question generated: %s", simple_question + "[ReasoningEvolution] simple question generated: %s", simple_question ) - result = await llm.agenerate_text( + result = await self.generator_llm.agenerate_text( prompt=reasoning_question_prompt.format( - question=simple_question, context=self._root_node + question=simple_question, context=current_nodes.root_node.page_content ) ) reasoning_question = result.generations[0][0].text.strip() - if not await self.question_filter.afilter(reasoning_question): + # + # compress the question + compressed_question = self._transform_question( + prompt=compress_question_prompt, question=reasoning_question + ) + logger.debug( + "[ReasoningEvolution] multicontext question compressed: %s", + reasoning_question, + ) + + if not await self.question_filter.afilter(compressed_question): # retry - ... + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_nodes) + + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + logger.debug( + "evolution_filter failed, retrying with %s", len(current_nodes.nodes) + ) + return await self.aretry_evolve(current_nodes) + return reasoning_question From 3127ad1c70d2a29f111059ae6457c0359b960b97 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 14:53:27 -0800 Subject: [PATCH 05/14] return DataRow now --- src/ragas/testset/evolutions.py | 132 ++++++++++++++++++++------------ 1 file changed, 84 insertions(+), 48 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index c63b4539d..dfe891b90 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -4,6 +4,7 @@ import typing as t from abc import ABC, abstractmethod from dataclasses import dataclass, field +from collections import namedtuple from fsspec.exceptions import asyncio from numpy.random import default_rng @@ -19,6 +20,7 @@ compress_question_prompt, reasoning_question_prompt, evolution_elimination_prompt, + question_answer_prompt, ) rng = default_rng() @@ -92,6 +94,18 @@ class CurrentNodes: nodes: t.List[Node] = field(default_factory=list) +DataRow = namedtuple( + "DataRow", + [ + "question", + "context", + "answer", + "question_type", + "evolution_elimination", + ], +) + + @dataclass class Evolution: generator_llm: BaseRagasLLM @@ -104,19 +118,19 @@ class Evolution: @staticmethod def merge_nodes(nodes: CurrentNodes) -> Node: return Node( - doc_id="merged", page_content=" ".join(n.page_content for n in nodes.nodes) + doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes) ) async def aretry_evolve( self, current_nodes: CurrentNodes, update_count: bool = True - ): + ) -> str: if update_count: self._tries += 1 logger.info("retrying evolution: %s times", self._tries) if self._tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self.aevolve(current_nodes) + return await self._aevolve(current_nodes) def _transform_question(self, prompt: Prompt, question: str) -> str: results = self.generator_llm.generate_text( @@ -124,38 +138,6 @@ def _transform_question(self, prompt: Prompt, question: str) -> str: ) return results.generations[0][0].text.strip() - @abstractmethod - def evolve(self, current_nodes: CurrentNodes) -> str: - ... - - @abstractmethod - async def aevolve(self, current_nodes: CurrentNodes) -> str: - ... - - -@dataclass -class ComplexEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - evolution_filter: EvolutionFilter = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution( - generator_llm=self.generator_llm, - docstore=self.docstore, - node_filter=self.node_filter, - question_filter=self.question_filter, - ) - # init evolution filter with critic llm from another filter - self.evolution_filter = EvolutionFilter(self.node_filter.llm) - - -@dataclass -class SimpleEvolution(Evolution): - def evolve(self, current_nodes: CurrentNodes) -> str: - logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes @@ -182,7 +164,69 @@ def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): return current_nodes - async def aevolve(self, current_nodes: CurrentNodes) -> str: + def evolve(self, current_nodes: CurrentNodes) -> DataRow: + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) + + async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: + evolved_question = await self._aevolve(current_nodes) + return self.generate_datarow( + question=evolved_question, + current_nodes=current_nodes, + ) + + @abstractmethod + async def _aevolve(self, current_nodes: CurrentNodes) -> str: + ... + + def generate_datarow( + self, + question: str, + current_nodes: CurrentNodes, + question_type: str = "", + evolution_elimination: bool = False, + ): + merged_nodes = self.merge_nodes(current_nodes) + results = self.generator_llm.generate_text( + prompt=question_answer_prompt.format( + question=question, context=merged_nodes.page_content + ) + ) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("answer generated: %s", json_results) + + # TODO: what do if answer is -1? + answer = json_results.get("answer") + + return DataRow( + question=question, + context=merged_nodes.page_content, + answer=answer, + question_type=question_type, + evolution_elimination=evolution_elimination, + ) + + +@dataclass +class ComplexEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + evolution_filter: EvolutionFilter = field(init=False, repr=False) + + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + +@dataclass +class SimpleEvolution(Evolution): + async def _aevolve(self, current_nodes: CurrentNodes) -> str: merged_node = self.merge_nodes(current_nodes) passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: @@ -209,12 +253,8 @@ async def aevolve(self, current_nodes: CurrentNodes) -> str: @dataclass class MultiContextEvolution(ComplexEvolution): - def evolve(self, current_nodes: CurrentNodes) -> str: - logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - - async def aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se.aevolve(current_nodes) + async def _aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se._aevolve(current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) @@ -258,12 +298,8 @@ async def aevolve(self, current_nodes: CurrentNodes) -> str: @dataclass class ReasoningEvolution(ComplexEvolution): - def evolve(self, current_nodes: CurrentNodes) -> str: - logger.debug("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - - async def aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se.aevolve(current_nodes) + async def _aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se._aevolve(current_nodes) logger.debug( "[ReasoningEvolution] simple question generated: %s", simple_question ) From 4529ebdb6d3f776b6acff961aad0f267ff9fffc4 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 15:04:22 -0800 Subject: [PATCH 06/14] moved out filter --- src/ragas/testset/evolutions.py | 64 +------------------------- src/ragas/testset/filters.py | 79 +++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 62 deletions(-) create mode 100644 src/ragas/testset/filters.py diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index dfe891b90..9e626216d 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -9,83 +9,23 @@ from fsspec.exceptions import asyncio from numpy.random import default_rng -from ragas.llms import BaseRagasLLM from ragas.llms.json_load import load_as_json from ragas.testset.docstore import Direction, Document, DocumentStore, Node from ragas.testset.prompts import ( - context_scoring_prompt, - filter_question_prompt, multi_context_question_prompt, seed_question_prompt, compress_question_prompt, reasoning_question_prompt, - evolution_elimination_prompt, question_answer_prompt, ) +from ragas.testset.filters import NodeFilter, QuestionFilter, EvolutionFilter rng = default_rng() logger = logging.getLogger(__name__) if t.TYPE_CHECKING: from ragas.llms.prompt import Prompt - - -@dataclass -class Filter(ABC): - ... - - -@dataclass -class NodeFilter(Filter): - llm: BaseRagasLLM - threshold: float = 7.5 - - def filter(self, node: Node) -> t.Dict: - return asyncio.get_event_loop().run_until_complete(self.afilter(node)) - - async def afilter(self, node: Node) -> t.Dict: - prompt = context_scoring_prompt.format(context=node.page_content) - results = await self.llm.agenerate_text(prompt=prompt) - output = results.generations[0][0].text.strip() - score = load_as_json(output) - score.update({"score": score.get("score", 0) >= self.threshold}) - return score - - -@dataclass -class QuestionFilter(Filter): - llm: BaseRagasLLM - - def filter(self, question: str) -> bool: - return asyncio.get_event_loop().run_until_complete(self.afilter(question)) - - async def afilter(self, question: str) -> bool: - prompt = filter_question_prompt.format(question=question) - results = await self.llm.agenerate_text(prompt=prompt) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("filtered question: %s", json_results) - return json_results.get("verdict") != "No" - - -@dataclass -class EvolutionFilter(Filter): - llm: BaseRagasLLM - - def filter(self, simple_question: str, compressed_question: str) -> bool: - return asyncio.get_event_loop().run_until_complete( - self.afilter(simple_question, compressed_question) - ) - - async def afilter(self, simple_question: str, compressed_question: str) -> bool: - prompt = evolution_elimination_prompt.format( - question1=simple_question, question2=compressed_question - ) - results = await self.llm.agenerate_text(prompt=prompt) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("filtered question: %s", json_results) - return json_results.get("verdict") != "No" + from ragas.llms import BaseRagasLLM @dataclass diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py new file mode 100644 index 000000000..e74e0cef5 --- /dev/null +++ b/src/ragas/testset/filters.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from dataclasses import dataclass +import logging +from abc import ABC +import asyncio +import typing as t + +from ragas.testset.prompts import ( + context_scoring_prompt, + filter_question_prompt, + evolution_elimination_prompt, +) +from ragas.llms.json_load import load_as_json + +if t.TYPE_CHECKING: + from ragas.llms.base import BaseRagasLLM + from ragas.testset.docstore import Node + + +logger = logging.getLogger(__name__) + + +@dataclass +class Filter(ABC): + ... + + +@dataclass +class NodeFilter(Filter): + llm: BaseRagasLLM + threshold: float = 7.5 + + def filter(self, node: Node) -> t.Dict: + return asyncio.get_event_loop().run_until_complete(self.afilter(node)) + + async def afilter(self, node: Node) -> t.Dict: + prompt = context_scoring_prompt.format(context=node.page_content) + results = await self.llm.agenerate_text(prompt=prompt) + output = results.generations[0][0].text.strip() + score = load_as_json(output) + score.update({"score": score.get("score", 0) >= self.threshold}) + return score + + +@dataclass +class QuestionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, question: str) -> bool: + return asyncio.get_event_loop().run_until_complete(self.afilter(question)) + + async def afilter(self, question: str) -> bool: + prompt = filter_question_prompt.format(question=question) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" + + +@dataclass +class EvolutionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, simple_question: str, compressed_question: str) -> bool: + return asyncio.get_event_loop().run_until_complete( + self.afilter(simple_question, compressed_question) + ) + + async def afilter(self, simple_question: str, compressed_question: str) -> bool: + prompt = evolution_elimination_prompt.format( + question1=simple_question, question2=compressed_question + ) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" From 9ce7eedb8ed1703022b22339c8cf30c7bc8a1ad9 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 15:05:16 -0800 Subject: [PATCH 07/14] fix fmt --- src/ragas/testset/evolutions.py | 16 ++++++++-------- src/ragas/testset/filters.py | 10 +++++----- src/ragas/testset/generator.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 9e626216d..693a6cccf 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -2,30 +2,30 @@ import logging import typing as t -from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from abc import abstractmethod from collections import namedtuple +from dataclasses import dataclass, field from fsspec.exceptions import asyncio from numpy.random import default_rng from ragas.llms.json_load import load_as_json -from ragas.testset.docstore import Direction, Document, DocumentStore, Node +from ragas.testset.docstore import Direction, DocumentStore, Node +from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( - multi_context_question_prompt, - seed_question_prompt, compress_question_prompt, - reasoning_question_prompt, + multi_context_question_prompt, question_answer_prompt, + reasoning_question_prompt, + seed_question_prompt, ) -from ragas.testset.filters import NodeFilter, QuestionFilter, EvolutionFilter rng = default_rng() logger = logging.getLogger(__name__) if t.TYPE_CHECKING: - from ragas.llms.prompt import Prompt from ragas.llms import BaseRagasLLM + from ragas.llms.prompt import Prompt @dataclass diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py index e74e0cef5..79d19bd47 100644 --- a/src/ragas/testset/filters.py +++ b/src/ragas/testset/filters.py @@ -1,17 +1,17 @@ from __future__ import annotations -from dataclasses import dataclass -import logging -from abc import ABC import asyncio +import logging import typing as t +from abc import ABC +from dataclasses import dataclass +from ragas.llms.json_load import load_as_json from ragas.testset.prompts import ( context_scoring_prompt, - filter_question_prompt, evolution_elimination_prompt, + filter_question_prompt, ) -from ragas.llms.json_load import load_as_json if t.TYPE_CHECKING: from ragas.llms.base import BaseRagasLLM diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 1ff77b5ce..94b957034 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -6,10 +6,10 @@ from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings import BaseRagasEmbeddings +from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.executor import Executor -from ragas.testset.evolutions import SimpleEvolution, QuestionFilter, NodeFilter +from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution @dataclass From 9a82bc0a7381f6070432e655a3186b979cf32eb8 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Mon, 22 Jan 2024 16:53:36 -0800 Subject: [PATCH 08/14] basics working --- src/ragas/testset/evolutions.py | 123 +++++++++++++++++++------------- src/ragas/testset/generator.py | 90 ++++++++++++++++++----- 2 files changed, 146 insertions(+), 67 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 693a6cccf..28e93d2e2 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -9,7 +9,6 @@ from fsspec.exceptions import asyncio from numpy.random import default_rng -from ragas.llms.json_load import load_as_json from ragas.testset.docstore import Direction, DocumentStore, Node from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( @@ -48,12 +47,11 @@ class CurrentNodes: @dataclass class Evolution: - generator_llm: BaseRagasLLM - docstore: DocumentStore - node_filter: NodeFilter - question_filter: QuestionFilter + generator_llm: t.Optional[BaseRagasLLM] = None + docstore: t.Optional[DocumentStore] = None + node_filter: t.Optional[NodeFilter] = None + question_filter: t.Optional[QuestionFilter] = None max_tries: int = 5 - _tries: int = field(default=0, init=False, repr=False) @staticmethod def merge_nodes(nodes: CurrentNodes) -> Node: @@ -62,17 +60,19 @@ def merge_nodes(nodes: CurrentNodes) -> Node: ) async def aretry_evolve( - self, current_nodes: CurrentNodes, update_count: bool = True + self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True ) -> str: if update_count: - self._tries += 1 - logger.info("retrying evolution: %s times", self._tries) - if self._tries > self.max_tries: + current_tries += 1 + logger.info("retrying evolution: %s times", current_tries) + if current_tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self._aevolve(current_nodes) + return await self._aevolve(current_tries, current_nodes) def _transform_question(self, prompt: Prompt, question: str) -> str: + assert self.generator_llm is not None, "generator_llm cannot be None" + results = self.generator_llm.generate_text( prompt=prompt.format(question=question) ) @@ -82,6 +82,8 @@ def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes """ + assert self.docstore is not None, "docstore cannot be None" + # get more nodes from above the context window prev_adjacent_node = self.docstore.get_adjacent( current_nodes.nodes[0], Direction.PREV @@ -108,14 +110,16 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow: return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: - evolved_question = await self._aevolve(current_nodes) + # init tries with 0 when first called + current_tries = 0 + evolved_question = await self._aevolve(current_tries, current_nodes) return self.generate_datarow( question=evolved_question, current_nodes=current_nodes, ) @abstractmethod - async def _aevolve(self, current_nodes: CurrentNodes) -> str: + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: ... def generate_datarow( @@ -125,18 +129,19 @@ def generate_datarow( question_type: str = "", evolution_elimination: bool = False, ): + assert self.generator_llm is not None, "generator_llm cannot be None" + merged_nodes = self.merge_nodes(current_nodes) results = self.generator_llm.generate_text( prompt=question_answer_prompt.format( question=question, context=merged_nodes.page_content ) ) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("answer generated: %s", json_results) + answer = results.generations[0][0].text.strip() + logger.debug("answer generated: %s", answer) - # TODO: what do if answer is -1? - answer = json_results.get("answer") + if answer == "-1": + answer = None return DataRow( question=question, @@ -147,32 +152,22 @@ def generate_datarow( ) -@dataclass -class ComplexEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - evolution_filter: EvolutionFilter = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution( - generator_llm=self.generator_llm, - docstore=self.docstore, - node_filter=self.node_filter, - question_filter=self.question_filter, - ) - # init evolution filter with critic llm from another filter - self.evolution_filter = EvolutionFilter(self.node_filter.llm) - - -@dataclass +@dataclass(unsafe_hash=True) class SimpleEvolution(Evolution): - async def _aevolve(self, current_nodes: CurrentNodes) -> str: + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.docstore is not None, "docstore cannot be None" + assert self.node_filter is not None, "node filter cannot be None" + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + merged_node = self.merge_nodes(current_nodes) passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: nodes = self.docstore.get_random_nodes(k=1) new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) - return await self.aretry_evolve(new_current_nodes, update_count=False) + return await self.aretry_evolve( + current_tries, new_current_nodes, update_count=False + ) results = self.generator_llm.generate_text( prompt=seed_question_prompt.format(context=merged_node.page_content) @@ -185,16 +180,39 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: # get more context to rewrite question current_nodes = self._get_more_adjacent_nodes(current_nodes) # retry with new nodes added - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) else: # if valid question return seed_question @dataclass +class ComplexEvolution(Evolution): + se: t.Optional[SimpleEvolution] = field(default=None, repr=False) + evolution_filter: t.Optional[EvolutionFilter] = field(default=None, repr=False) + + def init_evolution(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + assert self.node_filter is not None, "node filter cannot be None" + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + +@dataclass(unsafe_hash=True) class MultiContextEvolution(ComplexEvolution): - async def _aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se._aevolve(current_nodes) + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.docstore is not None, "docstore cannot be None" + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) @@ -223,7 +241,7 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: if not await self.question_filter.afilter(compressed_question): # retry current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) assert self.evolution_filter is not None, "evolution filter cannot be None" if not await self.evolution_filter.afilter( @@ -231,15 +249,19 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: ): # retry current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) return compressed_question -@dataclass +@dataclass(unsafe_hash=True) class ReasoningEvolution(ComplexEvolution): - async def _aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se._aevolve(current_nodes) + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[ReasoningEvolution] simple question generated: %s", simple_question ) @@ -263,7 +285,7 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: if not await self.question_filter.afilter(compressed_question): # retry current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) assert self.evolution_filter is not None, "evolution filter cannot be None" if not await self.evolution_filter.afilter( @@ -274,6 +296,11 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: logger.debug( "evolution_filter failed, retrying with %s", len(current_nodes.nodes) ) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) return reasoning_question + + +simple = SimpleEvolution() +multi_context = MultiContextEvolution() +reasoning = ReasoningEvolution() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 94b957034..b0a2f7628 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import typing as t from dataclasses import dataclass +import pandas as pd from langchain.embeddings import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI from llama_index.readers.schema import Document as LlamaindexDocument @@ -9,7 +12,35 @@ from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution +from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow, Evolution +from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter + +Distributions = t.Dict[Evolution, float] + + +@dataclass +class TestDataset: + """ + TestDataset class + """ + + test_data: t.List[DataRow] + + def to_pandas(self) -> pd.DataFrame: + data_samples = [] + for data in self.test_data: + question_type = data.question_type + data = { + "question": data.question, + "context": data.context, + "answer": "" if data.answer is None else data.answer, + "question_type": question_type, + "episode_done": True, + "evolution_elimination": data.evolution_elimination, + } + data_samples.append(data) + + return pd.DataFrame.from_records(data_samples) @dataclass @@ -53,31 +84,52 @@ def with_openai( ) def generate_with_llamaindex_docs( - self, documents: t.Sequence[LlamaindexDocument], test_size: int + self, + documents: t.Sequence[LlamaindexDocument], + test_size: int, + distributions: Distributions = {}, ): # chunk documents and add to docstore self.docstore.add_documents( [Document.from_llamaindex_document(doc) for doc in documents] ) - return self.generate(test_size=test_size) - - def generate(self, test_size: int): - node_filter = NodeFilter(self.critic_llm) - ques_filter = QuestionFilter(self.critic_llm) - exec = Executor() - qs = [] - for i in range(test_size): - se = SimpleEvolution(node_filter, ques_filter) - exec.submit( - se.aevolve, - self.generator_llm, - self.docstore, - name=f"SimpleEvolution-{i}", - ) + return self.generate(test_size=test_size, distributions=distributions) + + def generate(self, test_size: int, distributions: Distributions = {}): + # init filters and evolutions + for evolution in distributions.keys(): + if evolution.generator_llm is None: + evolution.generator_llm = self.generator_llm + if evolution.docstore is None: + evolution.docstore = self.docstore + + if evolution.question_filter is None: + evolution.question_filter = QuestionFilter(llm=self.critic_llm) + if evolution.node_filter is None: + evolution.node_filter = NodeFilter(llm=self.critic_llm) + + if isinstance(evolution, ComplexEvolution): + evolution.init_evolution() + if evolution.evolution_filter is None: + evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + + exec = Executor(raise_exceptions=True, is_async=True) + + current_nodes = [ + CurrentNodes(root_node=n, nodes=[n]) + for n in self.docstore.get_random_nodes(k=test_size) + ] + for evolution, probability in distributions.items(): + for i in range(round(probability * test_size)): + exec.submit( + evolution.aevolve, + current_nodes[i], + name=f"{evolution.__class__.__name__}-{i}", + ) try: - qs = exec.results() + test_data_rows = exec.results() except ValueError as e: raise e - return qs + return TestDataset(test_data=test_data_rows) From 602ae7bcf11f695bd02791b4a2b89db6dc353fbb Mon Sep 17 00:00:00 2001 From: jjmachan Date: Mon, 22 Jan 2024 17:18:19 -0800 Subject: [PATCH 09/14] final polish --- src/ragas/testset/evolutions.py | 15 ++++++++++++--- src/ragas/testset/generator.py | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 28e93d2e2..09e62eaa3 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -152,7 +152,7 @@ def generate_datarow( ) -@dataclass(unsafe_hash=True) +@dataclass class SimpleEvolution(Evolution): async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: assert self.docstore is not None, "docstore cannot be None" @@ -185,6 +185,9 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str # if valid question return seed_question + def __hash__(self): + return hash(self.__class__.__name__) + @dataclass class ComplexEvolution(Evolution): @@ -204,7 +207,7 @@ def init_evolution(self): self.evolution_filter = EvolutionFilter(self.node_filter.llm) -@dataclass(unsafe_hash=True) +@dataclass class MultiContextEvolution(ComplexEvolution): async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: assert self.docstore is not None, "docstore cannot be None" @@ -253,8 +256,11 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str return compressed_question + def __hash__(self): + return hash(self.__class__.__name__) -@dataclass(unsafe_hash=True) + +@dataclass class ReasoningEvolution(ComplexEvolution): async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: assert self.generator_llm is not None, "generator_llm cannot be None" @@ -300,6 +306,9 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str return reasoning_question + def __hash__(self): + return hash(self.__class__.__name__) + simple = SimpleEvolution() multi_context = MultiContextEvolution() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index b0a2f7628..748ff615e 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from dataclasses import dataclass @@ -15,6 +16,7 @@ from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow, Evolution from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter +logger = logging.getLogger(__name__) Distributions = t.Dict[Evolution, float] @@ -88,6 +90,7 @@ def generate_with_llamaindex_docs( documents: t.Sequence[LlamaindexDocument], test_size: int, distributions: Distributions = {}, + **kwargs, ): # chunk documents and add to docstore self.docstore.add_documents( @@ -96,9 +99,11 @@ def generate_with_llamaindex_docs( return self.generate(test_size=test_size, distributions=distributions) - def generate(self, test_size: int, distributions: Distributions = {}): + def generate( + self, test_size: int, distributions: Distributions = {}, show_debug_logs=False + ): # init filters and evolutions - for evolution in distributions.keys(): + for evolution in distributions: if evolution.generator_llm is None: evolution.generator_llm = self.generator_llm if evolution.docstore is None: @@ -113,8 +118,12 @@ def generate(self, test_size: int, distributions: Distributions = {}): evolution.init_evolution() if evolution.evolution_filter is None: evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + if show_debug_logs: + from ragas.utils import patch_logger + + patch_logger("ragas.testset.evolutions", logging.DEBUG) - exec = Executor(raise_exceptions=True, is_async=True) + exec = Executor(desc="Generating", raise_exceptions=True, is_async=True) current_nodes = [ CurrentNodes(root_node=n, nodes=[n]) From 25d8e15dd2a523a7cda5d84690a2c8c009dbbc8f Mon Sep 17 00:00:00 2001 From: jjmachan Date: Mon, 22 Jan 2024 17:28:19 -0800 Subject: [PATCH 10/14] added benchmark --- src/ragas/testset/generator.py | 4 ++-- tests/benchmarks/benchmark_testsetgen.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 748ff615e..29714a9a9 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -13,11 +13,11 @@ from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow, Evolution +from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter logger = logging.getLogger(__name__) -Distributions = t.Dict[Evolution, float] +Distributions = t.Dict[t.Any, float] @dataclass diff --git a/tests/benchmarks/benchmark_testsetgen.py b/tests/benchmarks/benchmark_testsetgen.py index e78e7eb48..282e2d75f 100644 --- a/tests/benchmarks/benchmark_testsetgen.py +++ b/tests/benchmarks/benchmark_testsetgen.py @@ -3,10 +3,13 @@ from llama_index import download_loader +from ragas.testset.evolutions import multi_context, reasoning, simple from ragas.testset.generator import TestsetGenerator generator = TestsetGenerator.with_openai() +distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1} + def get_documents(): SemanticScholarReader = download_loader("SemanticScholarReader") @@ -30,7 +33,9 @@ def get_documents(): os.environ["PYTHONASYNCIODEBUG"] = "1" print("Starting [Asyncio]") start = time.time() - generator.generate_with_llamaindex_docs(documents, test_size=100) + generator.generate_with_llamaindex_docs( + documents=documents, test_size=100, distributions=distributions + ) print(f"Time taken: {time.time() - start:.2f}s") # Threads From c4ebd3cea9d1352b793face43ccb7358bff19c8a Mon Sep 17 00:00:00 2001 From: jjmachan Date: Tue, 23 Jan 2024 23:10:14 -0800 Subject: [PATCH 11/14] polish things --- src/ragas/executor.py | 5 +++ src/ragas/testset/docstore.py | 18 ++++----- src/ragas/testset/evolutions.py | 68 ++++++++++++++++++--------------- src/ragas/testset/generator.py | 15 ++------ 4 files changed, 56 insertions(+), 50 deletions(-) diff --git a/src/ragas/executor.py b/src/ragas/executor.py index 43ce7bad0..a977b3894 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -10,6 +10,7 @@ @dataclass class Executor: desc: str = "Evaluating" + keep_progress_bar: bool = True is_async: bool = True max_workers: t.Optional[int] = None futures: t.List[t.Any] = field(default_factory=list, repr=False) @@ -74,6 +75,8 @@ async def _aresults(self) -> t.List[t.Any]: asyncio.as_completed(self.futures), desc=self.desc, total=len(self.futures), + # whether you want to keep the progress bar after completion + leave=not self.keep_progress_bar, ): r = (-1, np.nan) try: @@ -109,6 +112,8 @@ def results(self) -> t.List[t.Any]: as_completed(self.futures), desc=self.desc, total=len(self.futures), + # whether you want to keep the progress bar after completion + leave=not self.keep_progress_bar, ): r = (-1, np.nan) try: diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 44e88a6b1..078fae391 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -14,8 +14,8 @@ from langchain_core.pydantic_v1 import Field from llama_index.readers.schema import Document as LlamaindexDocument -from ragas.async_utils import run_async_tasks from ragas.embeddings.base import BaseRagasEmbeddings +from ragas.executor import Executor Embedding = t.Union[t.List[float], npt.NDArray[np.float64]] logger = logging.getLogger(__name__) @@ -204,22 +204,22 @@ def add_nodes( assert self.embeddings is not None, "Embeddings must be set" # NOTE: Adds everything in async mode for now. - embed_tasks = [] - docs_to_embed = [] + nodes_to_embed = [] # get embeddings for the docs + executor = Executor( + desc="embedding nodes", is_async=True, raise_exceptions=True + ) for n in nodes: if n.embedding is None: - embed_tasks.append(self.embeddings.aembed_query(n.page_content)) - docs_to_embed.append(n) + nodes_to_embed.append(n) + executor.submit(self.embeddings.aembed_query, n.page_content) else: self.nodes.append(n) self.node_map[n.doc_id] = n self.node_embeddings_list.append(n.embedding) - embeddings = run_async_tasks( - embed_tasks, show_progress=show_progress, progress_bar_desc=desc - ) - for n, embedding in zip(docs_to_embed, embeddings): + embeddings = executor.results() + for n, embedding in zip(nodes_to_embed, embeddings): n.embedding = embedding self.nodes.append(n) self.node_map[n.doc_id] = n diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 09e62eaa3..d234ae3a1 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -3,12 +3,14 @@ import logging import typing as t from abc import abstractmethod -from collections import namedtuple from dataclasses import dataclass, field from fsspec.exceptions import asyncio +from langchain_core.pydantic_v1 import BaseModel from numpy.random import default_rng +from ragas.llms import BaseRagasLLM +from ragas.llms.prompt import Prompt from ragas.testset.docstore import Direction, DocumentStore, Node from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( @@ -22,10 +24,6 @@ rng = default_rng() logger = logging.getLogger(__name__) -if t.TYPE_CHECKING: - from ragas.llms import BaseRagasLLM - from ragas.llms.prompt import Prompt - @dataclass class CurrentNodes: @@ -33,21 +31,19 @@ class CurrentNodes: nodes: t.List[Node] = field(default_factory=list) -DataRow = namedtuple( - "DataRow", - [ - "question", - "context", - "answer", - "question_type", - "evolution_elimination", - ], -) +EvolutionOutput = t.Tuple[str, CurrentNodes, str] + + +class DataRow(BaseModel): + question: str + context: str + answer: str + evolution_type: str @dataclass class Evolution: - generator_llm: t.Optional[BaseRagasLLM] = None + generator_llm: BaseRagasLLM = t.cast(BaseRagasLLM, None) docstore: t.Optional[DocumentStore] = None node_filter: t.Optional[NodeFilter] = None question_filter: t.Optional[QuestionFilter] = None @@ -61,7 +57,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node: async def aretry_evolve( self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True - ) -> str: + ) -> EvolutionOutput: if update_count: current_tries += 1 logger.info("retrying evolution: %s times", current_tries) @@ -112,22 +108,29 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow: async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: # init tries with 0 when first called current_tries = 0 - evolved_question = await self._aevolve(current_tries, current_nodes) + ( + evolved_question, + current_nodes, + evolution_type, + ) = await self._aevolve(current_tries, current_nodes) + return self.generate_datarow( question=evolved_question, current_nodes=current_nodes, + evolution_type=evolution_type, ) @abstractmethod - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: ... def generate_datarow( self, question: str, current_nodes: CurrentNodes, - question_type: str = "", - evolution_elimination: bool = False, + evolution_type: str, ): assert self.generator_llm is not None, "generator_llm cannot be None" @@ -146,15 +149,16 @@ def generate_datarow( return DataRow( question=question, context=merged_nodes.page_content, - answer=answer, - question_type=question_type, - evolution_elimination=evolution_elimination, + answer="" if answer is None else answer, + evolution_type=evolution_type, ) @dataclass class SimpleEvolution(Evolution): - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: assert self.docstore is not None, "docstore cannot be None" assert self.node_filter is not None, "node filter cannot be None" assert self.generator_llm is not None, "generator_llm cannot be None" @@ -183,7 +187,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str return await self.aretry_evolve(current_tries, current_nodes) else: # if valid question - return seed_question + return seed_question, current_nodes, "simple" def __hash__(self): return hash(self.__class__.__name__) @@ -209,7 +213,9 @@ def init_evolution(self): @dataclass class MultiContextEvolution(ComplexEvolution): - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: assert self.docstore is not None, "docstore cannot be None" assert self.generator_llm is not None, "generator_llm cannot be None" assert self.question_filter is not None, "question_filter cannot be None" @@ -254,7 +260,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str current_nodes = self.se._get_more_adjacent_nodes(current_nodes) return await self.aretry_evolve(current_tries, current_nodes) - return compressed_question + return compressed_question, current_nodes, "multi_context" def __hash__(self): return hash(self.__class__.__name__) @@ -262,7 +268,9 @@ def __hash__(self): @dataclass class ReasoningEvolution(ComplexEvolution): - async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: assert self.generator_llm is not None, "generator_llm cannot be None" assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" @@ -304,7 +312,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str ) return await self.aretry_evolve(current_tries, current_nodes) - return reasoning_question + return reasoning_question, current_nodes, "reasoning" def __hash__(self): return hash(self.__class__.__name__) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 29714a9a9..d7ad6c833 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import pandas as pd -from langchain.embeddings import OpenAIEmbeddings +from langchain_openai.embeddings import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI from llama_index.readers.schema import Document as LlamaindexDocument @@ -31,16 +31,9 @@ class TestDataset: def to_pandas(self) -> pd.DataFrame: data_samples = [] for data in self.test_data: - question_type = data.question_type - data = { - "question": data.question, - "context": data.context, - "answer": "" if data.answer is None else data.answer, - "question_type": question_type, - "episode_done": True, - "evolution_elimination": data.evolution_elimination, - } - data_samples.append(data) + data_dict = dict(data) + data_dict["episode_done"] = True + data_samples.append(data_dict) return pd.DataFrame.from_records(data_samples) From 9ff9417f437e1aa88d366c3a40298702524f10e0 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Tue, 23 Jan 2024 23:15:29 -0800 Subject: [PATCH 12/14] remove testset_gen --- src/ragas/testset/__init__.py | 2 +- src/ragas/testset/evolutions.py | 5 +- src/ragas/testset/generator.py | 2 +- src/ragas/testset/testset_generator.py | 547 ------------------------- 4 files changed, 5 insertions(+), 551 deletions(-) delete mode 100644 src/ragas/testset/testset_generator.py diff --git a/src/ragas/testset/__init__.py b/src/ragas/testset/__init__.py index b065c3f9b..61be28859 100644 --- a/src/ragas/testset/__init__.py +++ b/src/ragas/testset/__init__.py @@ -1,3 +1,3 @@ -from ragas.testset.testset_generator import TestsetGenerator +from ragas.testset.generator import TestsetGenerator __all__ = ["TestsetGenerator"] diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index d234ae3a1..0f9002f89 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -31,6 +31,7 @@ class CurrentNodes: nodes: t.List[Node] = field(default_factory=list) +# (question, current_nodes, evolution_type) EvolutionOutput = t.Tuple[str, CurrentNodes, str] @@ -221,7 +222,7 @@ async def _aevolve( assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question = await self.se._aevolve(current_tries, current_nodes) + simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) @@ -275,7 +276,7 @@ async def _aevolve( assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question = await self.se._aevolve(current_tries, current_nodes) + simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[ReasoningEvolution] simple question generated: %s", simple_question ) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index d7ad6c833..870cd5eb6 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -5,8 +5,8 @@ from dataclasses import dataclass import pandas as pd -from langchain_openai.embeddings import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI +from langchain_openai.embeddings import OpenAIEmbeddings from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings import BaseRagasEmbeddings diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py deleted file mode 100644 index 0d9481b83..000000000 --- a/src/ragas/testset/testset_generator.py +++ /dev/null @@ -1,547 +0,0 @@ -from __future__ import annotations - -import logging -import typing as t -from collections import defaultdict, namedtuple -from dataclasses import dataclass - -import numpy as np -import numpy.testing as npt -import pandas as pd -from langchain.chat_models import ChatOpenAI -from langchain.embeddings import OpenAIEmbeddings -from langchain.embeddings.base import Embeddings -from langchain.prompts import ChatPromptTemplate -from langchain.schema.document import Document as LangchainDocument -from numpy.random import default_rng -from tqdm.notebook import tqdm - -try: - from llama_index.indices.query.embedding_utils import get_top_k_embeddings - from llama_index.node_parser import SimpleNodeParser - from llama_index.readers.schema import Document as LlamaindexDocument - from llama_index.schema import BaseNode -except ImportError: - raise ImportError( - "llama_index must be installed to use this function. " - "Please, install it with `pip install llama_index`." - ) - -from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper -from ragas.llms.json_load import load_as_json -from ragas.testset.prompts import ( - ANSWER_FORMULATE, - COMPRESS_QUESTION, - CONDITIONAL_QUESTION, - CONTEXT_FORMULATE, - CONVERSATION_QUESTION, - EVOLUTION_ELIMINATION, - FILTER_QUESTION, - INFORMAL_QUESTION, - MULTICONTEXT_QUESTION, - REASONING_QUESTION, - REWRITE_QUESTION, - SCORE_CONTEXT, - SEED_QUESTION, - TABLE_QA, -) - -DEFAULT_TEST_DISTRIBUTION = { - "simple": 0.4, - "reasoning": 0.3, - "multi_context": 0.0, - "conditional": 0.3, -} - -question_deep_map = { - "reasoning": "_reasoning_question", - "conditional": "_condition_question", -} - -DataRow = namedtuple( - "DataRow", - [ - "seed_question", - "question", - "context", - "answer", - "question_type", - "evolution_elimination", - ], -) - -logger = logging.getLogger(__name__) - - -@dataclass -class TestDataset: - """ - TestDataset class - """ - - test_data: t.List[DataRow] - - def to_pandas(self) -> pd.DataFrame: - data_samples = [] - for data in self.test_data: - is_conv = len(data.context) > 1 - question_type = data.question_type - data = [ - { - "seed_question": seed, - "question": qstn, - "context": ctx, - "answer": ans, - "question_type": question_type, - "episode_done": True, - "evolution_elimination": data.evolution_elimination, - } - for seed, qstn, ctx, ans in zip( - data.seed_question, data.question, data.context, data.answer - ) - ] - if is_conv: - data[0].update({"episode_done": False}) - data_samples.extend(data) - - return pd.DataFrame.from_records(data_samples) - - -class TestsetGenerator: - - """ - Ragas Test Set Generator - - Attributes - ---------- - generator_llm: BaseRagasLLM - LLM used for all the generator operations in the TestGeneration paradigm. - critique_llm: BaseRagasLLM - LLM used for all the filtering and scoring operations in TestGeneration - paradigm. - embeddings_model: Embeddings - Embeddings used for vectorizing nodes when required. - chat_qa: float - Determines the fraction of conversational questions the resulting test set. - chunk_size: int - The chunk size of nodes created from data. - test_distribution : dict - Distribution of different types of questions to be generated from given - set of documents. Defaults to {"easy":0.1, "reasoning":0.4, "conversation":0.5} - """ - - def __init__( - self, - generator_llm: BaseRagasLLM, - critic_llm: BaseRagasLLM, - embeddings_model: Embeddings, - testset_distribution: t.Optional[t.Dict[str, float]] = None, - chat_qa: float = 0.0, - chunk_size: int = 356, - seed: int = 42, - ) -> None: - self.generator_llm = generator_llm - self.critic_llm = critic_llm - self.embedding_model = embeddings_model - testset_distribution = testset_distribution or DEFAULT_TEST_DISTRIBUTION - npt.assert_almost_equal( - 1, - sum(testset_distribution.values()), - err_msg="Sum of distribution should be 1", - ) - - probs = np.cumsum(list(testset_distribution.values())) - types = testset_distribution.keys() - self.testset_distribution = dict(zip(types, probs)) - - self.chat_qa = chat_qa - self.chunk_size = chunk_size - self.threshold = 7.5 - self.max_fixes = 2 - self.rng = default_rng(seed) - - @classmethod - def from_default( - cls, - openai_generator_llm: str = "gpt-3.5-turbo", - openai_filter_llm: str = "gpt-4", - chat_qa: float = 0.3, - chunk_size: int = 512, - testset_distribution: dict = DEFAULT_TEST_DISTRIBUTION, - ): - generator_llm = LangchainLLMWrapper( - langchain_llm=ChatOpenAI(model=openai_generator_llm) # type: ignore - ) - critic_llm = LangchainLLMWrapper( - langchain_llm=ChatOpenAI(model=openai_filter_llm) # type: ignore - ) - embeddings_model = OpenAIEmbeddings() # type: ignore - return cls( - generator_llm=generator_llm, - critic_llm=critic_llm, - embeddings_model=embeddings_model, - chat_qa=chat_qa, - chunk_size=chunk_size, - testset_distribution=testset_distribution, - ) - - def _get_evolve_type(self) -> str: - """ - Decides question evolution type based on probability - """ - prob = self.rng.uniform(0, 1) - return next( - ( - key - for key in self.testset_distribution.keys() - if prob <= self.testset_distribution[key] - ), - "simple", - ) - - def _filter_context(self, context: str) -> t.Dict: - """ - context: str - The input context - - Checks if the context is has information worthy of framing a question - """ - human_prompt = SCORE_CONTEXT.format(context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.critic_llm.generate_text_with_hmpt(prompts=[prompt]) - output = results.generations[0][0].text.strip() - output = load_as_json(output) - output.update({"score": output.get("score", 0) >= self.threshold}) - return output - - def _seed_question(self, context: str, is_table_present: bool) -> t.List[str]: - if is_table_present: - human_prompt = TABLE_QA.format(context=context) - else: - from ragas.testset.prompts import demonstrations - - sample = self.rng.choice(demonstrations, 1)[0] # type: ignore - questions = self.rng.choice(sample["questions"], 2, replace=False) - questions = ( - "{" - + str( - {k: v for dic in questions.tolist() for k, v in dic.items()} - ).replace("'", '"') - + "}" - ) - demo = f'Context:{sample["context"]}\nQuestions:{questions}' - human_prompt = SEED_QUESTION.format(demonstration=demo, context=context) - - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text - if is_table_present: - return [results] - else: - results = load_as_json(results) - return [v for v in results.values()] - - def _filter_question(self, question: str) -> bool: - human_prompt = FILTER_QUESTION.format(question=question) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - - results = self.critic_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("%s", json_results) - return json_results.get("verdict") != "No" - - def _rewrite_question(self, question: str, context: str) -> str: - human_prompt = REWRITE_QUESTION.format(question=question, context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text.strip() - return results - - def _evolution_elimination(self, question1: str, question2: str) -> bool: - human_prompt = EVOLUTION_ELIMINATION.format( - question1=question1, question2=question2 - ) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - - results = self.critic_llm.generate_text_with_hmpt(prompts=[prompt]) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - return json_results.get("verdict") != "Not Equal" - - def _reasoning_question(self, question: str, context: str) -> str: - return self._qc_template(REASONING_QUESTION, question, context) - - def _condition_question(self, question: str, context: str) -> str: - return self._qc_template(CONDITIONAL_QUESTION, question, context) - - def _multicontext_question( - self, question: str, context1: str, context2: str - ) -> str: - human_prompt = MULTICONTEXT_QUESTION.format( - question=question, context1=context1, context2=context2 - ) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - return results.generations[0][0].text.strip() - - def _compress_question(self, question: str) -> str: - return self._question_transformation(COMPRESS_QUESTION, question=question) - - def _conversational_question(self, question: str) -> str: - return self._question_transformation(CONVERSATION_QUESTION, question=question) - - def _question_transformation(self, prompt, question: str) -> str: - human_prompt = prompt.format(question=question) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - return results.generations[0][0].text.strip() - - def _qc_template(self, prompt, question, context) -> str: - human_prompt = prompt.format(question=question, context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - return results.generations[0][0].text.strip() - - def _generate_answer(self, question: str, context: list[str]) -> t.List[str]: - return [ - self._qc_template(ANSWER_FORMULATE, qstn, context[i]) - for i, qstn in enumerate(question.split("\n")) - ] - - def _generate_context(self, question: str, text_chunk: str) -> t.List[str]: - return [ - self._qc_template(CONTEXT_FORMULATE, qstn, text_chunk) - for qstn in question.split("\n") - ] - - def _question_transform(self, question: str) -> str: - output = [] - for qstn in question.split("\n"): - human_prompt = INFORMAL_QUESTION.format(question=qstn) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = self.generator_llm.generate_text_with_hmpt(prompts=[prompt]) - output.append(results.generations[0][0].text.strip()) - - return "\n".join(output) - - def _remove_nodes( - self, available_indices: list[BaseNode], node_idx: list - ) -> t.List[BaseNode]: - for idx in node_idx: - if idx in available_indices: - available_indices.remove(idx) - return available_indices - - def _generate_doc_nodes_map( - self, document_nodes: t.List[BaseNode] - ) -> t.Dict[str, t.List[BaseNode]]: - doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list[BaseNode]) - for node in document_nodes: - file_name = node.metadata.get("file_name") - if file_name: - doc_nodes_map[file_name].append(node) - - return doc_nodes_map # type: ignore - - def _get_neighbour_node( - self, - node: BaseNode, - related_nodes: list[BaseNode], - max_tokens=1000, - after: bool = True, - ) -> t.List[BaseNode]: - if len(related_nodes) < 2: - logger.warn("No neighbors exists") - return [node] - idx = related_nodes.index(node) if node in related_nodes else [] - if idx: - tokens = 0 - nodes = [] - inc = 1 if after else -1 - while tokens < max_tokens and idx >= 0 and idx < len(related_nodes): # type: ignore - nodes.append(related_nodes[idx]) # type: ignore - idx += inc # type: ignore - # TODO: replace split with tikitoken - tokens += len(related_nodes[idx].get_content().split()) # type: ignore - - return nodes if after else nodes[::-1] - return [node] - - def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]: - embeddings = {} - for node in nodes: - embeddings[node.id_] = list( - self.embedding_model.embed_query(node.get_content()) - ) - - return embeddings - - def generate( - self, - documents: list[LlamaindexDocument] | list[LangchainDocument], - test_size: int, - ) -> TestDataset: - if not isinstance(documents[0], (LlamaindexDocument, LangchainDocument)): - raise ValueError( - "Testset Generatation only supports LlamaindexDocuments or LangchainDocuments" # noqa - ) - - if isinstance(documents[0], LangchainDocument): - # cast to LangchainDocument since its the only case here - documents = t.cast(t.List[LangchainDocument], documents) - documents = [ - LlamaindexDocument.from_langchain_format(doc) for doc in documents - ] - # Convert documents into nodes - # TODO: modify this to - # each node should contain docs of preffered chunk size - # append document to provide enough context - # Use metadata for this. - node_parser = SimpleNodeParser.from_defaults( - chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True - ) - documents = t.cast(t.List[LlamaindexDocument], documents) - document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents( - documents=documents - ) - # maximum 1 seed question per node - if test_size > len(document_nodes): - raise ValueError( - """Maximum possible number of samples exceeded, - reduce test_size or add more documents""" - ) - - available_nodes = document_nodes - doc_nodes_map = self._generate_doc_nodes_map(document_nodes) - count = 0 - samples = [] - - pbar = tqdm(total=test_size) - while count < test_size and available_nodes != []: - evolve_type = self._get_evolve_type() - curr_node = self.rng.choice(np.array(available_nodes), size=1)[0] - available_nodes = self._remove_nodes(available_nodes, [curr_node]) - - neighbor_nodes = doc_nodes_map[curr_node.metadata["file_name"]] - - # Append multiple nodes randomly to remove chunking bias - if len(curr_node.get_content().split()) < self.chunk_size: - size = self.chunk_size - len(curr_node.get_content().split()) - nodes = self._get_neighbour_node( - curr_node, neighbor_nodes, max_tokens=size, after=False - ) - else: - nodes = [curr_node] - - text_chunk = "\n".join([node.get_content() for node in nodes]) - logger.debug( - "Len of text chunks %s %s", len(nodes), len(text_chunk.split()) - ) - context_filter = self._filter_context(text_chunk) - if not context_filter.get("score"): - continue - - # is_table_qa = context_filter.get("is_table_present", False) - is_table_qa = False - seed_questions = self._seed_question(text_chunk, is_table_qa) - evolve_type = ( - "simple" - if ((evolve_type == "multi_context") and (is_table_qa)) - else evolve_type - ) - logger.debug("seed question %s", seed_questions) - for seed_question in seed_questions: - is_valid_question = self._filter_question(seed_question) - tries = 1 - - while tries < self.max_fixes and not is_valid_question: - nodes = self._get_neighbour_node( - nodes[0], neighbor_nodes, max_tokens=500, after=False - ) - text_chunk = "\n".join([node.get_content() for node in nodes]) - seed_question = self._rewrite_question( - question=seed_question, context=text_chunk - ) - logger.debug("rewritten question %s", seed_question) - is_valid_question = self._filter_question(seed_question) - tries += 1 - - if not is_valid_question: - continue - - if evolve_type == "multi_context": - # Find most similar chunk in same document - # TODO: handle cases where neighbour nodes is null, ie multi context across documents - # First preference - nodes from same document, second preference - other docs - node_embedding = self._embed_nodes([nodes[-1]]) - neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes) - neighbor_emb = self._embed_nodes(neighbor_nodes) - - _, indices = get_top_k_embeddings( - list(node_embedding.values())[0], - list(neighbor_emb.values()), - similarity_cutoff=self.threshold / 10, - ) - if indices: - # type cast indices from list[Any] to list[int] - indices = t.cast(t.List[int], indices) - best_neighbor = neighbor_nodes[indices[0]] - question = self._multicontext_question( - question=seed_question, - context1=text_chunk, - context2=best_neighbor.get_content(), - ) - text_chunk = "\n".join( - [text_chunk, best_neighbor.get_content()] - ) - else: - continue - - # for reasoning and conditional modes, evolve question with the - # functions from question_deep_map - else: - evolve_fun = question_deep_map.get(evolve_type) - question = ( - getattr(self, evolve_fun)(seed_question, text_chunk) - if evolve_fun - else seed_question - ) - - # compress question or convert into conversational questions - if evolve_type != "simple": - prob = self.rng.uniform(0, 1) - if self.chat_qa and prob <= self.chat_qa: - question = self._conversational_question(question=question) - else: - question = self._compress_question(question=question) - - is_valid_question = ( - self._filter_question(question) if evolve_type != "simple" else True - ) - if evolve_type != "simple": - evolution_elimination = self._evolution_elimination( - question1=seed_question, question2=question - ) - else: - evolution_elimination = None - - if is_valid_question: - # question = self._question_transform(question) - context = self._generate_context(question, text_chunk) - answer = self._generate_answer(question, context) - samples.append( - DataRow( - [seed_question], - question.split("\n"), - context, - answer, - evolve_type, - evolution_elimination, - ) - ) - count += 1 - pbar.update(count) - - return TestDataset(test_data=samples) From bd8961b23e03f3f0b5af0d80c56f545c0c59175b Mon Sep 17 00:00:00 2001 From: jjmachan Date: Tue, 23 Jan 2024 23:26:45 -0800 Subject: [PATCH 13/14] fix test --- tests/unit/test_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_simple.py b/tests/unit/test_simple.py index 43b27eef6..4089f0a23 100644 --- a/tests/unit/test_simple.py +++ b/tests/unit/test_simple.py @@ -5,7 +5,7 @@ def test_import(): import ragas - from ragas.testset.testset_generator import TestsetGenerator + from ragas.testset.generator import TestsetGenerator assert TestsetGenerator is not None assert ragas is not None From 29cbcd9aec209086132195e0c04c6e097873a10f Mon Sep 17 00:00:00 2001 From: jjmachan Date: Wed, 24 Jan 2024 16:56:40 -0800 Subject: [PATCH 14/14] fix langchain 0.1 version --- pyproject.toml | 1 + src/ragas/evaluation.py | 6 +++++- src/ragas/executor.py | 4 ++-- src/ragas/llms/base.py | 19 ++----------------- src/ragas/metrics/_answer_similarity.py | 2 +- src/ragas/metrics/_context_relevancy.py | 2 +- src/ragas/metrics/critique.py | 2 +- src/ragas/testset/docstore.py | 13 ++++++++++--- src/ragas/testset/generator.py | 15 ++++++++++++--- src/ragas/testset/prompts.py | 2 +- tests/unit/llms/test_llm.py | 2 +- 11 files changed, 37 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d73b3b813..0851232f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ dependencies = [ "tiktoken", "langchain", "langchain-core", + "langchain-community", "langchain_openai", "openai>1", "pysbd>=0.3.4", diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index 1a1787259..6dfc90403 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -154,7 +154,11 @@ def evaluate( [m.init_model() for m in metrics] executor = Executor( - is_async=is_async, max_workers=max_workers, raise_exceptions=raise_exceptions + desc="Evaluating", + keep_progress_bar=True, + is_async=is_async, + max_workers=max_workers, + raise_exceptions=raise_exceptions, ) # new evaluation chain row_run_managers = [] diff --git a/src/ragas/executor.py b/src/ragas/executor.py index a977b3894..079fca4f1 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -76,7 +76,7 @@ async def _aresults(self) -> t.List[t.Any]: desc=self.desc, total=len(self.futures), # whether you want to keep the progress bar after completion - leave=not self.keep_progress_bar, + leave=self.keep_progress_bar, ): r = (-1, np.nan) try: @@ -113,7 +113,7 @@ def results(self) -> t.List[t.Any]: desc=self.desc, total=len(self.futures), # whether you want to keep the progress bar after completion - leave=not self.keep_progress_bar, + leave=self.keep_progress_bar, ): r = (-1, np.nan) try: diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 6224a86ad..8c2f2af3f 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -4,14 +4,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI -from langchain.llms import AzureOpenAI, OpenAI, VertexAI +from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI, ChatVertexAI +from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI from langchain_core.language_models import BaseLanguageModel from langchain_core.outputs import LLMResult if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks - from langchain_core.prompts import ChatPromptTemplate from ragas.llms.prompt import PromptValue @@ -62,20 +61,6 @@ async def agenerate_text( ) -> LLMResult: ... - # TODO: remove after testset generator is refactored - def generate_text_with_hmpt( - self, - prompts: t.List[ChatPromptTemplate], - n: int = 1, - temperature: float = 1e-8, - stop: t.Optional[t.List[str]] = None, - callbacks: Callbacks = [], - ) -> LLMResult: - from ragas.llms.prompt import PromptValue - - prompt = PromptValue(prompt_str=prompts[0].format()) - return self.generate_text(prompt, n, temperature, stop, callbacks) - @dataclass class LangchainLLMWrapper(BaseRagasLLM): diff --git a/src/ragas/metrics/_answer_similarity.py b/src/ragas/metrics/_answer_similarity.py index d59e8c146..f6a3669fb 100644 --- a/src/ragas/metrics/_answer_similarity.py +++ b/src/ragas/metrics/_answer_similarity.py @@ -10,7 +10,7 @@ from ragas.metrics.base import EvaluationMode, MetricWithEmbeddings, MetricWithLLM if t.TYPE_CHECKING: - from langchain.callbacks.base import Callbacks + from langchain_core.callbacks.base import Callbacks logger = logging.getLogger(__name__) diff --git a/src/ragas/metrics/_context_relevancy.py b/src/ragas/metrics/_context_relevancy.py index 827b0fb4b..de2ebd464 100644 --- a/src/ragas/metrics/_context_relevancy.py +++ b/src/ragas/metrics/_context_relevancy.py @@ -11,7 +11,7 @@ from ragas.metrics.base import EvaluationMode, MetricWithLLM if t.TYPE_CHECKING: - from langchain.callbacks.base import Callbacks + from langchain_core.callbacks.base import Callbacks logger = logging.getLogger(__name__) diff --git a/src/ragas/metrics/critique.py b/src/ragas/metrics/critique.py index 8c6e63307..15c9ed522 100644 --- a/src/ragas/metrics/critique.py +++ b/src/ragas/metrics/critique.py @@ -12,7 +12,7 @@ from ragas.metrics.base import EvaluationMode, MetricWithLLM if t.TYPE_CHECKING: - from langchain.callbacks.base import Callbacks + from langchain_core.callbacks.base import Callbacks from ragas.llms import BaseRagasLLM diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 078fae391..30d65d27a 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -207,12 +207,19 @@ def add_nodes( nodes_to_embed = [] # get embeddings for the docs executor = Executor( - desc="embedding nodes", is_async=True, raise_exceptions=True + desc="embedding nodes", + keep_progress_bar=False, + is_async=True, + raise_exceptions=True, ) - for n in nodes: + for i, n in enumerate(nodes): if n.embedding is None: nodes_to_embed.append(n) - executor.submit(self.embeddings.aembed_query, n.page_content) + executor.submit( + self.embeddings.aembed_query, + n.page_content, + name=f"embed_node_task[{i}]", + ) else: self.nodes.append(n) self.node_map[n.doc_id] = n diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 870cd5eb6..d1dd6bb3a 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -83,14 +83,18 @@ def generate_with_llamaindex_docs( documents: t.Sequence[LlamaindexDocument], test_size: int, distributions: Distributions = {}, - **kwargs, + show_debug_logs=False, ): # chunk documents and add to docstore self.docstore.add_documents( [Document.from_llamaindex_document(doc) for doc in documents] ) - return self.generate(test_size=test_size, distributions=distributions) + return self.generate( + test_size=test_size, + distributions=distributions, + show_debug_logs=show_debug_logs, + ) def generate( self, test_size: int, distributions: Distributions = {}, show_debug_logs=False @@ -116,7 +120,12 @@ def generate( patch_logger("ragas.testset.evolutions", logging.DEBUG) - exec = Executor(desc="Generating", raise_exceptions=True, is_async=True) + exec = Executor( + desc="Generating", + keep_progress_bar=True, + raise_exceptions=True, + is_async=True, + ) current_nodes = [ CurrentNodes(root_node=n, nodes=[n]) diff --git a/src/ragas/testset/prompts.py b/src/ragas/testset/prompts.py index 16269f362..d2addc6bb 100644 --- a/src/ragas/testset/prompts.py +++ b/src/ragas/testset/prompts.py @@ -1,4 +1,4 @@ -from langchain.prompts import HumanMessagePromptTemplate +from langchain_core.prompts import HumanMessagePromptTemplate from ragas.llms.prompt import Prompt diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index 09b6b0f03..c83ce48f9 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -2,7 +2,7 @@ import typing as t -from langchain.schema import Generation, LLMResult +from langchain_core.outputs import Generation, LLMResult from ragas.llms.base import BaseRagasLLM