Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修改 load_chats逻辑,增加 warning 提升未完成的数量 #83

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '3.3.1'
__version__ = '3.3.2'

import os, sys, requests, json
from .chattype import Chat, Resp
Expand Down
25 changes: 12 additions & 13 deletions chattool/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json, warnings, os
import json, os
from typing import List, Dict, Union, Callable, Any
from .chattype import Chat
import tqdm
from loguru import logger

def load_chats( checkpoint:str):
"""Load chats from a checkpoint file
Expand All @@ -23,18 +24,16 @@ def load_chats( checkpoint:str):
if len(txts) == 1 and txts[0] == '': return []
# get the chatlogs
logs = [json.loads(txt) for txt in txts]
chat_size, chatlogs = 1, [None]
for log in logs:
idx = log['index']
if idx >= chat_size: # extend chatlogs
chatlogs.extend([None] * (idx - chat_size + 1))
chat_size = idx + 1
chatlogs[idx] = log['chat_log']
# mapping from index to chat object
idx2chatlog = { log['index']: Chat(log['chat_log']) for log in logs }
max_index = max(idx2chatlog.keys())
chat_objects = [ idx2chatlog.get(index, None) for index in range(max_index+1)]
num_unfinished = chat_objects.count(None)
# check if there are missing chatlogs
if None in chatlogs:
warnings.warn(f"checkpoint file {checkpoint} has unfinished chats")
if num_unfinished > 0:
logger.warning(f"checkpoint file {checkpoint} has {num_unfinished}/{max_index+1} unfinished chats")
# return Chat class
return [Chat(chat_log) if chat_log is not None else None for chat_log in chatlogs]
return chat_objects

def process_chats( data:List[Any]
, data2chat:Callable[[Any], Chat]
Expand All @@ -59,7 +58,7 @@ def process_chats( data:List[Any]
## load chats from the checkpoint file
chats = load_chats(checkpoint)
if len(chats) > len(data):
warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed")
logger.warning(f"checkpoint file {checkpoint} has more chats than the data to be processed")
return chats[:len(data)]
chats.extend([None] * (len(data) - len(chats)))
## process chats
Expand All @@ -69,4 +68,4 @@ def process_chats( data:List[Any]
chat = data2chat(data[i])
chat.save(checkpoint, mode='a', index=i)
chats[i] = chat
return chats
return chats
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.3.1'
VERSION = '3.3.2'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
Loading