Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New rag agent #1727

Closed
wants to merge 74 commits into from
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
bfe40ad
Added new rag core functionalities
thinkall Feb 8, 2024
cbc1a24
Update docstring
thinkall Feb 19, 2024
948ae5e
Add get docs by ids
thinkall Feb 20, 2024
07e206f
Keep order in merge_documents, good for no reranker
thinkall Feb 20, 2024
2b5bf10
Add more docstrings
thinkall Feb 20, 2024
986be00
Update get_docs_by_ids, add include and kwargs
thinkall Feb 20, 2024
1a9acd8
Add test_chromadb and fix errors
thinkall Feb 20, 2024
3f6d133
Add test_datamodel
thinkall Feb 20, 2024
706c886
Add test_encoder
thinkall Feb 20, 2024
a8319b1
Add test_promptgenerator
thinkall Feb 20, 2024
a99882e
Add test_reranker
thinkall Feb 20, 2024
66a09c9
Add test_retriever
thinkall Feb 20, 2024
9d9bc59
Update test_retriever
thinkall Feb 20, 2024
7f76f03
Add test_splitter and fix splitter for one line file
thinkall Feb 20, 2024
a4f2f4b
Add test_uitls and fix utils
thinkall Feb 20, 2024
5320e45
Add test_rag_agent
thinkall Feb 20, 2024
e3b42a6
Fix test splitter
thinkall Feb 20, 2024
dd13a29
Add test_rag_openai
thinkall Feb 20, 2024
35fc80c
Merge branch 'main' into new_rag
thinkall Feb 20, 2024
8311940
Merge branch 'main' into new_rag
thinkall Feb 21, 2024
02d9e80
Fix test configs
thinkall Feb 21, 2024
dec1e56
Fix __all__
thinkall Feb 21, 2024
fbc37cf
Update docstring
thinkall Feb 21, 2024
0e43b69
Merge branch 'main' into new_rag
thinkall Feb 22, 2024
f73740e
Merge branch 'main' into new_rag
thinkall Feb 23, 2024
1b30dc2
Reformat and fix typo
thinkall Feb 23, 2024
18e8886
Add readme, update docstrings, improvements
thinkall Feb 23, 2024
f9e367b
Use Protocol for encoder and reranker
thinkall Feb 23, 2024
aff1459
Use Protocol for vector db, keep using received_raw_message for llm c…
thinkall Feb 23, 2024
0d6d074
Use jupyer-kernel-gateway for ipython executor (#1748)
jackgerrits Feb 23, 2024
46fed07
Handle azure_deployment Parameter Issue in GPTAssistantAgent to Maint…
IANTHEREAL Feb 24, 2024
6542943
Update parameter name
thinkall Feb 24, 2024
2f741d8
Update hash, add deduplication function for chunks
thinkall Feb 24, 2024
b5bae5e
Merge branch 'main' into new_rag
thinkall Feb 24, 2024
a071ab4
Merge branch 'main' into new_rag
thinkall Feb 25, 2024
e3fbda8
Update context
thinkall Feb 26, 2024
bd9f722
Improve docstrings
thinkall Feb 26, 2024
ad774f8
Merge remote-tracking branch 'origin/main' into new_rag
thinkall Feb 26, 2024
00439a0
Update default llm_model
thinkall Feb 26, 2024
72a5fa4
Add notebook example
thinkall Feb 26, 2024
e26dd4c
Update source, keep original url
thinkall Feb 26, 2024
90674c2
Fix a typo
thinkall Feb 26, 2024
f4d0db4
Add RAG capability
thinkall Feb 26, 2024
a66613f
Update notebook
thinkall Feb 26, 2024
19340ee
Update Readme
thinkall Feb 26, 2024
3b94241
Update readme
thinkall Feb 27, 2024
b354e0d
Merge branch 'main' into new_rag
thinkall Feb 27, 2024
a4517bc
Fix tests
thinkall Feb 27, 2024
829e47e
Fix tests
thinkall Feb 27, 2024
63b43de
Fix test_reranker
thinkall Feb 27, 2024
8915243
Merge branch 'main' into new_rag
thinkall Feb 28, 2024
efd441d
Improve inner loop, fix some bugs
thinkall Feb 28, 2024
a192c98
Remove run code
thinkall Feb 28, 2024
e5fd3bb
Merge branch 'main' into new_rag
thinkall Feb 28, 2024
98f5ab4
Merge branch 'main' into new_rag
thinkall Feb 29, 2024
8e1db8d
Merge branch 'main' into new_rag
thinkall Feb 29, 2024
b197421
Add installation
thinkall Feb 28, 2024
aac8cc9
Update prompts and promptgenerator
thinkall Feb 29, 2024
a0a0db4
Update prompts and LLM system message
thinkall Feb 29, 2024
b48d255
Improve performance of multi-hop question
thinkall Feb 29, 2024
d14cadf
Add example of multi-round conversation
thinkall Feb 29, 2024
8a107cd
Update TERMINATE_TRIGGER_WORDS
thinkall Feb 29, 2024
0762afe
Update readme
thinkall Feb 29, 2024
7704c75
Add sequence uml
thinkall Feb 29, 2024
1fdb16a
Update sequence uml
thinkall Feb 29, 2024
686f140
Update readme
thinkall Feb 29, 2024
14b1896
Merge branch 'main' into new_rag
thinkall Mar 1, 2024
7cf36ab
Merge branch 'main' into new_rag
thinkall Mar 1, 2024
d21aaf4
Merge branch 'main' into new_rag
thinkall Mar 1, 2024
63acc2f
Merge branch 'main' into new_rag
thinkall Mar 12, 2024
e95684d
Update readme
thinkall Mar 1, 2024
8d97414
Add temp doc for use cases
thinkall Mar 12, 2024
6fa479d
Merge remote-tracking branch 'origin/main' into new_rag
thinkall Mar 14, 2024
66de003
Add sequence uml for retriever
thinkall Mar 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/contrib-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ jobs:
run: |
pip install docker
pip install qdrant_client[fastembed]
pip install -e .[retrievechat]
pip install -e .[retrievechat,rag]
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/rag
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
pip install unstructured[all-docs]
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
pip install -e .[retrievechat,rag]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
Expand All @@ -57,11 +57,11 @@ jobs:
fi
- name: Test RetrieveChat
run: |
pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/rag --skip-openai
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/rag --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,6 @@ test/agentchat/test_agent_scripts/*

# test cache
.cache_test

# RAG DB folders
.db/
32 changes: 32 additions & 0 deletions autogen/agentchat/contrib/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .datamodel import Chunk, Document, QueryResults, Query
from .encoder import Encoder, EmbeddingFunction, EmbeddingFunctionFactory
from .promptgenerator import PromptGenerator
from .reranker import Reranker, RerankerFactory
from .retriever import Retriever, RetrieverFactory
from .splitter import Splitter, SplitterFactory, TextLineSplitter
from .vectordb import VectorDB, VectorDBFactory
from .utils import timer, logger
from .rag_agent import RagAgent

__all__ = [
"Chunk",
"Document",
"Encoder",
"EmbeddingFunction",
"EmbeddingFunctionFactory",
"PromptGenerator",
"Reranker",
"RerankerFactory",
"Retriever",
"RetrieverFactory",
"Splitter",
"SplitterFactory",
"TextLineSplitter",
"VectorDB",
"VectorDBFactory",
"timer",
"logger",
"RagAgent",
"QueryResults",
"Query",
]
248 changes: 248 additions & 0 deletions autogen/agentchat/contrib/rag/chromadb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import List, Any, Callable
from .datamodel import Document, Query, QueryResults, GetResults
from .vectordb import VectorDB
from .utils import logger, timer
from .constants import CHROMADB_MAX_BATCH_SIZE

try:
import chromadb

if chromadb.__version__ < "0.4.15":
raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")
from chromadb.api.models.Collection import Collection
except ImportError:
raise ImportError("Please install chromadb: `pip install chromadb`")


class ChromaVectorDB(VectorDB):
thinkall marked this conversation as resolved.
Show resolved Hide resolved
"""
A vector database that uses ChromaDB as the backend.
"""

def __init__(self, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs):
"""
Initialize the vector database.

Args:
path: str | The path to the vector database. Default is None.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None.
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of
the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances),
[hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184),
and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md).
kwargs: dict | Additional keyword arguments.

Returns:
None
"""
self.path = path
self.embedding_function = embedding_function
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}
if self.path is not None:
self.client = chromadb.PersistentClient(path=self.path, **kwargs)
else:
self.client = chromadb.Client(**kwargs)
self.active_collection = None

def create_collection(
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
) -> Collection:
"""
Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
otherwise it raise a ValueError.

Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.

Returns:
Collection | The collection object.
"""
try:
collection = self.client.get_collection(collection_name)
except ValueError:
collection = None
if collection is None:
return self.client.create_collection(
collection_name,
embedding_function=self.embedding_function,
get_or_create=get_or_create,
metadata=self.metadata,
)
elif overwrite:
self.client.delete_collection(collection_name)
return self.client.create_collection(
collection_name,
embedding_function=self.embedding_function,
get_or_create=get_or_create,
metadata=self.metadata,
)
elif get_or_create:
return collection
else:
raise ValueError(f"Collection {collection_name} already exists.")

def get_collection(self, collection_name: str = None) -> Collection:
"""
Get the collection from the vector database.

Args:
collection_name: str | The name of the collection. Default is None. If None, return the
current active collection.

Returns:
Collection | The collection object.
"""
if collection_name is None:
if self.active_collection is None:
raise ValueError("No collection is specified.")
else:
logger.info(
f"No collection is specified. Using current active collection {self.active_collection.name}."
)
else:
self.active_collection = self.client.get_collection(collection_name)
return self.active_collection

def delete_collection(self, collection_name: str) -> None:
"""
Delete the collection from the vector database.

Args:
collection_name: str | The name of the collection.

Returns:
None
"""
self.client.delete_collection(collection_name)
if self.active_collection:
if self.active_collection.name == collection_name:
self.active_collection = None

def _batch_insert(self, collection, embeddings=None, ids=None, metadata=None, documents=None, upsert=False):
batch_size = int(CHROMADB_MAX_BATCH_SIZE)
for i in range(0, len(documents), min(batch_size, len(documents))):
end_idx = i + min(batch_size, len(documents) - i)
collection_kwargs = {
"documents": documents[i:end_idx],
"ids": ids[i:end_idx],
"metadatas": metadata[i:end_idx] if metadata else None,
"embeddings": embeddings[i:end_idx] if embeddings else None,
}
if upsert:
collection.upsert(**collection_kwargs)
else:
collection.add(**collection_kwargs)

@timer
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
"""
Insert documents into the collection of the vector database.

Args:
docs: List[Document] | A list of documents.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.

Returns:
None
"""
if not docs:
return
collection = self.get_collection(collection_name)
if docs[0].content_embedding is None:
logger.info(
"No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding."
)
embeddings = None
else:
embeddings = [doc.content_embedding for doc in docs]
documents = [doc.content for doc in docs]
ids = [doc.id for doc in docs]
metadata = [doc.metadata for doc in docs]
self._batch_insert(collection, embeddings, ids, metadata, documents, upsert)

def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
"""
Update documents in the collection of the vector database.

Args:
docs: List[Document] | A list of documents.
collection_name: str | The name of the collection. Default is None.

Returns:
None
"""
self.insert_docs(docs, collection_name, upsert=True)

def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> None:
"""
Delete documents from the collection of the vector database.

Args:
ids: List[Any] | A list of document ids.
collection_name: str | The name of the collection. Default is None.
kwargs: dict | Additional keyword arguments.

Returns:
None
"""
collection = self.get_collection(collection_name)
collection.delete(ids, **kwargs)

def retrieve_docs(self, queries: List[Query], collection_name: str = None) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.

Args:
queries: List[Query] | A list of queries.
collection_name: str | The name of the collection. Default is None.

Returns:
QueryResults | The query results.
"""
collection = self.get_collection(collection_name)
results = collection.query(
query_texts=[q.text for q in queries],
n_results=queries[0].k,
where=queries[0].filter_metadata,
where_document=queries[0].filter_document,
include=queries[0].include if queries[0].include else ["distances", "documents", "metadatas"],
)
return QueryResults(
ids=results.get("ids"),
texts=results.get("documents"),
embeddings=results.get("embeddings"),
metadatas=results.get("metadatas"),
distances=results.get("distances"),
)

def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> GetResults:
"""
Retrieve documents from the collection of the vector database based on the ids.

Args:
ids: List[Any] | A list of document ids.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"]
kwargs: dict | Additional keyword arguments.

Returns:
GetResults | The query results.
"""
collection = self.get_collection(collection_name)
include = include if include else ["metadatas", "documents"]
results = collection.get(ids, include=include, **kwargs)
return GetResults(
ids=results.get("ids"),
texts=results.get("documents"),
embeddings=results.get("embeddings"),
metadatas=results.get("metadatas"),
)
4 changes: 4 additions & 0 deletions autogen/agentchat/contrib/rag/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os

RAG_MINIMUM_MESSAGE_LENGTH = os.environ.get("RAG_MINIMUM_MESSAGE_LENGTH", 5)
CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)
Loading
Loading