Skip to content

Commit

Permalink
Merge pull request #62 from cubenlp/rex
Browse files Browse the repository at this point in the history
use data2chat in async process
  • Loading branch information
RexWzh committed Nov 2, 2023
2 parents c18699c + e3e716e commit c4ece72
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '2.3.5'
__version__ = '2.4.0'

import os, sys, requests
from .chattool import Chat, Resp
Expand Down
21 changes: 15 additions & 6 deletions chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,13 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
, chat_url:Union[str, None]=None
, max_requests:int=1
, ncoroutines:int=1
, nproc:int=1
, timeout:int=0
, timeinterval:int=0
, clearfile:bool=False
, notrun:bool=False
, msg2log:Union[Callable, None]=None
, data2chat:Union[Callable, None]=None
, max_tokens:Union[Callable, int, None]=None
, **options
):
Expand All @@ -148,23 +150,30 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'.
api_key (Union[str, None], optional): API key. Defaults to None.
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
ncoroutines (int, optional): number of coroutines. Defaults to 5.
ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1.
nproc (int, optional): number of coroutines. Defaults to 1.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
clearfile (bool, optional): whether to clear the checkpoint file. Defaults to False.
notrun (bool, optional): whether to run the async process. It should be True
when use in Jupyter Notebook. Defaults to False.
msg2log (Union[Callable, None], optional): function to convert message to chat log.
msg2log (Union[Callable, None], optional): (Deprecated)function to convert message
to chat log. Defaults to None.
data2chat (Union[Callable, None], optional): function to convert data to Chat object.
Defaults to None.
max_tokens (Union[Callable, int, None], optional): function to calculate the maximum
number of tokens for the API call. Defaults to None.
Returns:
List[Dict]: list of responses
"""
# read chatlogs | use method from the Chat object
if msg2log is None:
msg2log = lambda msg: Chat(msg).chat_log
# read chatlogs. By default, use method from the Chat object
if data2chat is None:
msg2log = lambda data: Chat(data).chat_log
elif msg2log is not None: # deprecated warning
warnings.warn("msg2log is deprecated, use data2chat instead!")
# use nproc instead of ncoroutines
nproc = max(nproc, ncoroutines)
chatlogs = [msg2log(log) for log in msgs]
if clearfile and os.path.exists(chkpoint):
os.remove(chkpoint)
Expand All @@ -184,7 +193,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
"api_key": api_key,
"chat_url": chat_url,
"max_requests": max_requests,
"ncoroutines": ncoroutines,
"ncoroutines": nproc,
"timeout": timeout,
"timeinterval": timeinterval,
"model": model,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '2.3.5'
VERSION = '2.4.0'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23",
Expand Down
9 changes: 9 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def test_async_process():
assert all(resp)
print(f"Time elapsed: {time.time() - t:.2f}s")

# broken test
def test_failed_async():
api_key = chattool.api_key
chattool.api_key = "sk-invalid"
chkpoint = testpath + "test_async_fail.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
resp = async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3)
chattool.api_key = api_key

def test_async_process_withfunc():
chkpoint = testpath + "test_async_withfunc.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
Expand Down

0 comments on commit c4ece72

Please sign in to comment.