Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RetrieveChat #1158

Merged
merged 30 commits into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9aa34fc
Add RetrieveChat notebook, RetrieveAssistantAgent and RetrieveUserPro…
thinkall Jul 31, 2023
b751a04
merge main
thinkall Aug 1, 2023
5a0a117
Update according to comments
thinkall Aug 1, 2023
147b606
Add output
thinkall Aug 1, 2023
11e1c1f
Merge branch 'main' into retrieve_agent
thinkall Aug 3, 2023
22c2303
Add tests, merge main, address comments
thinkall Aug 3, 2023
ca19515
Fix tests
thinkall Aug 4, 2023
0652cdb
Merge branch 'main' into retrieve_agent
thinkall Aug 6, 2023
8943523
Merge main
thinkall Aug 6, 2023
744131c
Remove unnecessary code
thinkall Aug 6, 2023
f3041d7
Update test
thinkall Aug 6, 2023
e2c9b3e
Update QA, merge main
thinkall Aug 8, 2023
4778663
Update notebook, some functions
thinkall Aug 9, 2023
9ad3887
Fix print issue
thinkall Aug 10, 2023
96b1e81
Update notebook
thinkall Aug 10, 2023
b1603f9
Update notebook
thinkall Aug 10, 2023
f57dfca
Update notebook
thinkall Aug 10, 2023
35996d6
Improve retrieve utils and update notebook
thinkall Aug 10, 2023
6226c1e
Update vector db creation method
thinkall Aug 10, 2023
4e7b845
Update notebook
thinkall Aug 10, 2023
5e3b16f
Update notebook
thinkall Aug 10, 2023
b36cd43
Add terminate if no more context
thinkall Aug 11, 2023
671ffd4
Update prompt and notebook, add example for update context
thinkall Aug 11, 2023
dc3dfbb
Update results
thinkall Aug 12, 2023
e388f2f
Update results
thinkall Aug 12, 2023
54d7c14
Merge branch 'main' into retrieve_agent
thinkall Aug 12, 2023
5518ac6
Update results of update context
thinkall Aug 12, 2023
74698cb
Fix typo
thinkall Aug 12, 2023
11f46d0
Add table of contents
thinkall Aug 13, 2023
e4b2e52
Update table of contents
thinkall Aug 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ test/nlp/testtmpfl.py
output/
flaml/tune/spark/mylearner.py
*.pkl

# local config files
*.config.local
45 changes: 45 additions & 0 deletions flaml/autogen/agentchat/contrib/retrieve_assistant_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from flaml.autogen.agentchat.agent import Agent
from flaml.autogen.agentchat.assistant_agent import AssistantAgent
from typing import Callable, Dict, Optional, Union


class RetrieveAssistantAgent(AssistantAgent):
"""(Experimental) Assistant agent, designed to solve a task with LLM.

AssistantAgent is a subclass of ResponsiveAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
thinkall marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
name: str,
**kwargs,
):
"""
Args:
name (str): agent name.
"""
super().__init__(
name,
**kwargs,
)

def _reset(self):
self._oai_conversations.clear()
thinkall marked this conversation as resolved.
Show resolved Hide resolved

def receive(self, message: Union[Dict, str], sender: "Agent"):
thinkall marked this conversation as resolved.
Show resolved Hide resolved
"""Receive a message from another agent.
If "Update Context" in message, update the context and reset the messages in the conversation.
"""
message = self._message_to_dict(message)
if "UPDATE CONTEXT" in message.get("content", "")[-20::].upper():
self._reset()
thinkall marked this conversation as resolved.
Show resolved Hide resolved
thinkall marked this conversation as resolved.
Show resolved Hide resolved
self.send("UPDATE CONTEXT", sender)
elif "exitcode: 0 (execution succeeded)" in message.get("content", ""):
self.send("TERMINATE", sender)
else:
super().receive(message, sender)
216 changes: 216 additions & 0 deletions flaml/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import chromadb
from flaml.autogen.agentchat.agent import Agent
from flaml.autogen.agentchat import UserProxyAgent
from flaml.autogen.retrieval_utils import create_vector_db_from_dir, query_vector_db, num_tokens_from_text
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang

from typing import Callable, Dict, Optional, Union, List
from IPython import get_ipython


PROMPT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
context provided by the user. You should follow the following steps to answer a question:
Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or
a QA task.
Step 2, you generate code or answer the question based on the intent.
You should leverage the context provided by the user as much as possible. If you think the context is not enough, you
can reply exactly "UPDATE CONTEXT" to ask the user to provide more contexts.
For code generation, you must obey the following rules:
You MUST NOT install any packages because all the packages needed are already installed.
The code will be executed in IPython, you must follow the formats below to write your code:
```python
# your code
```

User's question is: {input_question}

Context is: {input_context}
"""


