Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions src/ragas/testset/docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,58 @@
logger = logging.getLogger(__name__)


class Document(LCDocument):
doc_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
filename: t.Optional[str] = None
embedding: t.Optional[t.List[float]] = Field(default=None, repr=False)
@dataclass
class Document:
page_content: str
metadata: dict = field(default_factory=dict)
doc_id: str = field(default_factory=lambda: str(uuid.uuid4()))
type: t.Literal["Document"] = "Document"

@property
def filename(self):
filename = self.metadata.get("filename")
if filename is not None:
return filename
else:
logger.info(
"Document [ID: %s] has no filename. Using doc_id as filename.",
self.doc_id,
)
return self.doc_id

@filename.setter
def filename(self, value):
self.metadata["filename"] = value

@classmethod
def from_langchain_document(cls, doc: LCDocument):
doc_id = str(uuid.uuid4())
if doc.metadata.get("filename"):
filename = doc.metadata["filename"]
else:
logger.info(
"Document [ID: %s] has no filename. Using doc_id as filename.", doc_id
)
filename = doc_id
return cls(
page_content=doc.page_content,
metadata=doc.metadata,
doc_id=doc_id,
filename=filename,
)

def to_langchain_document(self) -> LCDocument:
return LCDocument(
page_content=self.page_content,
metadata=self.metadata,
)

@classmethod
def from_llamaindex_document(cls, doc: LlamaindexDocument):
doc_id = str(uuid.uuid4())
if doc.metadata.get("filename"):
filename = doc.metadata["filename"]
else:
logger.info(
"Document [ID: %s] has no filename. Using doc_id as filename.", doc_id
)
filename = doc_id
return cls(
page_content=doc.text,
metadata=doc.metadata,
doc_id=doc_id,
filename=filename,
)


@dataclass
class Node(Document):
keyphrases: t.List[str] = Field(default_factory=list, repr=False)
embedding: Embedding = Field(default=t.cast(Embedding, None), repr=False)


class Direction(str, Enum):
Expand Down Expand Up @@ -204,7 +214,9 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True):
# split documents with self.splitter into smaller nodes
nodes = [
Node.from_langchain_document(d)
for d in self.splitter.transform_documents(docs)
for d in self.splitter.transform_documents(
[doc.to_langchain_document() for doc in docs]
)
]

self.add_nodes(nodes, show_progress=show_progress)
Expand Down
25 changes: 11 additions & 14 deletions tests/unit/testset_generator/test_docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ async def aembed_query(self, text: str) -> t.List[float]:


def test_adjacent_nodes():
a1 = Node(doc_id="a1", page_content="a1", filename="a")
a2 = Node(doc_id="a2", page_content="a2", filename="a")
b = Node(doc_id="b", page_content="b", filename="b")
a1 = Node(doc_id="a1", page_content="a1")
a2 = Node(doc_id="a2", page_content="a2")
b = Node(doc_id="b", page_content="b")

fake_embeddings = FakeEmbeddings()
splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=0)
Expand All @@ -49,7 +49,7 @@ def test_adjacent_nodes():
assert store.get_adjacent(b, Direction.PREV) is None

# raise ValueError if doc not in store
c = Node(doc_id="c", page_content="c", filename="c")
c = Node(doc_id="c", page_content="c")
pytest.raises(ValueError, store.get_adjacent, c)


Expand All @@ -62,16 +62,11 @@ def create_test_nodes(with_embeddings=True):
from collections import defaultdict

embeddings = defaultdict(lambda: None)
a1 = Node(
doc_id="a1", page_content="cat", filename="a", embedding=embeddings["cat"]
)
a2 = Node(
doc_id="a2", page_content="mouse", filename="a", embedding=embeddings["mouse"]
)
a1 = Node(doc_id="a1", page_content="cat", embedding=embeddings["cat"])
a2 = Node(doc_id="a2", page_content="mouse", embedding=embeddings["mouse"])
b = Node(
doc_id="b",
page_content="solar_system",
filename="b",
embedding=embeddings["solar_system"],
)

Expand Down Expand Up @@ -144,7 +139,9 @@ async def test_fake_embeddings():
def test_docstore_add_batch(fake_llm):
# create a dummy embeddings with support for async aembed_query()
fake_embeddings = FakeEmbeddings()
store = InMemoryDocumentStore(splitter=None, embeddings=fake_embeddings, llm=fake_llm) # type: ignore
store = InMemoryDocumentStore(
splitter=None, embeddings=fake_embeddings, llm=fake_llm
) # type: ignore

# add documents in batch
nodes = create_test_nodes(with_embeddings=False)
Expand All @@ -154,8 +151,8 @@ def test_docstore_add_batch(fake_llm):
== fake_embeddings.embeddings[nodes[0].page_content]
)
# add documents in batch that have some embeddings
c = Node(doc_id="c", page_content="c", filename="c", embedding=[0.0] * 768)
d = Node(doc_id="d", page_content="d", filename="d", embedding=[0.0] * 768)
c = Node(doc_id="c", page_content="c", embedding=[0.0] * 768)
d = Node(doc_id="d", page_content="d", embedding=[0.0] * 768)
store.add_nodes([c, d])

# test get() and that embeddings and keyphrases are correct
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/testset_generator/test_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import uuid
from ragas.testset.docstore import Document


def test_document_filename(monkeypatch):
monkeypatch.setattr(uuid, "uuid4", lambda: "test-uuid")
d1 = Document(page_content="a1")
assert d1.filename == "test-uuid"

# now suppose I add a filename to metadata
d2 = Document(page_content="a2", metadata={"filename": "test-filename"})
assert d2.filename == "test-filename"

# can I change the filename?
d2.filename = "new-filename"
assert d2.filename == "new-filename"
1 change: 1 addition & 0 deletions tests/unit/testset_generator/test_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ragas.testset.docstore import Node