From c8042d9cf76bb4782eabe0be5961f197770fe95e Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 4 Nov 2024 16:35:26 +0530 Subject: [PATCH 01/15] add persona and topic extraction --- src/ragas/testset/persona.py | 149 ++++++++++++++++++ .../transforms/extractors/llm_based.py | 46 ++++++ 2 files changed, 195 insertions(+) create mode 100644 src/ragas/testset/persona.py diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py new file mode 100644 index 000000000..5716da3c4 --- /dev/null +++ b/src/ragas/testset/persona.py @@ -0,0 +1,149 @@ +import logging +import random +import typing as t +from dataclasses import dataclass, field + +import numpy as np +from pydantic import BaseModel + +from ragas.llms.base import BaseRagasLLM +from ragas.prompt import PydanticPrompt +from ragas.testset.graph import KnowledgeGraph, Node + +logger = logging.getLogger(__name__) + + +def default_filter(node: Node) -> bool: + + if ( + node.type.name == "DOCUMENT" + and node.properties.get("summary_embedding") is not None + ): + return True + elif ( + node.type.name == "CHUNK" + and node.properties.get("topic_description_embedding") is not None + ): + return random.random() < 0.1 + else: + return False + + +class SummaryInput(BaseModel): + summaries: t.List[str] + + +class Persona(BaseModel): + name: str + role_description: str + + +class PersonasList(BaseModel): + personas: t.List[Persona] + + def __getitem__(self, key: str) -> Persona: + for persona in self.personas: + if persona.name == key: + return persona + raise KeyError(f"No persona found with name '{key}'") + + +# Define the prompt class +class PersonaGenerationPrompt(PydanticPrompt[SummaryInput, PersonasList]): + instruction: str = ( + "Using the provided summaries, generate one persona for each summary who might " + "interact with the content. For each persona, include a unique name " + "and a brief role description of who they are." + ) + input_model: t.Type[SummaryInput] = SummaryInput + output_model: t.Type[PersonasList] = PersonasList + examples: t.List[t.Tuple[SummaryInput, PersonasList]] = [ + ( + SummaryInput( + summaries=[ + "Guide to Digital Marketing explains strategies for engaging audiences across various online platforms.", + "Data Privacy Essentials discusses principles for safeguarding user data and complying with privacy regulations.", + "Introduction to Project Management covers key methods for planning, executing, and monitoring projects.", + ] + ), + PersonasList( + personas=[ + Persona( + name="Digital Marketing Specialist", + role_description="Focuses on engaging audiences and growing the brand online.", + ), + Persona( + name="Data Privacy Officer", + role_description="Ensures the organization's compliance with data protection laws.", + ), + Persona( + name="Project Manager", + role_description="Oversees project timelines and ensures tasks are completed efficiently.", + ), + ] + ), + ) + ] + + +@dataclass +class PersonaGenerator: + + llm: BaseRagasLLM + num_personas: int = 5 + prompt: PydanticPrompt = PersonaGenerationPrompt() + filter_nodes: t.Callable[[Node], bool] = field( + default_factory=lambda: default_filter + ) + + def __post_init__(self): + + try: + from sklearn.cluster import KMeans + from sklearn.metrics import pairwise_distances + except ImportError: + raise ImportError( + "PersonaGenerator requires the 'scikit-learn' package to be installed. " + "You can install it with 'pip install scikit-learn'." + ) + + self.pairwise_distances = pairwise_distances + self.kmeans = KMeans(n_clusters=self.num_personas, random_state=42) + + async def generate_from_kg(self, kg: KnowledgeGraph) -> PersonasList: + + nodes = [node for node in kg.nodes if self.filter_nodes(node)] + summaries = [ + node.properties.get("summary") or node.properties.get("topic_description") + for node in nodes + ] + embeddings = [] + for node in nodes: + embeddings.append( + node.properties.get("summary_embedding") + or node.properties.get("topic_description_embedding") + ) + + embeddings = np.array(embeddings) + self.kmeans.fit(embeddings) + labels = self.kmeans.labels_ + if labels is None: + raise ValueError("No labels found from clustering") + cluster_centers = self.kmeans.cluster_centers_ + top_summaries = [] + for i in range(self.num_personas): + cluster_indices = [j for j, label in enumerate(labels) if label == i] + _ = [summaries[j] for j in cluster_indices] + centroid = cluster_centers[i] + X_cluster = embeddings[cluster_indices] + distances = self.pairwise_distances( + X_cluster, centroid.reshape(1, -1), metric="euclidean" + ).flatten() + + closest_index = distances.argmin() + representative_summary = summaries[cluster_indices[closest_index]] + top_summaries.append(representative_summary) + + prompt_input = SummaryInput(summaries=top_summaries) + response = await self.prompt.generate(data=prompt_input, llm=self.llm) + return response diff --git a/src/ragas/testset/transforms/extractors/llm_based.py b/src/ragas/testset/transforms/extractors/llm_based.py index cd53354a1..613c5b4d8 100644 --- a/src/ragas/testset/transforms/extractors/llm_based.py +++ b/src/ragas/testset/transforms/extractors/llm_based.py @@ -260,3 +260,49 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Dict[str, t.List[str]]]: return self.property_name, {} result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) return self.property_name, result.entities.model_dump() + + +class TopicDescription(BaseModel): + description: str + + +class TopicDescriptionPrompt(PydanticPrompt[StringIO, TopicDescription]): + instruction: str = ( + "Provide a concise description of the main topic(s) discussed in the following text." + ) + input_model: t.Type[StringIO] = StringIO + output_model: t.Type[TopicDescription] = TopicDescription + examples: t.List[t.Tuple[StringIO, TopicDescription]] = [ + ( + StringIO( + text="Quantum Computing\n\nQuantum computing leverages the principles of quantum mechanics to perform complex computations more efficiently than classical computers. It has the potential to revolutionize fields like cryptography, material science, and optimization problems by solving tasks that are currently intractable for classical systems." + ), + TopicDescription( + description="An introduction to quantum computing and its potential to outperform classical computers in complex computations, impacting areas such as cryptography and material science." + ), + ) + ] + + +@dataclass +class TopicDescriptionExtractor(LLMBasedExtractor): + """ + Extracts a concise description of the main topic(s) discussed in the given text. + + Attributes + ---------- + property_name : str + The name of the property to extract. + prompt : TopicDescriptionPrompt + The prompt used for extraction. + """ + + property_name: str = "topic_description" + prompt: TopicDescriptionPrompt = TopicDescriptionPrompt() + + async def extract(self, node: Node) -> t.Tuple[str, t.Any]: + node_text = node.get_property("page_content") + if node_text is None: + return self.property_name, None + result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + return self.property_name, result.description From b965463ff44ee77847d775b9c9c5713da366843b Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 4 Nov 2024 16:35:37 +0530 Subject: [PATCH 02/15] update init --- src/ragas/testset/transforms/extractors/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ragas/testset/transforms/extractors/__init__.py b/src/ragas/testset/transforms/extractors/__init__.py index 57be41bf3..be8dfbc33 100644 --- a/src/ragas/testset/transforms/extractors/__init__.py +++ b/src/ragas/testset/transforms/extractors/__init__.py @@ -5,6 +5,7 @@ NERExtractor, SummaryExtractor, TitleExtractor, + TopicDescriptionExtractor, ) from .regex_based import emails_extractor, links_extractor, markdown_headings_extractor @@ -18,4 +19,5 @@ "HeadlinesExtractor", "EmbeddingExtractor", "NERExtractor", + "TopicDescriptionExtractor", ] From 33fa657a68b6b1cdfb43b125cf39ba7e41f1424e Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 4 Nov 2024 16:35:44 +0530 Subject: [PATCH 03/15] add scikit learn to dev --- requirements/dev.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/dev.txt b/requirements/dev.txt index e788428fe..6efd34db9 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -14,4 +14,5 @@ rouge_score nltk rapidfuzz pandas -datacompy \ No newline at end of file +datacompy +scikit-learn \ No newline at end of file From 6e2c387a30a19205cbce6fd25bc42df600141c69 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 4 Nov 2024 16:37:35 +0530 Subject: [PATCH 04/15] add persona to init --- src/ragas/testset/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ragas/testset/__init__.py b/src/ragas/testset/__init__.py index 66985c8f6..7eff14115 100644 --- a/src/ragas/testset/__init__.py +++ b/src/ragas/testset/__init__.py @@ -1,4 +1,12 @@ +from ragas.testset.persona import Persona, PersonaGenerator, PersonasList from ragas.testset.synthesizers.generate import TestsetGenerator from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample -__all__ = ["TestsetGenerator", "Testset", "TestsetSample"] +__all__ = [ + "TestsetGenerator", + "Testset", + "TestsetSample", + "PersonaGenerator", + "Persona", + "PersonasList", +] From 1b8fe6e13fcf9b52cdd8399cc7b94beddb101a6f Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 4 Nov 2024 16:37:42 +0530 Subject: [PATCH 05/15] add scikit-learn --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index bf8093a7b..79bd0230a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ all = [ "rapidfuzz", "pandas", "datacompy", + "scikit-learn", ] docs = [ "mkdocs>=1.6.1", From de73e689d0862bae7b19c553b2ea65bafc6a92d6 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 4 Nov 2024 16:58:23 +0530 Subject: [PATCH 06/15] add description --- src/ragas/testset/persona.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index 5716da3c4..36281c9be 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -5,6 +5,7 @@ import numpy as np from pydantic import BaseModel +from tqdm import tqdm from ragas.llms.base import BaseRagasLLM from ragas.prompt import PydanticPrompt @@ -131,7 +132,7 @@ async def generate_from_kg(self, kg: KnowledgeGraph) -> PersonasList: raise ValueError("No labels found from clustering") cluster_centers = self.kmeans.cluster_centers_ top_summaries = [] - for i in range(self.num_personas): + for i in tqdm(range(self.num_personas), desc="Generating personas"): cluster_indices = [j for j, label in enumerate(labels) if label == i] _ = [summaries[j] for j in cluster_indices] centroid = cluster_centers[i] From d1b0c78c5491c1f5ec6e530be38f4c33c5ab8587 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 5 Nov 2024 18:52:32 +0530 Subject: [PATCH 07/15] reformat persona generation --- src/ragas/testset/persona.py | 134 ++++++++++++++--------------------- 1 file changed, 52 insertions(+), 82 deletions(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index 36281c9be..1034dd35e 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -1,14 +1,14 @@ import logging import random import typing as t -from dataclasses import dataclass, field import numpy as np +from langchain_core.callbacks import Callbacks from pydantic import BaseModel from tqdm import tqdm from ragas.llms.base import BaseRagasLLM -from ragas.prompt import PydanticPrompt +from ragas.prompt import PydanticPrompt, StringIO from ragas.testset.graph import KnowledgeGraph, Node logger = logging.getLogger(__name__) @@ -21,83 +21,54 @@ def default_filter(node: Node) -> bool: and node.properties.get("summary_embedding") is not None ): return True - elif ( - node.type.name == "CHUNK" - and node.properties.get("topic_description_embedding") is not None - ): - return random.random() < 0.1 else: return False -class SummaryInput(BaseModel): - summaries: t.List[str] - - class Persona(BaseModel): name: str role_description: str -class PersonasList(BaseModel): - personas: t.List[Persona] - - def __getitem__(self, key: str) -> Persona: - for persona in self.personas: - if persona.name == key: - return persona - raise KeyError(f"No persona found with name '{key}'") - - -# Define the prompt class -class PersonaGenerationPrompt(PydanticPrompt[SummaryInput, PersonasList]): +class PersonaGenerationPrompt(PydanticPrompt[StringIO, Persona]): instruction: str = ( - "Using the provided summaries, generate one persona for each summary who might " - "interact with the content. For each persona, include a unique name " - "and a brief role description of who they are." + "Using the provided summary, generate a single persona who would likely " + "interact with or benefit from the content. Include a unique name and a " + "concise role description of who they are." ) - input_model: t.Type[SummaryInput] = SummaryInput - output_model: t.Type[PersonasList] = PersonasList - examples: t.List[t.Tuple[SummaryInput, PersonasList]] = [ + input_model: t.Type[StringIO] = StringIO + output_model: t.Type[Persona] = Persona + examples: t.List[t.Tuple[StringIO, Persona]] = [ ( - SummaryInput( - summaries=[ - "Guide to Digital Marketing explains strategies for engaging audiences across various online platforms.", - "Data Privacy Essentials discusses principles for safeguarding user data and complying with privacy regulations.", - "Introduction to Project Management covers key methods for planning, executing, and monitoring projects.", - ] + StringIO( + text="Guide to Digital Marketing explains strategies for engaging audiences across various online platforms." ), - PersonasList( - personas=[ - Persona( - name="Digital Marketing Specialist", - role_description="Focuses on engaging audiences and growing the brand online.", - ), - Persona( - name="Data Privacy Officer", - role_description="Ensures the organization's compliance with data protection laws.", - ), - Persona( - name="Project Manager", - role_description="Oversees project timelines and ensures tasks are completed efficiently.", - ), - ] + Persona( + name="Digital Marketing Specialist", + role_description="Focuses on engaging audiences and growing the brand online.", ), ) ] -@dataclass -class PersonaGenerator: +class PersonaList(BaseModel): + personas: t.List[Persona] - llm: BaseRagasLLM - num_personas: int = 5 - prompt: PydanticPrompt = PersonaGenerationPrompt() - filter_nodes: t.Callable[[Node], bool] = field( - default_factory=lambda: default_filter - ) + def __getitem__(self, key: str) -> Persona: + for persona in self.personas: + if persona.name == key: + return persona + raise KeyError(f"No persona found with name '{key}'") - def __post_init__(self): + @classmethod + async def from_kg( + cls, + llm: BaseRagasLLM, + kg: KnowledgeGraph, + persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), + num_personas: int = 5, + callbacks: Callbacks = [], + ) -> "PersonaList": try: from sklearn.cluster import KMeans @@ -108,43 +79,42 @@ def __post_init__(self): "You can install it with 'pip install scikit-learn'." ) - self.pairwise_distances = pairwise_distances - self.kmeans = KMeans(n_clusters=self.num_personas, random_state=42) + kmeans = KMeans(n_clusters=num_personas, random_state=42) - async def generate_from_kg(self, kg: KnowledgeGraph) -> PersonasList: + nodes = [node for node in kg.nodes if default_filter(node)] + summaries = [node.properties.get("summary") for node in nodes] + if len(summaries) < num_personas: + logger.warning( + f"Only {len(summaries)} summaries found, randomly duplicating to reach {num_personas} personas." + ) + summaries.extend(random.choices(summaries, k=num_personas - len(summaries))) - nodes = [node for node in kg.nodes if self.filter_nodes(node)] - summaries = [ - node.properties.get("summary") or node.properties.get("topic_description") - for node in nodes - ] + summaries = [summary for summary in summaries if isinstance(summary, str)] embeddings = [] for node in nodes: - embeddings.append( - node.properties.get("summary_embedding") - or node.properties.get("topic_description_embedding") - ) + embeddings.append(node.properties.get("summary_embedding")) embeddings = np.array(embeddings) - self.kmeans.fit(embeddings) - labels = self.kmeans.labels_ + kmeans.fit(embeddings) + labels = kmeans.labels_ if labels is None: raise ValueError("No labels found from clustering") - cluster_centers = self.kmeans.cluster_centers_ - top_summaries = [] - for i in tqdm(range(self.num_personas), desc="Generating personas"): + cluster_centers = kmeans.cluster_centers_ + persona_list = [] + for i in tqdm(range(num_personas), desc="Generating personas"): cluster_indices = [j for j, label in enumerate(labels) if label == i] _ = [summaries[j] for j in cluster_indices] centroid = cluster_centers[i] X_cluster = embeddings[cluster_indices] - distances = self.pairwise_distances( + distances = pairwise_distances( X_cluster, centroid.reshape(1, -1), metric="euclidean" ).flatten() closest_index = distances.argmin() - representative_summary = summaries[cluster_indices[closest_index]] - top_summaries.append(representative_summary) + representative_summary: str = summaries[cluster_indices[closest_index]] + persona = await persona_generation_prompt.generate( + llm=llm, data=StringIO(text=representative_summary), callbacks=callbacks + ) + persona_list.append(persona) - prompt_input = SummaryInput(summaries=top_summaries) - response = await self.prompt.generate(data=prompt_input, llm=self.llm) - return response + return cls(personas=persona_list) From 4deb5eeec3473b9815625dd0de9a357cb6043c3a Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 5 Nov 2024 18:55:35 +0530 Subject: [PATCH 08/15] add default filter --- src/ragas/testset/__init__.py | 5 ++--- src/ragas/testset/persona.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ragas/testset/__init__.py b/src/ragas/testset/__init__.py index 7eff14115..5edfdf1b4 100644 --- a/src/ragas/testset/__init__.py +++ b/src/ragas/testset/__init__.py @@ -1,4 +1,4 @@ -from ragas.testset.persona import Persona, PersonaGenerator, PersonasList +from ragas.testset.persona import Persona, PersonaList from ragas.testset.synthesizers.generate import TestsetGenerator from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample @@ -6,7 +6,6 @@ "TestsetGenerator", "Testset", "TestsetSample", - "PersonaGenerator", "Persona", - "PersonasList", + "PersonaList", ] diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index 1034dd35e..28dff1741 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -67,6 +67,7 @@ async def from_kg( kg: KnowledgeGraph, persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), num_personas: int = 5, + filter_fn: t.Callable[[Node], bool] = default_filter, callbacks: Callbacks = [], ) -> "PersonaList": @@ -81,7 +82,7 @@ async def from_kg( kmeans = KMeans(n_clusters=num_personas, random_state=42) - nodes = [node for node in kg.nodes if default_filter(node)] + nodes = [node for node in kg.nodes if filter_fn(node)] summaries = [node.properties.get("summary") for node in nodes] if len(summaries) < num_personas: logger.warning( From 0df3d694426768b979e715638324fb0522eb0267 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 5 Nov 2024 19:43:50 +0530 Subject: [PATCH 09/15] make it a dataclass --- src/ragas/testset/persona.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index 28dff1741..7119ff622 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -1,6 +1,7 @@ import logging import random import typing as t +from dataclasses import dataclass import numpy as np from langchain_core.callbacks import Callbacks @@ -51,7 +52,8 @@ class PersonaGenerationPrompt(PydanticPrompt[StringIO, Persona]): ] -class PersonaList(BaseModel): +@dataclass +class PersonaList: personas: t.List[Persona] def __getitem__(self, key: str) -> Persona: From 830f584504ab0f9b960bd89bba98cc130d14d725 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 6 Nov 2024 00:11:02 +0530 Subject: [PATCH 10/15] simplified persona generation --- src/ragas/testset/persona.py | 95 ++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index 7119ff622..5a0de2c55 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -21,7 +21,7 @@ def default_filter(node: Node) -> bool: node.type.name == "DOCUMENT" and node.properties.get("summary_embedding") is not None ): - return True + return random.random() < 0.25 else: return False @@ -68,55 +68,78 @@ async def from_kg( llm: BaseRagasLLM, kg: KnowledgeGraph, persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), - num_personas: int = 5, + num_personas: int = 3, filter_fn: t.Callable[[Node], bool] = default_filter, callbacks: Callbacks = [], ) -> "PersonaList": - - try: - from sklearn.cluster import KMeans - from sklearn.metrics import pairwise_distances - except ImportError: - raise ImportError( - "PersonaGenerator requires the 'scikit-learn' package to be installed. " - "You can install it with 'pip install scikit-learn'." - ) - - kmeans = KMeans(n_clusters=num_personas, random_state=42) + """ + Generate personas from a knowledge graph based on cluster of similar document summaries. + + parameters: + llm: BaseRagasLLM + The LLM to use for generating the persona. + kg: KnowledgeGraph + The knowledge graph to generate personas from. + persona_generation_prompt: PersonaGenerationPrompt + The prompt to use for generating the persona. + num_personas: int + The maximum number of personas to generate. + filter_fn: Callable[[Node], bool] + A function to filter nodes in the knowledge graph. + callbacks: Callbacks + The callbacks to use for the generation process. + + + returns: + PersonaList + The list of generated personas. + """ nodes = [node for node in kg.nodes if filter_fn(node)] summaries = [node.properties.get("summary") for node in nodes] - if len(summaries) < num_personas: - logger.warning( - f"Only {len(summaries)} summaries found, randomly duplicating to reach {num_personas} personas." - ) - summaries.extend(random.choices(summaries, k=num_personas - len(summaries))) - summaries = [summary for summary in summaries if isinstance(summary, str)] + embeddings = [] for node in nodes: embeddings.append(node.properties.get("summary_embedding")) embeddings = np.array(embeddings) - kmeans.fit(embeddings) - labels = kmeans.labels_ - if labels is None: - raise ValueError("No labels found from clustering") - cluster_centers = kmeans.cluster_centers_ + cosine_similarities = np.dot(embeddings, embeddings.T) + + groups = [] + visited = set() + threshold = 0.75 + + for i, summary in enumerate(summaries): + if i in visited: + continue + group = [i] + visited.add(i) + for j in range(i + 1, len(summaries)): + if cosine_similarities[i, j] > threshold: + group.append(j) + visited.add(j) + groups.append(group) + persona_list = [] - for i in tqdm(range(num_personas), desc="Generating personas"): - cluster_indices = [j for j, label in enumerate(labels) if label == i] - _ = [summaries[j] for j in cluster_indices] - centroid = cluster_centers[i] - X_cluster = embeddings[cluster_indices] - distances = pairwise_distances( - X_cluster, centroid.reshape(1, -1), metric="euclidean" - ).flatten() - - closest_index = distances.argmin() - representative_summary: str = summaries[cluster_indices[closest_index]] + top_summaries = [] + for group in groups: + representative_summary = max([summaries[i] for i in group], key=len) + top_summaries.append(representative_summary) + + if len(top_summaries) <= num_personas: + top_summaries.extend( + np.random.choice(top_summaries, num_personas - len(top_summaries)) + ) + + for representative_summary in tqdm( + top_summaries[:num_personas], desc="Generating personas" + ): persona = await persona_generation_prompt.generate( - llm=llm, data=StringIO(text=representative_summary), callbacks=callbacks + llm=llm, + data=StringIO(text=representative_summary), + callbacks=callbacks, + temperature=1.0, ) persona_list.append(persona) From 729d970b11367f43ccc1cb649a9308d8287cdd1b Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 6 Nov 2024 00:11:08 +0530 Subject: [PATCH 11/15] remove init --- src/ragas/testset/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ragas/testset/__init__.py b/src/ragas/testset/__init__.py index 5edfdf1b4..c8ec4f33c 100644 --- a/src/ragas/testset/__init__.py +++ b/src/ragas/testset/__init__.py @@ -1,4 +1,3 @@ -from ragas.testset.persona import Persona, PersonaList from ragas.testset.synthesizers.generate import TestsetGenerator from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample @@ -6,6 +5,4 @@ "TestsetGenerator", "Testset", "TestsetSample", - "Persona", - "PersonaList", ] From 6924b9c9edc0818af370138e5787ff7f4ff17739 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 6 Nov 2024 00:11:50 +0530 Subject: [PATCH 12/15] removed scikit learn --- pyproject.toml | 5 ++--- requirements/dev.txt | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 79bd0230a..2e15c0dad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,8 @@ all = [ "rouge_score", "rapidfuzz", "pandas", - "datacompy", - "scikit-learn", -] + "datacompy",] + docs = [ "mkdocs>=1.6.1", "mkdocs-material", diff --git a/requirements/dev.txt b/requirements/dev.txt index 6efd34db9..e788428fe 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -14,5 +14,4 @@ rouge_score nltk rapidfuzz pandas -datacompy -scikit-learn \ No newline at end of file +datacompy \ No newline at end of file From 19e021d873fb6cf2556506c818333760c14334c1 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Wed, 6 Nov 2024 16:44:41 +0530 Subject: [PATCH 13/15] feat: similified function and use executor to run in batch --- src/ragas/testset/persona.py | 174 ++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 86 deletions(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index 5a0de2c55..e86f472bc 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -1,13 +1,13 @@ import logging import random import typing as t -from dataclasses import dataclass import numpy as np from langchain_core.callbacks import Callbacks from pydantic import BaseModel from tqdm import tqdm +from ragas.executor import run_async_batch from ragas.llms.base import BaseRagasLLM from ragas.prompt import PydanticPrompt, StringIO from ragas.testset.graph import KnowledgeGraph, Node @@ -16,7 +16,6 @@ def default_filter(node: Node) -> bool: - if ( node.type.name == "DOCUMENT" and node.properties.get("summary_embedding") is not None @@ -52,8 +51,7 @@ class PersonaGenerationPrompt(PydanticPrompt[StringIO, Persona]): ] -@dataclass -class PersonaList: +class PersonaList(BaseModel): personas: t.List[Persona] def __getitem__(self, key: str) -> Persona: @@ -62,85 +60,89 @@ def __getitem__(self, key: str) -> Persona: return persona raise KeyError(f"No persona found with name '{key}'") - @classmethod - async def from_kg( - cls, - llm: BaseRagasLLM, - kg: KnowledgeGraph, - persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), - num_personas: int = 3, - filter_fn: t.Callable[[Node], bool] = default_filter, - callbacks: Callbacks = [], - ) -> "PersonaList": - """ - Generate personas from a knowledge graph based on cluster of similar document summaries. - - parameters: - llm: BaseRagasLLM - The LLM to use for generating the persona. - kg: KnowledgeGraph - The knowledge graph to generate personas from. - persona_generation_prompt: PersonaGenerationPrompt - The prompt to use for generating the persona. - num_personas: int - The maximum number of personas to generate. - filter_fn: Callable[[Node], bool] - A function to filter nodes in the knowledge graph. - callbacks: Callbacks - The callbacks to use for the generation process. - - - returns: - PersonaList - The list of generated personas. - """ - - nodes = [node for node in kg.nodes if filter_fn(node)] - summaries = [node.properties.get("summary") for node in nodes] - summaries = [summary for summary in summaries if isinstance(summary, str)] - - embeddings = [] - for node in nodes: - embeddings.append(node.properties.get("summary_embedding")) - - embeddings = np.array(embeddings) - cosine_similarities = np.dot(embeddings, embeddings.T) - - groups = [] - visited = set() - threshold = 0.75 - - for i, summary in enumerate(summaries): - if i in visited: - continue - group = [i] - visited.add(i) - for j in range(i + 1, len(summaries)): - if cosine_similarities[i, j] > threshold: - group.append(j) - visited.add(j) - groups.append(group) - - persona_list = [] - top_summaries = [] - for group in groups: - representative_summary = max([summaries[i] for i in group], key=len) - top_summaries.append(representative_summary) - - if len(top_summaries) <= num_personas: - top_summaries.extend( - np.random.choice(top_summaries, num_personas - len(top_summaries)) - ) - - for representative_summary in tqdm( - top_summaries[:num_personas], desc="Generating personas" - ): - persona = await persona_generation_prompt.generate( - llm=llm, - data=StringIO(text=representative_summary), - callbacks=callbacks, - temperature=1.0, - ) - persona_list.append(persona) - - return cls(personas=persona_list) + +async def generate_personas_from_kg( + kg: KnowledgeGraph, + llm: BaseRagasLLM, + persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), + num_personas: int = 3, + filter_fn: t.Callable[[Node], bool] = default_filter, + callbacks: Callbacks = [], +) -> t.List[Persona]: + """ + Generate personas from a knowledge graph based on cluster of similar document summaries. + + parameters: + kg: KnowledgeGraph + The knowledge graph to generate personas from. + llm: BaseRagasLLM + The LLM to use for generating the persona. + persona_generation_prompt: PersonaGenerationPrompt + The prompt to use for generating the persona. + num_personas: int + The maximum number of personas to generate. + filter_fn: Callable[[Node], bool] + A function to filter nodes in the knowledge graph. + callbacks: Callbacks + The callbacks to use for the generation process. + + + returns: + t.List[Persona] + The list of generated personas. + """ + + nodes = [node for node in kg.nodes if filter_fn(node)] + summaries = [node.properties.get("summary") for node in nodes] + summaries = [summary for summary in summaries if isinstance(summary, str)] + + embeddings = [] + for node in nodes: + embeddings.append(node.properties.get("summary_embedding")) + + embeddings = np.array(embeddings) + cosine_similarities = np.dot(embeddings, embeddings.T) + + groups = [] + visited = set() + threshold = 0.75 + + for i, _ in enumerate(summaries): + if i in visited: + continue + group = [i] + visited.add(i) + for j in range(i + 1, len(summaries)): + if cosine_similarities[i, j] > threshold: + group.append(j) + visited.add(j) + groups.append(group) + + persona_list = [] + top_summaries = [] + for group in groups: + representative_summary = max([summaries[i] for i in group], key=len) + top_summaries.append(representative_summary) + + if len(top_summaries) <= num_personas: + top_summaries.extend( + np.random.choice(top_summaries, num_personas - len(top_summaries)) + ) + + # use run_async_batch to generate personas in parallel + kwargs_list = [ + { + "llm": llm, + "data": StringIO(text=summary), + "callbacks": callbacks, + "temperature": 1.0, + } + for summary in top_summaries[:num_personas] + ] + personas = run_async_batch( + desc="Generating personas", + func=persona_generation_prompt.generate, + kwargs_list=kwargs_list, + ) + + return personas From d407670f7b773f0f567ff67a7d1abdb8bd9b7b51 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Wed, 6 Nov 2024 16:45:47 +0530 Subject: [PATCH 14/15] feat: generate_personas_from_kg doesn't have to be async --- src/ragas/testset/persona.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index e86f472bc..fd20150c4 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -61,7 +61,7 @@ def __getitem__(self, key: str) -> Persona: raise KeyError(f"No persona found with name '{key}'") -async def generate_personas_from_kg( +def generate_personas_from_kg( kg: KnowledgeGraph, llm: BaseRagasLLM, persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), From 8699d4a4b03bf229b3a81a1605c53d5dd2a2397c Mon Sep 17 00:00:00 2001 From: jjmachan Date: Wed, 6 Nov 2024 17:14:41 +0530 Subject: [PATCH 15/15] style: fix ci --- src/ragas/testset/persona.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ragas/testset/persona.py b/src/ragas/testset/persona.py index fd20150c4..09a56663b 100644 --- a/src/ragas/testset/persona.py +++ b/src/ragas/testset/persona.py @@ -5,7 +5,6 @@ import numpy as np from langchain_core.callbacks import Callbacks from pydantic import BaseModel -from tqdm import tqdm from ragas.executor import run_async_batch from ragas.llms.base import BaseRagasLLM @@ -118,7 +117,6 @@ def generate_personas_from_kg( visited.add(j) groups.append(group) - persona_list = [] top_summaries = [] for group in groups: representative_summary = max([summaries[i] for i in group], key=len) @@ -139,10 +137,10 @@ def generate_personas_from_kg( } for summary in top_summaries[:num_personas] ] - personas = run_async_batch( + persona_list = run_async_batch( desc="Generating personas", func=persona_generation_prompt.generate, kwargs_list=kwargs_list, ) - return personas + return persona_list