Skip to content

Commit

Permalink
Merge pull request #70 from cubenlp/rex/change-param
Browse files Browse the repository at this point in the history
change params
  • Loading branch information
RexWzh committed Dec 27, 2023
2 parents ee2b2a2 + c4a2656 commit aa54757
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 50 deletions.
4 changes: 2 additions & 2 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '3.0.0'
__version__ = '3.0.1'

import os, sys, requests
from .chattool import Chat, Resp
Expand Down Expand Up @@ -129,7 +129,7 @@ def debug_log( net_url:str="https://www.baidu.com"
if test_response:
print("\nTest response:", message)
chat = Chat(message)
chat.getresponse(max_requests=3)
chat.getresponse(max_tries=3)
chat.print_log()

print("\nDebug is finished.")
Expand Down
50 changes: 22 additions & 28 deletions chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def async_post( session
, url
, data:str
, headers:Dict
, max_requests:int=1
, max_tries:int=1
, timeinterval=0
, timeout=0):
"""Asynchronous post request
Expand All @@ -21,7 +21,7 @@ async def async_post( session
url (str): chat completion url
data (str): payload of the request
headers (Dict): request headers
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
Expand All @@ -30,15 +30,15 @@ async def async_post( session
"""
async with sem:
ntries = 0
while max_requests > 0:
while max_tries > 0:
try:
async with session.post(url, headers=headers, data=data, timeout=timeout) as response:
resp = await response.text()
resp = Resp(json.loads(resp))
assert resp.is_valid(), resp.error_message
return resp
except Exception as e:
max_requests -= 1
max_tries -= 1
ntries += 1
time.sleep(random.random() * timeinterval)
print(f"Request Failed({ntries}):{e}")
Expand All @@ -50,11 +50,10 @@ async def async_process_msgs( chatlogs:List[List[Dict]]
, chkpoint:str
, api_key:str
, chat_url:str
, max_requests:int=1
, ncoroutines:int=1
, max_tries:int=1
, nproc:int=1
, timeout:int=0
, timeinterval:int=0
, max_tokens:Union[Callable, None]=None
, **options
)->List[bool]:
"""Process messages asynchronously
Expand All @@ -63,38 +62,36 @@ async def async_process_msgs( chatlogs:List[List[Dict]]
chatlogs (List[List[Dict]]): list of chat logs
chkpoint (str): checkpoint file
api_key (Union[str, None], optional): API key. Defaults to None.
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
ncoroutines (int, optional): number of coroutines. Defaults to 5.
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
nproc (int, optional): number of coroutines. Defaults to 5.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
Returns:
List[bool]: list of responses
"""
# load from checkpoint
chats = load_chats(chkpoint, withid=True) if os.path.exists(chkpoint) else []
chats = load_chats(chkpoint) if os.path.exists(chkpoint) else []
chats.extend([None] * (len(chatlogs) - len(chats)))
costs = [0] * len(chatlogs)
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key
}
ncoroutines += 1 # add one for the main coroutine
sem = asyncio.Semaphore(ncoroutines)
nproc += 1 # add one for the main coroutine
sem = asyncio.Semaphore(nproc)
locker = asyncio.Lock()

async def chat_complete(ind, locker, chat_log, chkpoint, **options):
payload = {"messages": chat_log}
payload.update(options)
if max_tokens is not None:
payload['max_tokens'] = max_tokens(chat_log)
data = json.dumps(payload)
resp = await async_post( session=session
, sem=sem
, url=chat_url
, data=data
, headers=headers
, max_requests=max_requests
, max_tries=max_tries
, timeinterval=timeinterval
, timeout=timeout)
## saving files
Expand Down Expand Up @@ -130,16 +127,16 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
, model:str='gpt-3.5-turbo'
, api_key:Union[str, None]=None
, chat_url:Union[str, None]=None
, max_requests:int=1
, ncoroutines:int=1
, max_tries:int=1
, nproc:int=1
, timeout:int=0
, timeinterval:int=0
, clearfile:bool=False
, notrun:bool=False
, msg2log:Union[Callable, None]=None
, data2chat:Union[Callable, None]=None
, max_tokens:Union[Callable, int, None]=None
, max_requests:int=-1
, ncoroutines:int=1
, **options
):
"""Asynchronous chat completion
Expand All @@ -149,8 +146,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
chkpoint (str): checkpoint file
model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'.
api_key (Union[str, None], optional): API key. Defaults to None.
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1.
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
nproc (int, optional): number of coroutines. Defaults to 1.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
Expand All @@ -161,8 +157,8 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
Defaults to None.
data2chat (Union[Callable, None], optional): function to convert data to Chat object.
Defaults to None.
max_tokens (Union[Callable, int, None], optional): function to calculate the maximum
number of tokens for the API call. Defaults to None.
max_requests (int, optional): (Deprecated)maximum number of requests to make. Defaults to -1.
ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1.
Returns:
List[Dict]: list of responses
Expand All @@ -184,20 +180,18 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
chat_url = chattool.request.normalize_url(chat_url)
# run async process
assert ncoroutines > 0, "ncoroutines must be greater than 0!"
if isinstance(max_tokens, int):
max_tokens = lambda chat_log: max_tokens
assert nproc > 0, "nproc must be greater than 0!"
max_tries = max(max_tries, max_requests)
args = {
"chatlogs": chatlogs,
"chkpoint": chkpoint,
"api_key": api_key,
"chat_url": chat_url,
"max_requests": max_requests,
"ncoroutines": nproc,
"max_tries": max_tries,
"nproc": nproc,
"timeout": timeout,
"timeinterval": timeinterval,
"model": model,
"max_tokens": max_tokens,
**options
}
if notrun: # when use in Jupyter Notebook
Expand Down
11 changes: 7 additions & 4 deletions chattool/chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,22 @@ def print_log(self, sep: Union[str, None]=None):

# Part2: response and async response
def getresponse( self
, max_requests:int=1
, max_tries:int = 1
, timeout:int = 0
, timeinterval:int = 0
, update:bool = True
, stream:bool = False
, max_requests:int=-1
, **options)->Resp:
"""Get the API response
Args:
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
update (bool, optional): whether to update the chat log. Defaults to True.
options (dict, optional): other options like `temperature`, `top_p`, etc.
max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit
Returns:
Resp: API response
Expand All @@ -194,10 +196,11 @@ def getresponse( self
func_call = options.get('function_call', self.function_call)
if api_key is None: warnings.warn("API key is not set!")
msg, resp, numoftries = self.chat_log, None, 0
max_tries = max(max_tries, max_requests)
if stream: # TODO: add the `usage` key to the response
warnings.warn("stream mode is not supported yet! Use `async_stream_responses()` instead.")
# make requests
while max_requests:
while max_tries:
try:
# make API Call
if funcs is not None: options['functions'] = funcs
Expand All @@ -209,7 +212,7 @@ def getresponse( self
assert resp.is_valid(), resp.error_message
break
except Exception as e:
max_requests -= 1
max_tries -= 1
numoftries += 1
time.sleep(random.random() * timeinterval)
print(f"Try again ({numoftries}):{e}\n")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.0.0'
VERSION = '3.0.1'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
12 changes: 5 additions & 7 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ async def show_resp(chat):
def test_async_process():
chkpoint = testpath + "test_async.jsonl"
t = time.time()
resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, ncoroutines=3)
resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, ncoroutines=3)
resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, nproc=3)
assert all(resp)
print(f"Time elapsed: {time.time() - t:.2f}s")