def _is_termination_msg_retrievechat(message):
"""Check if a message is a termination message."""
if isinstance(message, dict):
message = message.get("content")
if message is None:
return False
cb = extract_code(message)
contain_code = False
for c in cb:
if c[0] == "python" or c[0] == "wolfram":
thinkall marked this conversation as resolved.
Show resolved Hide resolved
contain_code = True
break
return not contain_code


class RetrieveUserProxyAgent(UserProxyAgent):
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
is_termination_msg: Optional[Callable[[Dict], bool]] = _is_termination_msg_retrievechat,
human_input_mode: Optional[str] = "ALWAYS",
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
):
"""
Args:
name (str): name of the agent.
human_input_mode (str): whether to ask for human inputs every time a message is received.
Possible values are "ALWAYS", "TERMINATE", "NEVER".
(1) When "ALWAYS", the agent prompts for human input every time a message is received.
Under this mode, the conversation stops when the human input is "exit",
or when is_termination_msg is True and there is no human input.
(2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or
the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
retrieve_config (dict or None): config for the retrieve agent.
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
- client (Optional, chromadb.Client): the chromadb client.
If key not provided, a default client `chromadb.Client()` will be used.
- docs_path (Optional, str): the path to the docs directory.
If key not provided, a default path `./docs` will be used.
- collection_name (Optional, str): the name of the collection.
If key not provided, a default name `flaml-docs` will be used.
- model (Optional, str): the model to use for the retrieve chat.
If key not provided, a default model `gpt-3.5` will be used.
thinkall marked this conversation as resolved.
Show resolved Hide resolved
- chunk_token_size (Optional, int): the chunk token size for the retrieve chat.
If key not provided, a default size `max_tokens * 0.6` will be used.
**kwargs (dict): other kwargs in [UserProxyAgent](user_proxy_agent#__init__).
"""
super().__init__(
name=name,
is_termination_msg=is_termination_msg,
human_input_mode=human_input_mode,
**kwargs,
)

self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", "./docs")
self._collection_name = self._retrieve_config.get("collection_name", "flaml-docs")
self._model = self._retrieve_config.get("model", "gpt-3.5")
self._max_tokens = self.get_max_tokens(self._model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #1160 , where to get the token limit info may change in future.

self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.6))
self._collection = False # whether the collection is created
self._ipython = get_ipython()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This limits the usage of this agent in ipython. Better decouple that and make this part of the code_execution_config in future.

self._doc_idx = -1 # the index of the current used doc
self._results = [] # the results of the current query

@staticmethod
def get_max_tokens(model="gpt-3.5"):
if "gpt-4-32k" in model:
return 32000
elif "gpt-4" in model:
return 8000
else:
thinkall marked this conversation as resolved.
Show resolved Hide resolved
return 4000

def _reset(self):
self._oai_conversations.clear()
thinkall marked this conversation as resolved.
Show resolved Hide resolved

