diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 30d65d27a..59095d104 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import heapq import logging import typing as t @@ -12,11 +14,13 @@ from langchain.text_splitter import TextSplitter from langchain_core.documents import Document as LCDocument from langchain_core.pydantic_v1 import Field -from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings.base import BaseRagasEmbeddings from ragas.executor import Executor +if t.TYPE_CHECKING: + from llama_index.readers.schema import Document as LlamaindexDocument + Embedding = t.Union[t.List[float], npt.NDArray[np.float64]] logger = logging.getLogger(__name__) rng = np.random.default_rng() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 3cca54ef4..7db63dd79 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -7,7 +7,6 @@ import pandas as pd from langchain_openai.chat_models import ChatOpenAI from langchain_openai.embeddings import OpenAIEmbeddings -from llama_index.readers.schema import Document as LlamaindexDocument from ragas._analytics import TesetGenerationEvent, track from ragas.embeddings import BaseRagasEmbeddings @@ -17,9 +16,14 @@ from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter -logger = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from llama_index.readers.schema import Document as LlamaindexDocument + from langchain_core.documents import Document as LCDocument + Distributions = t.Dict[t.Any, float] +logger = logging.getLogger(__name__) + @dataclass class TestDataset: @@ -79,12 +83,14 @@ def with_openai( docstore=docstore, ) + # if you add any arguments to this function, make sure to add them to + # generate_with_langchain_docs as well def generate_with_llamaindex_docs( self, documents: t.Sequence[LlamaindexDocument], test_size: int, distributions: Distributions = {}, - show_debug_logs=False, + with_debugging_logs=False, ): # chunk documents and add to docstore self.docstore.add_documents( @@ -94,11 +100,34 @@ def generate_with_llamaindex_docs( return self.generate( test_size=test_size, distributions=distributions, - show_debug_logs=show_debug_logs, + with_debugging_logs=with_debugging_logs, + ) + + # if you add any arguments to this function, make sure to add them to + # generate_with_langchain_docs as well + def generate_with_langchain_docs( + self, + documents: t.Sequence[LCDocument], + test_size: int, + distributions: Distributions = {}, + with_debugging_logs=False, + ): + # chunk documents and add to docstore + self.docstore.add_documents( + [Document.from_langchain_document(doc) for doc in documents] + ) + + return self.generate( + test_size=test_size, + distributions=distributions, + with_debugging_logs=with_debugging_logs, ) def generate( - self, test_size: int, distributions: Distributions = {}, show_debug_logs=False + self, + test_size: int, + distributions: Distributions = {}, + with_debugging_logs=False, ): # init filters and evolutions for evolution in distributions: @@ -116,7 +145,7 @@ def generate( evolution.init_evolution() if evolution.evolution_filter is None: evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) - if show_debug_logs: + if with_debugging_logs: from ragas.utils import patch_logger patch_logger("ragas.testset.evolutions", logging.DEBUG) diff --git a/tests/unit/test_analytics.py b/tests/unit/test_analytics.py index 1e09ff305..887f7915b 100644 --- a/tests/unit/test_analytics.py +++ b/tests/unit/test_analytics.py @@ -90,7 +90,6 @@ def test_load_userid_from_json_file(tmp_path, monkeypatch): def test_testset_generation_tracking(monkeypatch): - import ragas._analytics as analyticsmodule from ragas._analytics import TesetGenerationEvent, track from ragas.testset.evolutions import multi_context, reasoning, simple