Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ragas/testset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from ragas.testset.synthesizers.generate import TestsetGenerator
from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample

__all__ = ["TestsetGenerator", "Testset", "TestsetSample"]
__all__ = [
"TestsetGenerator",
"Testset",
"TestsetSample",
]
146 changes: 146 additions & 0 deletions src/ragas/testset/persona.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import logging
import random
import typing as t

import numpy as np
from langchain_core.callbacks import Callbacks
from pydantic import BaseModel

from ragas.executor import run_async_batch
from ragas.llms.base import BaseRagasLLM
from ragas.prompt import PydanticPrompt, StringIO
from ragas.testset.graph import KnowledgeGraph, Node

logger = logging.getLogger(__name__)


def default_filter(node: Node) -> bool:
if (
node.type.name == "DOCUMENT"
and node.properties.get("summary_embedding") is not None
):
return random.random() < 0.25
else:
return False


class Persona(BaseModel):
name: str
role_description: str


class PersonaGenerationPrompt(PydanticPrompt[StringIO, Persona]):
instruction: str = (
"Using the provided summary, generate a single persona who would likely "
"interact with or benefit from the content. Include a unique name and a "
"concise role description of who they are."
)
input_model: t.Type[StringIO] = StringIO
output_model: t.Type[Persona] = Persona
examples: t.List[t.Tuple[StringIO, Persona]] = [
(
StringIO(
text="Guide to Digital Marketing explains strategies for engaging audiences across various online platforms."
),
Persona(
name="Digital Marketing Specialist",
role_description="Focuses on engaging audiences and growing the brand online.",
),
)
]


class PersonaList(BaseModel):
personas: t.List[Persona]

def __getitem__(self, key: str) -> Persona:
for persona in self.personas:
if persona.name == key:
return persona
raise KeyError(f"No persona found with name '{key}'")


def generate_personas_from_kg(
kg: KnowledgeGraph,
llm: BaseRagasLLM,
persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(),
num_personas: int = 3,
filter_fn: t.Callable[[Node], bool] = default_filter,
callbacks: Callbacks = [],
) -> t.List[Persona]:
"""
Generate personas from a knowledge graph based on cluster of similar document summaries.

parameters:
kg: KnowledgeGraph
The knowledge graph to generate personas from.
llm: BaseRagasLLM
The LLM to use for generating the persona.
persona_generation_prompt: PersonaGenerationPrompt
The prompt to use for generating the persona.
num_personas: int
The maximum number of personas to generate.
filter_fn: Callable[[Node], bool]
A function to filter nodes in the knowledge graph.
callbacks: Callbacks
The callbacks to use for the generation process.


returns:
t.List[Persona]
The list of generated personas.
"""

nodes = [node for node in kg.nodes if filter_fn(node)]
summaries = [node.properties.get("summary") for node in nodes]
summaries = [summary for summary in summaries if isinstance(summary, str)]

embeddings = []
for node in nodes:
embeddings.append(node.properties.get("summary_embedding"))

embeddings = np.array(embeddings)
cosine_similarities = np.dot(embeddings, embeddings.T)

groups = []
visited = set()
threshold = 0.75

for i, _ in enumerate(summaries):
if i in visited:
continue
group = [i]
visited.add(i)
for j in range(i + 1, len(summaries)):
if cosine_similarities[i, j] > threshold:
group.append(j)
visited.add(j)
groups.append(group)

top_summaries = []
for group in groups:
representative_summary = max([summaries[i] for i in group], key=len)
top_summaries.append(representative_summary)

if len(top_summaries) <= num_personas:
top_summaries.extend(
np.random.choice(top_summaries, num_personas - len(top_summaries))
)

# use run_async_batch to generate personas in parallel
kwargs_list = [
{
"llm": llm,
"data": StringIO(text=summary),
"callbacks": callbacks,
"temperature": 1.0,
}
for summary in top_summaries[:num_personas]
]
persona_list = run_async_batch(
desc="Generating personas",
func=persona_generation_prompt.generate,
kwargs_list=kwargs_list,
)

return persona_list
2 changes: 2 additions & 0 deletions src/ragas/testset/transforms/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
NERExtractor,
SummaryExtractor,
TitleExtractor,
TopicDescriptionExtractor,
)
from .regex_based import emails_extractor, links_extractor, markdown_headings_extractor

Expand All @@ -18,4 +19,5 @@
"HeadlinesExtractor",
"EmbeddingExtractor",
"NERExtractor",
"TopicDescriptionExtractor",
]
46 changes: 46 additions & 0 deletions src/ragas/testset/transforms/extractors/llm_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,49 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Dict[str, t.List[str]]]:
return self.property_name, {}
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
return self.property_name, result.entities.model_dump()


class TopicDescription(BaseModel):
description: str


class TopicDescriptionPrompt(PydanticPrompt[StringIO, TopicDescription]):
instruction: str = (
"Provide a concise description of the main topic(s) discussed in the following text."
)
input_model: t.Type[StringIO] = StringIO
output_model: t.Type[TopicDescription] = TopicDescription
examples: t.List[t.Tuple[StringIO, TopicDescription]] = [
(
StringIO(
text="Quantum Computing\n\nQuantum computing leverages the principles of quantum mechanics to perform complex computations more efficiently than classical computers. It has the potential to revolutionize fields like cryptography, material science, and optimization problems by solving tasks that are currently intractable for classical systems."
),
TopicDescription(
description="An introduction to quantum computing and its potential to outperform classical computers in complex computations, impacting areas such as cryptography and material science."
),
)
]


@dataclass
class TopicDescriptionExtractor(LLMBasedExtractor):
"""
Extracts a concise description of the main topic(s) discussed in the given text.

Attributes
----------
property_name : str
The name of the property to extract.
prompt : TopicDescriptionPrompt
The prompt used for extraction.
"""

property_name: str = "topic_description"
prompt: TopicDescriptionPrompt = TopicDescriptionPrompt()

async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
return self.property_name, result.description
Loading