def receive(self, message: Union[Dict, str], sender: "Agent"):
thinkall marked this conversation as resolved.
Show resolved Hide resolved
"""Receive a message from another agent.
If "Update Context" in message, update the context and reset the messages in the conversation.
"""
message = self._message_to_dict(message)
if "UPDATE CONTEXT" in message.get("content", "")[-20::].upper():
print("Updating context and resetting conversation.")
self._reset()
thinkall marked this conversation as resolved.
Show resolved Hide resolved
results = self._results
doc_contents = ""
_doc_idx = self._doc_idx
for idx, doc in enumerate(results["documents"][0]):
thinkall marked this conversation as resolved.
Show resolved Hide resolved
if idx <= _doc_idx:
continue
_doc_contents = doc_contents + doc + "\n"
if num_tokens_from_text(_doc_contents) > self._chunk_token_size:
thinkall marked this conversation as resolved.
Show resolved Hide resolved
break
print(f"Adding doc_id {results['ids'][0][idx]} to context.")
doc_contents = _doc_contents
self._doc_idx = idx

if self.customized_prompt:
message = (
self.customized_prompt + "\nUser's question is: " + self.problem + "\nContext is: " + doc_contents
)
else:
message = PROMPT.format(input_question=self.problem, input_context=doc_contents)
self.send(message, sender)
else:
super().receive(message, sender)

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
if not self._collection:
create_vector_db_from_dir(
dir_path=self._docs_path,
max_tokens=self._chunk_token_size,
client=self._client,
collection_name=self._collection_name,
)
self._collection = True

results = query_vector_db(
query_texts=[problem],
n_results=n_results,
search_string=search_string,
client=self._client,
collection_name=self._collection_name,
)
self._results = results
print("doc_ids: ", results["ids"])

def generate_init_message(
self, problem: str, customized_prompt: str = "", n_results: int = 20, search_string: str = ""
):
"""Generate a prompt for the assitant agent with the given problem and prompt.
thinkall marked this conversation as resolved.
Show resolved Hide resolved

Args:
problem (str): the problem to be solved.
customized_prompt (str): a customized prompt to be used. If it is not "", the built-in prompt will be
ignored.

Returns:
str: the generated prompt ready to be sent to the assistant agent.
"""
self.reset()
thinkall marked this conversation as resolved.
Show resolved Hide resolved
self.retrieve_docs(problem, n_results, search_string)
results = self._results
doc_contents = ""
for idx, doc in enumerate(results["documents"][0]):
_doc_contents = doc_contents + doc + "\n"
if num_tokens_from_text(_doc_contents) > self._chunk_token_size:
break
print(f"Adding doc_id {results['ids'][0][idx]} to context.")
doc_contents = _doc_contents
self._doc_idx = idx
thinkall marked this conversation as resolved.
Show resolved Hide resolved

if customized_prompt:
self.customized_prompt = customized_prompt
msg = customized_prompt + "\nUser's question is:" + problem + "\nContext is:" + doc_contents
else:
self.customized_prompt = ""
msg = PROMPT.format(input_question=problem, input_context=doc_contents)
thinkall marked this conversation as resolved.
Show resolved Hide resolved
self.problem = problem
return msg

def run_code(self, code, **kwargs):
lang = kwargs.get("lang", None)
if code.startswith("!") or code.startswith("pip") or lang in ["bash", "shell", "sh"]:
return (
0,
bytes(
thinkall marked this conversation as resolved.
Show resolved Hide resolved
"You MUST NOT install any packages because all the packages needed are already installed.", "utf-8"
),
None,
)
result = self._ipython.run_cell(code)
log = str(result.result)
thinkall marked this conversation as resolved.
Show resolved Hide resolved
exitcode = 0 if result.success else 1
if result.error_before_exec is not None:
log += f"\n{result.error_before_exec}"
exitcode = 1
if result.error_in_exec is not None:
log += f"\n{result.error_in_exec}"
exitcode = 1
return exitcode, bytes(log, "utf-8"), None
Loading
Loading