Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ragas/testset/docstore.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import heapq
import logging
import typing as t
Expand All @@ -12,11 +14,13 @@
from langchain.text_splitter import TextSplitter
from langchain_core.documents import Document as LCDocument
from langchain_core.pydantic_v1 import Field
from llama_index.readers.schema import Document as LlamaindexDocument

from ragas.embeddings.base import BaseRagasEmbeddings
from ragas.executor import Executor

if t.TYPE_CHECKING:
from llama_index.readers.schema import Document as LlamaindexDocument

Embedding = t.Union[t.List[float], npt.NDArray[np.float64]]
logger = logging.getLogger(__name__)
rng = np.random.default_rng()
Expand Down
41 changes: 35 additions & 6 deletions src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from llama_index.readers.schema import Document as LlamaindexDocument

from ragas._analytics import TesetGenerationEvent, track
from ragas.embeddings import BaseRagasEmbeddings
Expand All @@ -17,9 +16,14 @@
from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter

logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from llama_index.readers.schema import Document as LlamaindexDocument
from langchain_core.documents import Document as LCDocument

Distributions = t.Dict[t.Any, float]

logger = logging.getLogger(__name__)


@dataclass
class TestDataset:
Expand Down Expand Up @@ -79,12 +83,14 @@ def with_openai(
docstore=docstore,
)

# if you add any arguments to this function, make sure to add them to
# generate_with_langchain_docs as well
def generate_with_llamaindex_docs(
self,
documents: t.Sequence[LlamaindexDocument],
test_size: int,
distributions: Distributions = {},
show_debug_logs=False,
with_debugging_logs=False,
):
# chunk documents and add to docstore
self.docstore.add_documents(
Expand All @@ -94,11 +100,34 @@ def generate_with_llamaindex_docs(
return self.generate(
test_size=test_size,
distributions=distributions,
show_debug_logs=show_debug_logs,
with_debugging_logs=with_debugging_logs,
)

# if you add any arguments to this function, make sure to add them to
# generate_with_langchain_docs as well
def generate_with_langchain_docs(
self,
documents: t.Sequence[LCDocument],
test_size: int,
distributions: Distributions = {},
with_debugging_logs=False,
):
# chunk documents and add to docstore
self.docstore.add_documents(
[Document.from_langchain_document(doc) for doc in documents]
)

return self.generate(
test_size=test_size,
distributions=distributions,
with_debugging_logs=with_debugging_logs,
)

def generate(
self, test_size: int, distributions: Distributions = {}, show_debug_logs=False
self,
test_size: int,
distributions: Distributions = {},
with_debugging_logs=False,
):
# init filters and evolutions
for evolution in distributions:
Expand All @@ -116,7 +145,7 @@ def generate(
evolution.init_evolution()
if evolution.evolution_filter is None:
evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm)
if show_debug_logs:
if with_debugging_logs:
from ragas.utils import patch_logger

patch_logger("ragas.testset.evolutions", logging.DEBUG)
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def test_load_userid_from_json_file(tmp_path, monkeypatch):


def test_testset_generation_tracking(monkeypatch):

import ragas._analytics as analyticsmodule
from ragas._analytics import TesetGenerationEvent, track
from ragas.testset.evolutions import multi_context, reasoning, simple
Expand Down