diff --git a/README-EN.md b/README-EN.md index 73b4c0c..8d9d600 100644 --- a/README-EN.md +++ b/README-EN.md @@ -35,7 +35,7 @@ from random import randint msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)] # Annotate some data and get interrupted -process_messages(msgs[:2], "test.jsonl", time_interval=5, max_tries=3) +process_messages(msgs[:2], "test.jsonl", interval=5, max_tries=3) # Continue annotation process_messages(msgs, "test.jsonl") ``` diff --git a/README.md b/README.md index c664a47..01bf802 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ from random import randint msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)] # 标注一部分后被中断 -process_messages(msgs[:2], "test.jsonl", time_interval=5, max_tries=3) +process_messages(msgs[:2], "test.jsonl", interval=5, max_tries=3) # 继续标注 process_messages(msgs, "test.jsonl") ``` diff --git a/setup.py b/setup.py index 5833567..3c78bf3 100644 --- a/setup.py +++ b/setup.py @@ -3,21 +3,20 @@ """The setup script.""" from setuptools import setup, find_packages -VERSION = '0.1.1' +VERSION = '0.2.0' with open('README.md') as readme_file: readme = readme_file.read() requirements = [ ] -test_requirements = ['pytest>=3', ] +test_requirements = ['pytest>=3', 'tqdm>=4.60', 'chattool>=3.0.0'] setup( author="Rex Wang", author_email='1073853456@qq.com', python_requires='>=3.6', classifiers=[ - 'Development Status :: 2 - Pre-Alpha', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Natural Language :: English', diff --git a/tests/test_process.py b/tests/test_process.py index ad497a6..41a8879 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -4,5 +4,5 @@ def test_process_msgs(): msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)] - process_messages(msgs[:2], "test.jsonl") - process_messages(msgs, "test.jsonl") \ No newline at end of file + process_messages(msgs[:2], "test.jsonl", max_tries=3) + process_messages(msgs, "test.jsonl", max_tries=3, interval=3) \ No newline at end of file diff --git a/webchatter/__init__.py b/webchatter/__init__.py index a456983..7c62a6c 100644 --- a/webchatter/__init__.py +++ b/webchatter/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '0.1.1' +__version__ = '0.2.0' import os, dotenv, requests from typing import Union diff --git a/webchatter/checkpoint.py b/webchatter/checkpoint.py index ea66c4c..ee37829 100644 --- a/webchatter/checkpoint.py +++ b/webchatter/checkpoint.py @@ -1,15 +1,32 @@ -import os, json +import os, json, warnings, time from webchatter import WebChat import tqdm, tqdm.notebook -import time +from chattool import load_chats, Chat +from typing import List, Callable # from chattool import load_chats -def process_messages( msgs +def try_sth(func:Callable, max_tries:int, interval:float, *args, **kwargs): + """Try something. + + Args: + func (Callable): The function to try. + max_tries (int): The maximum number of tries. + interval (float): The interval between tries. + """ + while max_tries: + try: + return func(*args, **kwargs) + except Exception as e: + print(e) + max_tries -= 1 + time.sleep(interval) + return None + +def process_messages( msgs:List[str] , checkpoint:str - , time_interval:int=5 + , interval:int=5 , max_tries:int=-1 , isjupyter:bool=False - , interval_rate:float=1 ): """Process the messages. @@ -17,33 +34,26 @@ def process_messages( msgs msgs (list): The messages. checkpoint (str): Store the checkpoint. + Returns: list: The processed messages. """ - offset = 0 - if os.path.exists(checkpoint): - with open(checkpoint, 'r', encoding='utf-8') as f: - processed = f.read().strip().split('\n') - if len(processed) >= 1 and processed[0] != '': - offset = len(processed) + chats = load_chats(checkpoint) + if len(chats) > len(msgs): + warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed") + return chats[:len(msgs)] + chats.extend([None] * (len(msgs) - len(chats))) tq = tqdm.tqdm if not isjupyter else tqdm.notebook.tqdm - chat = WebChat() - with open(checkpoint, 'a', encoding='utf-8') as f: - for ind in tq(range(offset, len(msgs))): - wait_time = time_interval - while max_tries: - try: - msg = msgs[ind] - ans = chat.ask(msg, keep=False) - data = {"index":ind + offset, "chat_log":{"user":msg, "assistant":ans}} - f.write(json.dumps(data) + '\n') - break - except Exception as e: - print(ind, e) - max_tries -= 1 - time.sleep(wait_time) - wait_time = wait_time * interval_rate - return True + # process chats + webchat, chat = WebChat(), Chat() + for ind in tq(range(len(chats))): + if chats[ind] is not None: continue + ans = try_sth(webchat.ask, max_tries, interval, msgs[ind]) + chat = Chat(msgs[ind]) + chat.assistant(ans) + chat.save(checkpoint, mode='a', index=ind) + chats[ind] = chat + return chats def process_chats(chats, checkpoint:str): """Process the chats. diff --git a/webchatter/webchatter.py b/webchatter/webchatter.py index 4d903ea..d3adb49 100644 --- a/webchatter/webchatter.py +++ b/webchatter/webchatter.py @@ -283,7 +283,6 @@ def save( self, file:str # make path if not exists pathname = os.path.dirname(file).strip() if pathname != '': os.makedirs(pathname, exist_ok=True) - if chat_log_only: data = { "index": index,