From dcd82327246a538b30808d3ba56eb385355a172a Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 26 Jan 2024 20:08:08 -0800 Subject: [PATCH 1/4] add relevant context --- src/ragas/testset/evolutions.py | 22 +++++++++++++++++++++- src/ragas/testset/prompts.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 255a1b93a..ae66784d1 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -10,11 +10,13 @@ from numpy.random import default_rng from ragas.llms import BaseRagasLLM +from ragas.llms.json_load import json_loader from ragas.llms.prompt import Prompt from ragas.testset.docstore import Direction, DocumentStore, Node from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( compress_question_prompt, + find_relevent_context_prompt, multi_context_question_prompt, question_answer_prompt, reasoning_question_prompt, @@ -135,7 +137,25 @@ def generate_datarow( ): assert self.generator_llm is not None, "generator_llm cannot be None" - merged_nodes = self.merge_nodes(current_nodes) + node_content = [ + f"{i}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes) + ] + results = self.generator_llm.generate_text( + prompt=find_relevent_context_prompt.format( + question=question, contexts=node_content + ) + ) + relevant_context_indices = json_loader.safe_load( + results.generations[0][0].text.strip(), llm=self.generator_llm + ).get("relevant_context", None) + if relevant_context_indices is None: + relevant_context = CurrentNodes( + root_node=current_nodes.root_node, nodes=current_nodes.nodes + ) + else: + relevant_context = current_nodes + + merged_nodes = self.merge_nodes(relevant_context) results = self.generator_llm.generate_text( prompt=question_answer_prompt.format( question=question, context=merged_nodes.page_content diff --git a/src/ragas/testset/prompts.py b/src/ragas/testset/prompts.py index d96bcdc7e..7ebdf583f 100644 --- a/src/ragas/testset/prompts.py +++ b/src/ragas/testset/prompts.py @@ -297,3 +297,35 @@ output_type="string", language="english", ) + + +find_relevent_context_prompt = Prompt( + name="find_relevent_context", + instruction="Given a question and set of contexts, find the most relevant contexts to answer the question.", + examples=[ + { + "question": "What is the capital of France?", + "contexts": [ + "1. France is a country in Western Europe. It has several cities, including Paris, Lyon, and Marseille. Paris is not only known for its cultural landmarks like the Eiffel Tower and the Louvre Museum but also as the administrative center.", + "2. The capital of France is Paris. It is also the most populous city in France, with a population of over 2 million people. Paris is known for its cultural landmarks like the Eiffel Tower and the Louvre Museum.", + "3. Paris is the capital of France. It is also the most populous city in France, with a population of over 2 million people. Paris is known for its cultural landmarks like the Eiffel Tower and the Louvre Museum.", + ], + "output": { + "relevent_contexts": [1, 2], + }, + }, + { + "question": "How does caffeine affect the body and what are its common sources?", + "contexts": [ + "1. Caffeine is a central nervous system stimulant. It can temporarily ward off drowsiness and restore alertness. It primarily affects the brain, where it alters the function of neurotransmitters.", + "2. Regular physical activity is essential for maintaining good health. It can help control weight, combat health conditions, boost energy, and promote better sleep.", + "3. Common sources of caffeine include coffee, tea, cola, and energy drinks. These beverages are consumed worldwide and are known for providing a quick boost of energy.", + ], + "output": {"relevant_contexts": [1, 2]}, + }, + ], + input_keys=["question", "contexts"], + output_key="output", + output_type="json", + language="english", +) From d86bbf4d649a141a32f76ad9e98356ee5f2a1e26 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 26 Jan 2024 23:31:27 -0800 Subject: [PATCH 2/4] add rng --- src/ragas/testset/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ragas/testset/utils.py b/src/ragas/testset/utils.py index 4582a89ec..2b1827a5b 100644 --- a/src/ragas/testset/utils.py +++ b/src/ragas/testset/utils.py @@ -3,6 +3,10 @@ import re import warnings +import numpy as np + +rng = np.random.default_rng(seed=42) + def load_as_score(text): """ From 5fd394a577c28fad0e52081397f300461abeeff0 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 26 Jan 2024 23:31:41 -0800 Subject: [PATCH 3/4] use seeded rng --- src/ragas/testset/docstore.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 59095d104..4dc2cdc3e 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from random import choices import numpy as np import numpy.typing as npt @@ -17,13 +16,13 @@ from ragas.embeddings.base import BaseRagasEmbeddings from ragas.executor import Executor +from ragas.testset.utils import rng if t.TYPE_CHECKING: from llama_index.readers.schema import Document as LlamaindexDocument Embedding = t.Union[t.List[float], npt.NDArray[np.float64]] logger = logging.getLogger(__name__) -rng = np.random.default_rng() class Document(LCDocument): @@ -243,7 +242,7 @@ def get_document(self, doc_id: str) -> Node: raise NotImplementedError def get_random_nodes(self, k=1) -> t.List[Node]: - return choices(self.nodes, k=k) + return rng.choice(self.nodes, size=k).tolist() def get_similar( self, node: Node, threshold: float = 0.7, top_k: int = 3 From 4b2085d2a6e383a974a5c9929360a7ef9a572770 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 27 Jan 2024 00:18:15 -0800 Subject: [PATCH 4/4] type fixes --- src/ragas/testset/docstore.py | 2 +- src/ragas/testset/evolutions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 4dc2cdc3e..5c2a07063 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -242,7 +242,7 @@ def get_document(self, doc_id: str) -> Node: raise NotImplementedError def get_random_nodes(self, k=1) -> t.List[Node]: - return rng.choice(self.nodes, size=k).tolist() + return rng.choice(np.array(self.nodes), size=k).tolist() def get_similar( self, node: Node, threshold: float = 0.7, top_k: int = 3 diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index fbaebeb96..756b59f86 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -16,8 +16,8 @@ from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( compress_question_prompt, - find_relevent_context_prompt, conditional_question_prompt, + find_relevent_context_prompt, multi_context_question_prompt, question_answer_prompt, reasoning_question_prompt,