Skip to content

Commit

Permalink
Merge pull request #64 from cubenlp/rexwzh/fix-patch
Browse files Browse the repository at this point in the history
Rexwzh/fix patch
  • Loading branch information
RexWzh committed Nov 11, 2023
2 parents c4ece72 + a881fad commit 37d5255
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '2.4.0'
__version__ = '2.4.2'

import os, sys, requests
from .chattool import Chat, Resp
Expand Down
12 changes: 6 additions & 6 deletions chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
clearfile (bool, optional): whether to clear the checkpoint file. Defaults to False.
notrun (bool, optional): whether to run the async process. It should be True
when use in Jupyter Notebook. Defaults to False.
msg2log (Union[Callable, None], optional): (Deprecated)function to convert message
to chat log. Defaults to None.
msg2log (Union[Callable, None], optional): function to convert message to chat log.
Defaults to None.
data2chat (Union[Callable, None], optional): function to convert data to Chat object.
Defaults to None.
max_tokens (Union[Callable, int, None], optional): function to calculate the maximum
Expand All @@ -167,11 +167,11 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
Returns:
List[Dict]: list of responses
"""
# read chatlogs. By default, use method from the Chat object
if data2chat is None:
# convert chatlogs
if data2chat is not None:
msg2log = lambda data: data2chat(data).chat_log
elif msg2log is None: # By default, use method from the Chat object
msg2log = lambda data: Chat(data).chat_log
elif msg2log is not None: # deprecated warning
warnings.warn("msg2log is deprecated, use data2chat instead!")
# use nproc instead of ncoroutines
nproc = max(nproc, ncoroutines)
chatlogs = [msg2log(log) for log in msgs]
Expand Down
16 changes: 14 additions & 2 deletions chattool/chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .request import chat_completion, valid_models
import time, random, json
import aiohttp
import os
from .functioncall import generate_json_schema, delete_dialogue_assist

class Chat():
Expand Down Expand Up @@ -110,6 +111,8 @@ def save(self, path:str, mode:str='a'):
mode (str, optional): mode to open the file. Defaults to 'a'.
"""
assert mode in ['a', 'w'], "saving mode should be 'a' or 'w'"
# make path if not exists
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, mode, encoding='utf-8') as f:
f.write(json.dumps(self.chat_log, ensure_ascii=False) + '\n')
return
Expand All @@ -123,19 +126,24 @@ def savewithid(self, path:str, chatid:int, mode:str='a'):
mode (str, optional): mode to open the file. Defaults to 'a'.
"""
assert mode in ['a', 'w'], "saving mode should be 'a' or 'w'"
# make path if not exists
os.makedirs(os.path.dirname(path), exist_ok=True)
data = {"chatid": chatid, "chatlog": self.chat_log}
with open(path, mode, encoding='utf-8') as f:
f.write(json.dumps(data, ensure_ascii=False) + '\n')
return

def savewithmsg(self, path:str, mode:str='a'):
"""Save the chat log with message.
This is for fine-tuning the model.
Args:
path (str): path to the file
mode (str, optional): mode to open the file. Defaults to 'a'.
"""
assert mode in ['a', 'w'], "saving mode should be 'a' or 'w'"
# make path if not exists
os.makedirs(os.path.dirname(path), exist_ok=True)
data = {"messages": self.chat_log}
with open(path, mode, encoding='utf-8') as f:
f.write(json.dumps(data, ensure_ascii=False) + '\n')
Expand Down Expand Up @@ -213,11 +221,12 @@ def getresponse( self
self._resp = resp
return resp

async def async_stream_responses(self, timeout=0):
async def async_stream_responses(self, timeout:int=0, textonly:bool=False):
"""Post request asynchronously and stream the responses
Args:
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
textonly (bool, optional): whether to only return the text. Defaults to True.
Returns:
str: response text
Expand All @@ -238,7 +247,10 @@ async def async_stream_responses(self, timeout=0):
line = json.loads(strline)
resp = Resp(line)
if resp.finish_reason == 'stop': break
yield resp
if textonly:
yield resp.delta_content
else:
yield resp

# Part3: function call
def iswaiting(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '2.4.0'
VERSION = '2.4.2'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23",
Expand Down
7 changes: 7 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ async def show_resp(chat):
chat = Chat("Print hello using Python")
asyncio.run(show_resp(chat))

def test_async_typewriter2():
async def show_resp(chat):
async for txt in chat.async_stream_responses(textonly=True):
print(txt, end='')
chat = Chat("Print hello using Python")
asyncio.run(show_resp(chat))

def test_async_process():
chkpoint = testpath + "test_async.jsonl"
t = time.time()
Expand Down

0 comments on commit 37d5255

Please sign in to comment.