diff --git a/src/ragas/executor.py b/src/ragas/executor.py index ca66c6a9..d260ccda 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -3,12 +3,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field -import numpy as np from tqdm.auto import tqdm @dataclass class Executor: + desc: str = "Evaluating" is_async: bool = True max_workers: t.Optional[int] = None futures: t.List[t.Any] = field(default_factory=list, repr=False) @@ -71,10 +71,10 @@ async def _aresults(self) -> t.List[t.Any]: results = [] for future in tqdm( asyncio.as_completed(self.futures), - desc="Evaluating", + desc=self.desc, total=len(self.futures), ): - r = np.nan + r = (-1, None) try: r = await future except Exception as e: @@ -106,14 +106,14 @@ def results(self) -> t.List[t.Any]: try: for future in tqdm( as_completed(self.futures), - desc="Evaluating", + desc=self.desc, total=len(self.futures), ): - r = np.nan + r = (-1, None) try: r = future.result() except Exception as e: - r = np.nan + r = (-1, None) if self.raise_exceptions: raise e finally: @@ -121,5 +121,6 @@ def results(self) -> t.List[t.Any]: finally: self.executor.shutdown(wait=False) + print(results) sorted_results = sorted(results, key=lambda x: x[0]) return [r[1] for r in sorted_results] diff --git a/src/ragas/llms/__init__.py b/src/ragas/llms/__init__.py index 9d6285b6..f4d5513a 100644 --- a/src/ragas/llms/__init__.py +++ b/src/ragas/llms/__init__.py @@ -1,9 +1,10 @@ -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper __all__ = [ "BaseRagasLLM", + "LangchainLLMWrapper", "llm_factory", ] diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 9f26b44d..6224a86a 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -36,11 +36,10 @@ def is_multiple_completion_supported(llm: BaseLanguageModel) -> bool: @dataclass class BaseRagasLLM(ABC): - def get_temperature(self, n: int) -> float: """Return the temperature to use for completion based on n.""" return 0.3 if n > 1 else 1e-8 - + @abstractmethod def generate_text( self, diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index f87e5112..313bec4a 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -1,20 +1,25 @@ import heapq +import logging import typing as t import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum +from random import choices import numpy as np import numpy.typing as npt from langchain.text_splitter import TextSplitter from langchain_core.documents import Document as LCDocument -from pydantic import Field +from langchain_core.pydantic_v1 import Field +from llama_index.readers.schema import Document as LlamaindexDocument from ragas.async_utils import run_async_tasks from ragas.embeddings.base import BaseRagasEmbeddings, embedding_factory Embedding = t.Union[t.List[float], npt.NDArray[np.float64]] +logger = logging.getLogger(__name__) +rng = np.random.default_rng() class Document(LCDocument): @@ -22,29 +27,86 @@ class Document(LCDocument): filename: t.Optional[str] = None embedding: t.Optional[t.List[float]] = Field(default=None, repr=False) + @classmethod + def from_langchain_document(cls, doc: LCDocument): + doc_id = str(uuid.uuid4()) + if doc.metadata.get("filename"): + filename = doc.metadata["filename"] + else: + logger.info( + "Document [ID: %s] has no filename. Using doc_id as filename.", doc_id + ) + filename = doc_id + return cls( + page_content=doc.page_content, + metadata=doc.metadata, + doc_id=doc_id, + filename=filename, + ) + + @classmethod + def from_llamaindex_document(cls, doc: LlamaindexDocument): + doc_id = str(uuid.uuid4()) + if doc.metadata.get("filename"): + filename = doc.metadata["filename"] + else: + logger.info( + "Document [ID: %s] has no filename. Using doc_id as filename.", doc_id + ) + filename = doc_id + return cls( + page_content=doc.text, + metadata=doc.metadata, + doc_id=doc_id, + filename=filename, + ) + + +class Node(Document): + ... + + +class Direction(str, Enum): + """ + Direction for getting adjascent nodes. + """ + + NEXT = "next" + PREV = "prev" + UP = "up" + DOWN = "down" + class DocumentStore(ABC): def __init__(self): self.documents = {} @abstractmethod - def add(self, doc: t.Union[Document, t.Sequence[Document]], show_progress=True): + def add_documents(self, docs: t.Sequence[Document], show_progress=True): + ... + + @abstractmethod + def add_nodes(self, nodes: t.Sequence[Node], show_progress=True): ... @abstractmethod - def get(self, doc_id: str) -> Document: + def get_node(self, node_id: str) -> Node: + ... + + @abstractmethod + def get_random_nodes(self, k=1) -> t.List[Node]: ... @abstractmethod def get_similar( - self, doc: Document, threshold: float = 0.7, top_k: int = 3 - ) -> t.List[Document]: + self, node: Node, threshold: float = 0.7, top_k: int = 3 + ) -> t.Union[t.List[Document], t.List[Node]]: ... @abstractmethod - def get_adjascent( - self, doc: Document, direction: str = "next" - ) -> t.Optional[Document]: + def get_adjacent( + self, node: Node, direction: Direction = Direction.NEXT + ) -> t.Optional[Node]: ... @@ -117,56 +179,69 @@ class InMemoryDocumentStore(DocumentStore): embeddings: BaseRagasEmbeddings = field( default_factory=embedding_factory, repr=False ) - documents_list: t.List[Document] = field(default_factory=list) - embeddings_list: t.List[Embedding] = field(default_factory=list) - documents_map: t.Dict[str, Document] = field(default_factory=dict) + nodes: t.List[Node] = field(default_factory=list) + node_embeddings_list: t.List[Embedding] = field(default_factory=list) + node_map: t.Dict[str, Node] = field(default_factory=dict) + + def _embed_items(self, items: t.Union[t.Sequence[Document], t.Sequence[Node]]): + ... - def _add_documents_batch(self, docs: t.Sequence[Document], show_progress=True): + def add_documents(self, docs: t.Sequence[Document], show_progress=True): """ Add documents in batch mode. """ + # split documents with self.splitter into smaller nodes + nodes = [ + Node.from_langchain_document(d) + for d in self.splitter.transform_documents(docs) + ] + + self.add_nodes(nodes, show_progress=show_progress) + + def add_nodes( + self, nodes: t.Sequence[Node], show_progress=True, desc: str = "embedding nodes" + ): # NOTE: Adds everything in async mode for now. embed_tasks = [] docs_to_embed = [] - for doc in docs: - if doc.embedding is None: - embed_tasks.append(self.embeddings.aembed_query(doc.page_content)) - docs_to_embed.append(doc) + # get embeddings for the docs + for n in nodes: + if n.embedding is None: + embed_tasks.append(self.embeddings.aembed_query(n.page_content)) + docs_to_embed.append(n) else: - self.documents_list.append(doc) - self.documents_map[doc.doc_id] = doc - self.embeddings_list.append(doc.embedding) - - embeddings = run_async_tasks(embed_tasks, show_progress=show_progress) - for doc, embedding in zip(docs_to_embed, embeddings): - doc.embedding = embedding - self.documents_list.append(doc) - self.documents_map[doc.doc_id] = doc - self.embeddings_list.append(doc.embedding) - - def add(self, doc: t.Union[Document, t.Sequence[Document]], show_progress=True): - if isinstance(doc, list) or isinstance(doc, tuple): - self._add_documents_batch(doc) - elif isinstance(doc, Document): - self.documents_list.append(doc) - self.documents_map[doc.doc_id] = doc - if doc.embedding is None: - doc.embedding = self.embeddings.embed_query(doc.page_content) - self.embeddings_list.append(doc.embedding) - else: - raise ValueError("add() method only supports Document or List[Document]") + self.nodes.append(n) + self.node_map[n.doc_id] = n + self.node_embeddings_list.append(n.embedding) + + embeddings = run_async_tasks( + embed_tasks, show_progress=show_progress, progress_bar_desc=desc + ) + for n, embedding in zip(docs_to_embed, embeddings): + n.embedding = embedding + self.nodes.append(n) + self.node_map[n.doc_id] = n + self.node_embeddings_list.append(n.embedding) + + def get_node(self, node_id: str) -> Node: + return self.node_map[node_id] + + def get_document(self, doc_id: str) -> Node: + raise NotImplementedError - def get(self, doc_id: str) -> Document: - return self.documents_map[doc_id] + def get_random_nodes(self, k=1) -> t.List[Node]: + return choices(self.nodes, k=k) def get_similar( - self, doc: Document, threshold: float = 0.7, top_k: int = 3 - ) -> t.List[Document]: + self, node: Node, threshold: float = 0.7, top_k: int = 3 + ) -> t.Union[t.List[Document], t.List[Node]]: + items = [] + doc = node if doc.embedding is None: raise ValueError("Document has no embedding.") scores, doc_ids = get_top_k_embeddings( query_embedding=doc.embedding, - embeddings=self.embeddings_list, + embeddings=self.node_embeddings_list, similarity_fn=similarity, similarity_cutoff=threshold, # we need to return k+1 docs here as the top result is the input doc itself @@ -174,27 +249,28 @@ def get_similar( ) # remove the query doc itself from results scores, doc_ids = scores[1:], doc_ids[1:] - return [self.documents_list[doc_id] for doc_id in doc_ids] + items = [self.nodes[doc_id] for doc_id in doc_ids] + return items - def get_adjascent( - self, doc: Document, direction: str = "next" - ) -> t.Optional[Document]: + def get_adjacent( + self, node: Node, direction: Direction = Direction.NEXT + ) -> t.Optional[Node]: # linear search for doc_id of doc in documents_list - index = self.documents_list.index(doc) + index = self.nodes.index(node) - if direction == "next": - if len(self.documents_list) > index + 1: - next_doc = self.documents_list[index + 1] - if next_doc.filename == doc.filename: + if direction == Direction.NEXT: + if len(self.nodes) > index + 1: + next_doc = self.nodes[index + 1] + if next_doc.filename == node.filename: return next_doc else: return None else: return None - if direction == "prev": + if direction == Direction.PREV: if index > 0: - prev_doc = self.documents_list[index - 1] - if prev_doc.filename == doc.filename: + prev_doc = self.nodes[index - 1] + if prev_doc.filename == node.filename: return prev_doc else: return None diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 31521225..6bd5f697 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -1,87 +1,201 @@ +import logging +import typing as t from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field +from random import choice +from fsspec.exceptions import asyncio from langchain.prompts import ChatPromptTemplate +from numpy.random import default_rng from ragas.llms import BaseRagasLLM from ragas.llms.json_load import load_as_json from ragas.llms.prompt import PromptValue -from ragas.testset.docstore import Document, DocumentStore +from ragas.testset.docstore import Direction, Document, DocumentStore, Node from ragas.testset.prompts import ( FILTER_QUESTION, MULTICONTEXT_QUESTION, SCORE_CONTEXT, SEED_QUESTION, + TABLE_QA, + demonstrations, ) +rng = default_rng() +logger = logging.getLogger(__name__) + + +def to_pv(prompt: ChatPromptTemplate) -> PromptValue: + return PromptValue(prompt_str=prompt.format()) + @dataclass class Filter(ABC): - @abstractmethod - def filter(self) -> bool: - ... + ... - @abstractmethod - async def afilter(self) -> bool: - ... +@dataclass +class NodeFilter(Filter): + llm: BaseRagasLLM + threshold: float = 7.5 -def to_pv(prompt: ChatPromptTemplate) -> PromptValue: - return PromptValue(prompt_str=prompt.format()) + def filter(self, node: Node) -> t.Dict: + return asyncio.get_event_loop().run_until_complete(self.afilter(node)) + async def afilter(self, node: Node) -> t.Dict: + human_prompt = SCORE_CONTEXT.format(context=node.page_content) + prompt = ChatPromptTemplate.from_messages([human_prompt]) + results = await self.llm.agenerate_text(prompt=to_pv(prompt)) + output = results.generations[0][0].text.strip() + score = load_as_json(output) + score.update({"score": score.get("score", 0) >= self.threshold}) + return score -async def filter_context( - llm: BaseRagasLLM, context: str, threshold: float = 7.5 -) -> bool: - """ - context: str - The input context - Checks if the context is has enough information to frame a question - """ - human_prompt = SCORE_CONTEXT.format(context=context) - prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = await llm.agenerate_text(prompt=to_pv(prompt)) - output = results.generations[0][0].text.strip() - score = load_as_json(output) - return score >= threshold # type: ignore +@dataclass +class QuestionFilter(Filter): + llm: BaseRagasLLM + def filter(self, question: str) -> bool: + return asyncio.get_event_loop().run_until_complete(self.afilter(question)) -async def filter_question(llm: BaseRagasLLM, question: str) -> bool: - human_prompt = FILTER_QUESTION.format(question=question) - prompt = ChatPromptTemplate.from_messages([human_prompt]) + async def afilter(self, question: str) -> bool: + human_prompt = FILTER_QUESTION.format(question=question) + prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = await llm.agenerate_text(prompt=to_pv(prompt)) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - return json_results.get("verdict") != "No" + results = await self.llm.agenerate_text(prompt=to_pv(prompt)) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" @dataclass class Evolution: - def evolve(self): + node_filter: NodeFilter + question_filter: QuestionFilter + nodes: t.List[Node] = field(default_factory=list) + max_tries: int = 5 + _root_node: t.Optional[Node] = field(default=None, init=False, repr=False) + _tries: int = field(default=0, init=False, repr=False) + + def merged_nodes(self) -> Node: + return Node( + doc_id="merged", page_content=" ".join(n.page_content for n in self.nodes) + ) + + async def aretry_evolve( + self, llm: BaseRagasLLM, docstore: DocumentStore, update_count: bool = True + ): + if update_count: + self._tries += 1 + print("retrying evolution: %s times", self._tries) + if self._tries > self.max_tries: + # TODO: make this into a custom exception + raise ValueError("Max tries reached") + return await self.aevolve(llm, docstore) + + @abstractmethod + def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: ... - async def aevolve(self): + @abstractmethod + async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: ... -async def simple_evolution(llm: BaseRagasLLM, seed_doc: Document): - human_prompt = SEED_QUESTION.format(context=seed_doc.page_content) +@dataclass +class SimpleEvolution(Evolution): + def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + logger.info("evolving question") + return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + + def _get_more_adjacent_nodes(self, docstore: DocumentStore): + """ + if the evolutions doesn't have enough nodes to frame a question, get more nodes + """ + assert self._root_node is not None, "root node cannot be None" + # get more nodes from above the context window + prev_adjacent_node = docstore.get_adjacent(self._root_node, Direction.PREV) + if prev_adjacent_node is None: + # get more nodes from below the context window + next_adjacent_node = docstore.get_adjacent(self._root_node, Direction.NEXT) + if next_adjacent_node is not None: + # add next nodes towards the end + self.nodes.append(next_adjacent_node) + else: + # retry with new base node + self.nodes = docstore.get_random_nodes(k=1) + self._root_node = self.nodes[0] + else: + # add prev nodes in index 0 + self.nodes.insert(0, prev_adjacent_node) + + async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + # can the node be used to frame a question? + if self._tries == 0: + self.nodes = docstore.get_random_nodes(k=1) + self._root_node = self.nodes[0] + merged_node = self.merged_nodes() + passed, table_is_present = await self.node_filter.afilter(self.nodes[0]) + if not passed: + self.nodes = docstore.get_random_nodes(k=1) + return await self.aretry_evolve(llm, docstore, update_count=False) + + # frame a basic question with with node + seed_questions = await simple_evolution(llm, merged_node, table_is_present) + # NOTE: might need improvement + # select only one seed question here + seed_question = choice(seed_questions) + is_valid_question = await self.question_filter.afilter(seed_question) + if not is_valid_question: + # get more context to rewrite question + self._get_more_adjacent_nodes(docstore) + # retry with new nodes added + return await self.aretry_evolve(llm, docstore) + else: + # if valid question + return seed_question + + +async def simple_evolution( + llm: BaseRagasLLM, seed_doc: Document, is_table_present: bool = False +): + if is_table_present: + human_prompt = TABLE_QA.format(context=seed_doc.page_content) + else: + sample = rng.choice(demonstrations, 1)[0] # type: ignore + questions = rng.choice(sample["questions"], 2, replace=False) + questions = ( + "{" + + str({k: v for dic in questions.tolist() for k, v in dic.items()}).replace( + "'", '"' + ) + + "}" + ) + demo = f'Context:{sample["context"]}\nQuestions:{questions}' + human_prompt = SEED_QUESTION.format( + demonstration=demo, context=seed_doc.page_content + ) + prompt = ChatPromptTemplate.from_messages([human_prompt]) - results = await llm.agenerate_text(prompt=to_pv(prompt)) - question = results.generations[0][0].text.strip() - return question + results = llm.generate_text_with_hmpt(prompts=[prompt]) + results = results.generations[0][0].text + if is_table_present: + return [results] + else: + results = load_as_json(results) + return [v for v in results.values()] async def multi_context_evolution( - llm: BaseRagasLLM, seed_doc: Document, doc_store: DocumentStore + llm: BaseRagasLLM, seed_node: Node, doc_store: DocumentStore ): - question = simple_evolution(llm, seed_doc) + question = simple_evolution(llm, seed_node) print(question) - similar_context = doc_store.get_similar(seed_doc)[0] + similar_context = doc_store.get_similar(seed_node)[0] human_prompt = MULTICONTEXT_QUESTION.format( - question=question, context1=seed_doc.page_content, context2=similar_context + question=question, context1=seed_node.page_content, context2=similar_context ) prompt = ChatPromptTemplate.from_messages([human_prompt]) results = await llm.agenerate_text(prompt=to_pv(prompt)) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py new file mode 100644 index 00000000..6131764b --- /dev/null +++ b/src/ragas/testset/generator.py @@ -0,0 +1,58 @@ +import typing as t +from dataclasses import dataclass + +from langchain.embeddings import OpenAIEmbeddings +from langchain_community.chat_models import ChatOpenAI +from llama_index.readers.schema import Document as LlamaindexDocument + +from ragas.embeddings import BaseRagasEmbeddings +from ragas.llms import BaseRagasLLM, LangchainLLMWrapper +from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore + + +@dataclass +class TestsetGenerator: + generator_llm: BaseRagasLLM + critic_llm: BaseRagasLLM + embeddings: BaseRagasEmbeddings + docstore: DocumentStore + + @classmethod + def with_openai( + cls, + generator_llm: str = "gpt-3.5-turbo", + critic_llm: str = "gpt-4", + embeddings: str = "text-embedding-ada-002", + docstore: t.Optional[DocumentStore] = None, + chunk_size: int = 512, + ) -> "TestsetGenerator": + generator_llm_model = LangchainLLMWrapper(ChatOpenAI(model=generator_llm)) + critic_llm_model = LangchainLLMWrapper(ChatOpenAI(model=critic_llm)) + embeddings_model = OpenAIEmbeddings(model=embeddings) + if docstore is None: + from langchain.text_splitter import TokenTextSplitter + + splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=0) + docstore = InMemoryDocumentStore(splitter) + return cls( + generator_llm=generator_llm_model, + critic_llm=critic_llm_model, + # TODO: remove type ignore after fixing embeddigns + embeddings=embeddings_model, # type: ignore + docstore=docstore, + ) + else: + return cls( + generator_llm=generator_llm_model, + critic_llm=critic_llm_model, + embeddings=embeddings_model, # type: ignore + docstore=docstore, + ) + + def generate_with_llamaindex_docs(self, documents: t.Sequence[LlamaindexDocument]): + # chunk documents and add to docstore + self.docstore.add_documents( + [Document.from_llamaindex_document(doc) for doc in documents] + ) + # create evolutions and add to executor queue + # run till completion - keep updating progress bar diff --git a/src/ragas/testset/prompts.py b/src/ragas/testset/prompts.py index e3882083..04fe50a5 100644 --- a/src/ragas/testset/prompts.py +++ b/src/ragas/testset/prompts.py @@ -213,8 +213,7 @@ ) REWRITE_QUESTION = HumanMessagePromptTemplate.from_template( - """ - + """\ Given a context, transform the given question to be clear and standalone by replacing its coreferences with specific details from the context: Contexts: diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py index 8e7d0e12..907fc92a 100644 --- a/src/ragas/testset/testset_generator.py +++ b/src/ragas/testset/testset_generator.py @@ -544,4 +544,4 @@ def generate( count += 1 pbar.update(count) - return TestDataset(test_data=samples) \ No newline at end of file + return TestDataset(test_data=samples) diff --git a/tests/e2e/test_adaptation.py b/tests/e2e/test_adaptation.py index 269233f6..f2b07149 100644 --- a/tests/e2e/test_adaptation.py +++ b/tests/e2e/test_adaptation.py @@ -1,4 +1,3 @@ - from ragas import adapt from ragas.metrics import context_recall diff --git a/tests/unit/testset_generator/test_docstore.py b/tests/unit/testset_generator/test_docstore.py index 338067a3..b5a2ca78 100644 --- a/tests/unit/testset_generator/test_docstore.py +++ b/tests/unit/testset_generator/test_docstore.py @@ -5,28 +5,28 @@ import pytest from langchain_core.embeddings import Embeddings -from ragas.testset.docstore import Document, InMemoryDocumentStore +from ragas.testset.docstore import Direction, InMemoryDocumentStore, Node def test_adjacent_nodes(): - a1 = Document(doc_id="a1", page_content="a1", filename="a") - a2 = Document(doc_id="a2", page_content="a2", filename="a") - b = Document(doc_id="b", page_content="b", filename="b") + a1 = Node(doc_id="a1", page_content="a1", filename="a") + a2 = Node(doc_id="a2", page_content="a2", filename="a") + b = Node(doc_id="b", page_content="b", filename="b") store = InMemoryDocumentStore(splitter=None) # type: ignore - store.documents_list = [a1, a2, b] + store.nodes = [a1, a2, b] - assert store.get_adjascent(a1) == a2 - assert store.get_adjascent(a2, "prev") == a1 - assert store.get_adjascent(a2, "next") is None - assert store.get_adjascent(b, "prev") is None + assert store.get_adjacent(a1) == a2 + assert store.get_adjacent(a2, Direction.PREV) == a1 + assert store.get_adjacent(a2, Direction.NEXT) is None + assert store.get_adjacent(b, Direction.PREV) is None # raise ValueError if doc not in store - c = Document(doc_id="c", page_content="c", filename="c") - pytest.raises(ValueError, store.get_adjascent, c) + c = Node(doc_id="c", page_content="c", filename="c") + pytest.raises(ValueError, store.get_adjacent, c) -def create_test_documents(with_embeddings=True): +def create_test_nodes(with_embeddings=True): if with_embeddings: path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_embs.pkl") with open(path, "rb") as f: @@ -35,13 +35,13 @@ def create_test_documents(with_embeddings=True): from collections import defaultdict embeddings = defaultdict(lambda: None) - a1 = Document( + a1 = Node( doc_id="a1", page_content="cat", filename="a", embedding=embeddings["cat"] ) - a2 = Document( + a2 = Node( doc_id="a2", page_content="mouse", filename="a", embedding=embeddings["mouse"] ) - b = Document( + b = Node( doc_id="b", page_content="solar_system", filename="b", @@ -52,10 +52,10 @@ def create_test_documents(with_embeddings=True): def test_similar_nodes(): - a1, a2, b = create_test_documents() + a1, a2, b = create_test_nodes() store = InMemoryDocumentStore(splitter=None) # type: ignore - store.documents_list = [a1, a2, b] - store.embeddings_list = [d.embedding for d in store.documents_list] + store.nodes = [a1, a2, b] + store.node_embeddings_list = [d.embedding for d in store.nodes] assert store.get_similar(a1)[0] == a2 assert store.get_similar(a2)[0] == a1 @@ -65,10 +65,10 @@ def test_similar_nodes(): def test_similar_nodes_scaled(): - a1, a2, b = create_test_documents() - store = InMemoryDocumentStore(splitter=None) # type: ignore - store.documents_list = [a1, a2, b] + [b] * 100 - store.embeddings_list = [d.embedding for d in store.documents_list] + a1, a2, b = create_test_nodes() + store = InMemoryDocumentStore(splitter=None) # type: ignore (None type is not Splitter) + store.nodes = [a1, a2, b] + [b] * 100 + store.node_embeddings_list = [d.embedding for d in store.nodes] assert len(store.get_similar(a1, top_k=3)) == 3 assert store.get_similar(a1)[0] == a2 @@ -76,16 +76,16 @@ def test_similar_nodes_scaled(): def test_docstore_add(): - a1, a2, b = create_test_documents() + a1, a2, b = create_test_nodes() store = InMemoryDocumentStore(splitter=None) # type: ignore docs_added = [] for doc in [a1, a2, b]: - store.add(doc) + store.add_nodes([doc]) docs_added.append(doc) - assert store.documents_list == docs_added - assert store.embeddings_list == [d.embedding for d in docs_added] + assert store.nodes == docs_added + assert store.node_embeddings_list == [d.embedding for d in docs_added] - assert store.get(a1.doc_id) == a1 + assert store.get_node(a1.doc_id) == a1 class FakeEmbeddings(Embeddings): @@ -130,20 +130,20 @@ def test_docstore_add_batch(): store = InMemoryDocumentStore(splitter=None, embeddings=fake_embeddings) # type: ignore # add documents in batch - docs = create_test_documents(with_embeddings=False) - store.add(docs) + nodes = create_test_nodes(with_embeddings=False) + store.add_nodes(nodes) assert ( - store.documents_map[docs[0].doc_id].embedding - == fake_embeddings.embeddings[docs[0].page_content] + store.node_map[nodes[0].doc_id].embedding + == fake_embeddings.embeddings[nodes[0].page_content] ) # add documents in batch that have some embeddings - c = Document(doc_id="c", page_content="c", filename="c", embedding=[0.0] * 768) - d = Document(doc_id="d", page_content="d", filename="d", embedding=[0.0] * 768) - store.add([c, d]) + c = Node(doc_id="c", page_content="c", filename="c", embedding=[0.0] * 768) + d = Node(doc_id="d", page_content="d", filename="d", embedding=[0.0] * 768) + store.add_nodes([c, d]) # test get() and that embeddings are correct - assert store.get(c.doc_id).embedding == [0.0] * 768 - assert store.get(d.doc_id).embedding == [0.0] * 768 - assert len(store.documents_list) == 5 - assert len(store.embeddings_list) == 5 - assert len(store.documents_map) == 5 + assert store.get_node(c.doc_id).embedding == [0.0] * 768 + assert store.get_node(d.doc_id).embedding == [0.0] * 768 + assert len(store.nodes) == 5 + assert len(store.node_embeddings_list) == 5 + assert len(store.node_map) == 5