Skip to content

Commit

Permalink
Update no_update_context, fix upsert docs (#52)
Browse files Browse the repository at this point in the history
* Update no_update_context, fix upsert docs

* Recreate only once

* Add comments to get_or_create

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
  • Loading branch information
thinkall and sonichi committed Oct 1, 2023
1 parent 904b293 commit b06919b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
15 changes: 10 additions & 5 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def __init__(
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False.
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
super().__init__(
Expand All @@ -148,7 +150,8 @@ def __init__(
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.no_update_context = self._retrieve_config.get("no_update_context", False)
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._ipython = get_ipython()
Expand Down Expand Up @@ -231,7 +234,7 @@ def _generate_retrieve_user_reply(
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""In this function, we will update the context and reset the conversation based on different conditions.
We'll update the context and reset the conversation if no_update_context is False and either of the following:
We'll update the context and reset the conversation if update_context is True and either of the following:
(1) the last message contains "UPDATE CONTEXT",
(2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message.
"""
Expand All @@ -247,7 +250,7 @@ def _generate_retrieve_user_reply(
update_context_case2 = (
self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper()
)
if (update_context_case1 or update_context_case2) and not self.no_update_context:
if (update_context_case1 or update_context_case2) and self.update_context:
print(colored("Updating context and resetting conversation.", "green"), flush=True)
# extract the first sentence in the response as the intermediate answer
_message = message.get("content", "").split("\n")[0].strip()
Expand Down Expand Up @@ -286,7 +289,7 @@ def _generate_retrieve_user_reply(
return False, None

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
if not self._collection:
if not self._collection or self._get_or_create:
print("Trying to create collection.")
create_vector_db_from_dir(
dir_path=self._docs_path,
Expand All @@ -296,8 +299,10 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
chunk_mode=self._chunk_mode,
must_break_at_empty_line=self._must_break_at_empty_line,
embedding_model=self._embedding_model,
get_or_create=self._get_or_create,
)
self._collection = True
self._get_or_create = False

results = query_vector_db(
query_texts=[problem],
Expand Down
15 changes: 5 additions & 10 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,18 +208,13 @@ def create_vector_db_from_dir(

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
print(f"Found {len(chunks)} chunks.")
# upsert in batch of 40000
for i in range(0, len(chunks), 40000):
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
for i in range(0, len(chunks), min(40000, len(chunks))):
end_idx = i + min(40000, len(chunks) - i)
collection.upsert(
documents=chunks[
i : i + 40000
], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
ids=[f"doc_{i}" for i in range(i, i + 40000)], # unique for each doc
documents=chunks[i:end_idx],
ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc
)
collection.upsert(
documents=chunks[i : len(chunks)],
ids=[f"doc_{i}" for i in range(i, len(chunks))], # unique for each doc
)
except ValueError as e:
logger.warning(f"{e}")

Expand Down

0 comments on commit b06919b

Please sign in to comment.