Expand All @@ -55,7 +55,7 @@ def test_failed_async():
chattool.api_key = "sk-invalid"
chkpoint = testpath + "test_async_fail.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
resp = async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3)
resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3)
chattool.api_key = api_key

def test_async_process_withfunc():
Expand All @@ -66,15 +66,13 @@ def msg2log(msg):
chat.system("translate the words from English to Chinese")
chat.user(msg)
return chat.chat_log
def max_tokens(chat_log):
return Chat(chat_log).prompt_token()
async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3, max_tokens=max_tokens, msg2log=msg2log)
async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log)

def test_normal_process():
chkpoint = testpath + "test_nomal.jsonl"
def data2chat(data):
chat = Chat(data)
chat.getresponse(max_requests=3)
chat.getresponse(max_tries=3)
return chat
t = time.time()
process_chats(chatlogs, data2chat, chkpoint, clearfile=True)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

def test_call_weather():
chat = Chat("What's the weather like in Boston?")
resp = chat.getresponse(functions=functions, function_call='auto', max_requests=3)
resp = chat.getresponse(functions=functions, function_call='auto', max_tries=3)
# TODO: wrap the response
if resp.finish_reason == 'function_call':
# test response from chat api
Expand All @@ -54,12 +54,12 @@ def test_auto_response():
chat = Chat("What's the weather like in Boston?")
chat.functions, chat.function_call = functions, 'auto'
chat.name2func = name2func
chat.autoresponse(max_requests=2)
chat.autoresponse(max_tries=2)
chat.print_log()
chat.clear()
# response with nonempty content
chat.user("what is the result of 1+1, and What's the weather like in Boston?")
chat.autoresponse(max_requests=2)
chat.autoresponse(max_tries=2)

# generate docstring from functions
def add(a: int, b: int) -> int:
Expand Down Expand Up @@ -100,20 +100,20 @@ def test_add_and_mult():
chat.name2func = {'add': add} # dictionary of functions
chat.function_call = 'auto' # auto decision
# run until success: maxturns=-1
chat.autoresponse(max_requests=3, display=True, timeinterval=2)
chat.autoresponse(max_tries=3, display=True, timeinterval=2)
# response should be finished
chat.simplify()
chat.print_log()
# use the setfuncs method
chat = Chat("find the value of 124842 * 3423424")
chat.setfuncs([add, mult]) # multi choice
chat.autoresponse(max_requests=3, timeinterval=2)
chat.autoresponse(max_tries=3, timeinterval=2)
chat.simplify() # simplify the chat log
chat.print_log()
# test multichoice
chat.clear()
chat.user("find the value of 23723 + 12312, and 23723 * 12312")
chat.autoresponse(max_requests=3, timeinterval=2)
chat.autoresponse(max_tries=3, timeinterval=2)

def test_mock_resp():
chat = Chat("find the sum of 1235 and 3423")
Expand All @@ -122,12 +122,12 @@ def test_mock_resp():
para = {'name': 'add', 'arguments': '{\n "a": 1235,\n "b": 3423\n}'}
chat.assistant(content=None, function_call=para)
chat.callfunction()
chat.getresponse(max_requests=2)
chat.getresponse(max_tries=2)

def test_use_exec_function():
chat = Chat("find the result of sqrt(121314)")
chat.setfuncs([exec_python_code])
chat.autoresponse(max_requests=2)
chat.autoresponse(max_tries=2)

def test_find_permutation_group():
pass
Expand Down

0 comments on commit aa54757

Please sign in to comment.