Skip to content

Commit

Permalink
Add source to the answer for default prompt (#2289)
Browse files Browse the repository at this point in the history
* Add source to the answer for default prompt

* Fix qdrant

* Fix tests

* Update docstring

* Fix check files

* Fix qdrant test error
  • Loading branch information
thinkall committed Apr 10, 2024
1 parent 5292024 commit 5a96dc2
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
5 changes: 3 additions & 2 deletions autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ def create_qdrant_from_dir(
client.set_model(embedding_model)

if custom_text_split_function is not None:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")
Expand Down Expand Up @@ -298,5 +298,6 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
"metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data
16 changes: 14 additions & 2 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
User's question is: {input_question}
Context is: {input_context}
The source of the context is: {input_sources}
If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`.
"""

PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the
Expand Down Expand Up @@ -101,7 +105,8 @@ def __init__(
following keys:
- `task` (Optional, str) - the task of the retrieve chat. Possible values are
"code", "qa" and "default". System prompt will be different for different tasks.
The default value is `default`, which supports both code and qa.
The default value is `default`, which supports both code and qa, and provides
source information in the end of the response.
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
default client `chromadb.Client()` will be used. If you want to use other
vector db, extend this class and override the `retrieve_docs` function.
Expand Down Expand Up @@ -243,6 +248,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._current_docs_in_context = [] # the ids of the current context sources
self._search_string = "" # the search string used in the current query
# update the termination message function
self._is_termination_msg = (
Expand Down Expand Up @@ -290,6 +296,7 @@ def _reset(self, intermediate=False):

def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
doc_contents = ""
self._current_docs_in_context = []
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
Expand All @@ -310,6 +317,9 @@ def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
print(colored(func_print, "green"), flush=True)
current_tokens += _doc_tokens
doc_contents += doc + "\n"
_metadatas = results.get("metadatas")
if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
self._doc_idx = idx
self._doc_ids.append(results["ids"][0][idx])
self._doc_contents.append(doc)
Expand All @@ -329,7 +339,9 @@ def _generate_message(self, doc_contents, task="default"):
elif task.upper() == "QA":
message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "DEFAULT":
message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents)
message = PROMPT_DEFAULT.format(
input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context
)
else:
raise NotImplementedError(f"task {task} is not implemented.")
return message
Expand Down
25 changes: 17 additions & 8 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import glob
import os
import re
from typing import Callable, List, Union
from typing import Callable, List, Tuple, Union
from urllib.parse import urlparse

import chromadb
Expand Down Expand Up @@ -160,8 +160,14 @@ def split_files_to_chunks(
"""Split a list of files into chunks of max_tokens."""

chunks = []
sources = []

for file in files:
if isinstance(file, tuple):
url = file[1]
file = file[0]
else:
url = None
_, file_extension = os.path.splitext(file)
file_extension = file_extension.lower()

Expand All @@ -179,11 +185,13 @@ def split_files_to_chunks(
continue # Skip to the next file if no text is available

if custom_text_split_function is not None:
chunks += custom_text_split_function(text)
tmp_chunks = custom_text_split_function(text)
else:
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
tmp_chunks = split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
chunks += tmp_chunks
sources += [{"source": url if url else file}] * len(tmp_chunks)

return chunks
return chunks, sources


def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
Expand Down Expand Up @@ -267,7 +275,7 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
return webpage_text


def get_file_from_url(url: str, save_path: str = None):
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
"""Download a file from a URL."""
if save_path is None:
save_path = "tmp/chromadb"
Expand Down Expand Up @@ -303,7 +311,7 @@ def get_file_from_url(url: str, save_path: str = None):
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return save_path
return save_path, url


def is_url(string: str):
Expand Down Expand Up @@ -383,12 +391,12 @@ def create_vector_db_from_dir(
length = len(collection.get()["ids"])

if custom_text_split_function is not None:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
max_tokens,
chunk_mode,
Expand All @@ -401,6 +409,7 @@ def create_vector_db_from_dir(
collection.upsert(
documents=chunks[i:end_idx],
ids=[f"doc_{j+length}" for j in range(i, end_idx)], # unique for each doc
metadatas=sources[i:end_idx],
)
except ValueError as e:
logger.warning(f"{e}")
Expand Down
10 changes: 5 additions & 5 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_extract_text_from_pdf(self):
def test_split_files_to_chunks(self):
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
chunks = split_files_to_chunks([pdf_file_path, txt_file_path])
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path])
assert all(
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
for chunk in chunks
Expand All @@ -81,7 +81,7 @@ def test_get_files_from_dir(self):
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
files = get_files_from_dir([pdf_file_path, txt_file_path])
assert all(os.path.isfile(file) for file in files)
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
files = get_files_from_dir(
[
pdf_file_path,
Expand All @@ -91,7 +91,7 @@ def test_get_files_from_dir(self):
],
recursive=True,
)
assert all(os.path.isfile(file) for file in files)
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
files = get_files_from_dir(
[
pdf_file_path,
Expand All @@ -102,7 +102,7 @@ def test_get_files_from_dir(self):
recursive=True,
types=["pdf", "txt"],
)
assert all(os.path.isfile(file) for file in files)
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
assert len(files) == 3

def test_is_url(self):
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_unstructured(self):
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
word_file_path = os.path.join(test_dir, "example.docx")
chunks = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
assert all(
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
for chunk in chunks
Expand Down

0 comments on commit 5a96dc2

Please sign in to comment.