Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use data2chat in async process #62

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '2.3.5'
__version__ = '2.4.0'

import os, sys, requests
from .chattool import Chat, Resp
Expand Down
21 changes: 15 additions & 6 deletions chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,13 @@
, chat_url:Union[str, None]=None
, max_requests:int=1
, ncoroutines:int=1
, nproc:int=1
, timeout:int=0
, timeinterval:int=0
, clearfile:bool=False
, notrun:bool=False
, msg2log:Union[Callable, None]=None
, data2chat:Union[Callable, None]=None
, max_tokens:Union[Callable, int, None]=None
, **options
):
Expand All @@ -148,23 +150,30 @@
model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'.
api_key (Union[str, None], optional): API key. Defaults to None.
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
ncoroutines (int, optional): number of coroutines. Defaults to 5.
ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1.
nproc (int, optional): number of coroutines. Defaults to 1.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
clearfile (bool, optional): whether to clear the checkpoint file. Defaults to False.
notrun (bool, optional): whether to run the async process. It should be True
when use in Jupyter Notebook. Defaults to False.
msg2log (Union[Callable, None], optional): function to convert message to chat log.
msg2log (Union[Callable, None], optional): (Deprecated)function to convert message
to chat log. Defaults to None.
data2chat (Union[Callable, None], optional): function to convert data to Chat object.
Defaults to None.
max_tokens (Union[Callable, int, None], optional): function to calculate the maximum
number of tokens for the API call. Defaults to None.

Returns:
List[Dict]: list of responses
"""
# read chatlogs | use method from the Chat object
if msg2log is None:
msg2log = lambda msg: Chat(msg).chat_log
# read chatlogs. By default, use method from the Chat object
if data2chat is None:
msg2log = lambda data: Chat(data).chat_log
elif msg2log is not None: # deprecated warning
warnings.warn("msg2log is deprecated, use data2chat instead!")

Check warning on line 174 in chattool/asynctool.py

View check run for this annotation

Codecov / codecov/patch

chattool/asynctool.py#L173-L174

Added lines #L173 - L174 were not covered by tests
# use nproc instead of ncoroutines
nproc = max(nproc, ncoroutines)
chatlogs = [msg2log(log) for log in msgs]
if clearfile and os.path.exists(chkpoint):
os.remove(chkpoint)
Expand All @@ -184,7 +193,7 @@
"api_key": api_key,
"chat_url": chat_url,
"max_requests": max_requests,
"ncoroutines": ncoroutines,
"ncoroutines": nproc,
"timeout": timeout,
"timeinterval": timeinterval,
"model": model,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '2.3.5'
VERSION = '2.4.0'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23",
Expand Down
9 changes: 9 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def test_async_process():
assert all(resp)
print(f"Time elapsed: {time.time() - t:.2f}s")

# broken test
def test_failed_async():
api_key = chattool.api_key
chattool.api_key = "sk-invalid"
chkpoint = testpath + "test_async_fail.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
resp = async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3)
chattool.api_key = api_key

def test_async_process_withfunc():
chkpoint = testpath + "test_async_withfunc.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
Expand Down
Loading