From dc1dc5ea364628fa50c484b7324fafce362eaf98 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Thu, 22 Feb 2024 15:59:02 -0800 Subject: [PATCH] simplify the Document and Node --- src/ragas/testset/docstore.py | 54 +++++++++++-------- tests/unit/testset_generator/test_docstore.py | 25 ++++----- tests/unit/testset_generator/test_document.py | 16 ++++++ tests/unit/testset_generator/test_nodes.py | 1 + 4 files changed, 61 insertions(+), 35 deletions(-) create mode 100644 tests/unit/testset_generator/test_document.py create mode 100644 tests/unit/testset_generator/test_nodes.py diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index ee22f6011..5de3769f7 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -29,48 +29,58 @@ logger = logging.getLogger(__name__) -class Document(LCDocument): - doc_id: str = Field(default_factory=lambda: str(uuid.uuid4())) - filename: t.Optional[str] = None - embedding: t.Optional[t.List[float]] = Field(default=None, repr=False) +@dataclass +class Document: + page_content: str + metadata: dict = field(default_factory=dict) + doc_id: str = field(default_factory=lambda: str(uuid.uuid4())) + type: t.Literal["Document"] = "Document" + + @property + def filename(self): + filename = self.metadata.get("filename") + if filename is not None: + return filename + else: + logger.info( + "Document [ID: %s] has no filename. Using doc_id as filename.", + self.doc_id, + ) + return self.doc_id + + @filename.setter + def filename(self, value): + self.metadata["filename"] = value @classmethod def from_langchain_document(cls, doc: LCDocument): doc_id = str(uuid.uuid4()) - if doc.metadata.get("filename"): - filename = doc.metadata["filename"] - else: - logger.info( - "Document [ID: %s] has no filename. Using doc_id as filename.", doc_id - ) - filename = doc_id return cls( page_content=doc.page_content, metadata=doc.metadata, doc_id=doc_id, - filename=filename, + ) + + def to_langchain_document(self) -> LCDocument: + return LCDocument( + page_content=self.page_content, + metadata=self.metadata, ) @classmethod def from_llamaindex_document(cls, doc: LlamaindexDocument): doc_id = str(uuid.uuid4()) - if doc.metadata.get("filename"): - filename = doc.metadata["filename"] - else: - logger.info( - "Document [ID: %s] has no filename. Using doc_id as filename.", doc_id - ) - filename = doc_id return cls( page_content=doc.text, metadata=doc.metadata, doc_id=doc_id, - filename=filename, ) +@dataclass class Node(Document): keyphrases: t.List[str] = Field(default_factory=list, repr=False) + embedding: Embedding = Field(default=t.cast(Embedding, None), repr=False) class Direction(str, Enum): @@ -204,7 +214,9 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True): # split documents with self.splitter into smaller nodes nodes = [ Node.from_langchain_document(d) - for d in self.splitter.transform_documents(docs) + for d in self.splitter.transform_documents( + [doc.to_langchain_document() for doc in docs] + ) ] self.add_nodes(nodes, show_progress=show_progress) diff --git a/tests/unit/testset_generator/test_docstore.py b/tests/unit/testset_generator/test_docstore.py index 0c6071171..30cbebc88 100644 --- a/tests/unit/testset_generator/test_docstore.py +++ b/tests/unit/testset_generator/test_docstore.py @@ -33,9 +33,9 @@ async def aembed_query(self, text: str) -> t.List[float]: def test_adjacent_nodes(): - a1 = Node(doc_id="a1", page_content="a1", filename="a") - a2 = Node(doc_id="a2", page_content="a2", filename="a") - b = Node(doc_id="b", page_content="b", filename="b") + a1 = Node(doc_id="a1", page_content="a1") + a2 = Node(doc_id="a2", page_content="a2") + b = Node(doc_id="b", page_content="b") fake_embeddings = FakeEmbeddings() splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=0) @@ -49,7 +49,7 @@ def test_adjacent_nodes(): assert store.get_adjacent(b, Direction.PREV) is None # raise ValueError if doc not in store - c = Node(doc_id="c", page_content="c", filename="c") + c = Node(doc_id="c", page_content="c") pytest.raises(ValueError, store.get_adjacent, c) @@ -62,16 +62,11 @@ def create_test_nodes(with_embeddings=True): from collections import defaultdict embeddings = defaultdict(lambda: None) - a1 = Node( - doc_id="a1", page_content="cat", filename="a", embedding=embeddings["cat"] - ) - a2 = Node( - doc_id="a2", page_content="mouse", filename="a", embedding=embeddings["mouse"] - ) + a1 = Node(doc_id="a1", page_content="cat", embedding=embeddings["cat"]) + a2 = Node(doc_id="a2", page_content="mouse", embedding=embeddings["mouse"]) b = Node( doc_id="b", page_content="solar_system", - filename="b", embedding=embeddings["solar_system"], ) @@ -144,7 +139,9 @@ async def test_fake_embeddings(): def test_docstore_add_batch(fake_llm): # create a dummy embeddings with support for async aembed_query() fake_embeddings = FakeEmbeddings() - store = InMemoryDocumentStore(splitter=None, embeddings=fake_embeddings, llm=fake_llm) # type: ignore + store = InMemoryDocumentStore( + splitter=None, embeddings=fake_embeddings, llm=fake_llm + ) # type: ignore # add documents in batch nodes = create_test_nodes(with_embeddings=False) @@ -154,8 +151,8 @@ def test_docstore_add_batch(fake_llm): == fake_embeddings.embeddings[nodes[0].page_content] ) # add documents in batch that have some embeddings - c = Node(doc_id="c", page_content="c", filename="c", embedding=[0.0] * 768) - d = Node(doc_id="d", page_content="d", filename="d", embedding=[0.0] * 768) + c = Node(doc_id="c", page_content="c", embedding=[0.0] * 768) + d = Node(doc_id="d", page_content="d", embedding=[0.0] * 768) store.add_nodes([c, d]) # test get() and that embeddings and keyphrases are correct diff --git a/tests/unit/testset_generator/test_document.py b/tests/unit/testset_generator/test_document.py new file mode 100644 index 000000000..b5bc30e87 --- /dev/null +++ b/tests/unit/testset_generator/test_document.py @@ -0,0 +1,16 @@ +import uuid +from ragas.testset.docstore import Document + + +def test_document_filename(monkeypatch): + monkeypatch.setattr(uuid, "uuid4", lambda: "test-uuid") + d1 = Document(page_content="a1") + assert d1.filename == "test-uuid" + + # now suppose I add a filename to metadata + d2 = Document(page_content="a2", metadata={"filename": "test-filename"}) + assert d2.filename == "test-filename" + + # can I change the filename? + d2.filename = "new-filename" + assert d2.filename == "new-filename" diff --git a/tests/unit/testset_generator/test_nodes.py b/tests/unit/testset_generator/test_nodes.py new file mode 100644 index 000000000..d11eb02ff --- /dev/null +++ b/tests/unit/testset_generator/test_nodes.py @@ -0,0 +1 @@ +from ragas.testset.docstore import Node