-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from ijwfly/feature/cancel-streaming
Feature/cancel streaming
- Loading branch information
Showing
6 changed files
with
155 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from aiogram import types | ||
|
||
|
||
CANCELLATION_PREFIX = 'cancel' | ||
|
||
|
||
class CancellationToken: | ||
""" | ||
Class that represents a cancellation token | ||
""" | ||
def __init__(self): | ||
self.is_canceled = False | ||
|
||
def __call__(self): | ||
return self.is_canceled | ||
|
||
def cancel(self): | ||
self.is_canceled = True | ||
|
||
|
||
class CancellationManager: | ||
""" | ||
Class that manages the cancellation of message processing for streaming messages | ||
""" | ||
def __init__(self, bot, dispatcher): | ||
self._cancellation_tokens = {} | ||
dispatcher.register_callback_query_handler(self.process_callback, lambda c: CANCELLATION_PREFIX in c.data) | ||
self.bot = bot | ||
|
||
async def process_callback(self, callback_query: types.CallbackQuery): | ||
""" | ||
Process the telegram callback query | ||
""" | ||
chat_id = callback_query.from_user.id | ||
self.cancel(chat_id) | ||
await self.bot.answer_callback_query(callback_query.id) | ||
|
||
def get_token(self, tg_user_id): | ||
""" | ||
Get a cancellation token for the user | ||
""" | ||
key = str(tg_user_id) | ||
if key not in self._cancellation_tokens: | ||
self._cancellation_tokens[key] = CancellationToken() | ||
return self._cancellation_tokens[key] | ||
|
||
def cancel(self, tg_user_id): | ||
""" | ||
Cancel the message processing for the user | ||
""" | ||
key = str(tg_user_id) | ||
if key in self._cancellation_tokens: | ||
self._cancellation_tokens[key].cancel() | ||
del self._cancellation_tokens[key] | ||
|
||
|
||
def get_cancel_button(): | ||
return types.InlineKeyboardButton(text='Stop', callback_data=f'{CANCELLATION_PREFIX}.cancel') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,33 @@ | ||
from datetime import datetime | ||
from typing import List | ||
from typing import List, AsyncGenerator, Callable | ||
|
||
from app.openai_helpers.chatgpt import DialogMessage | ||
from app.storage.db import DB, User | ||
|
||
from aiogram.types import Message | ||
|
||
|
||
WAIT_BETWEEN_UPDATES = 2 | ||
|
||
|
||
class ChatGptManager: | ||
def __init__(self, chatgpt, db): | ||
self.chatgpt = chatgpt | ||
self.db: DB = db | ||
|
||
async def send_user_message(self, user: User, tg_message: Message, messages: List[DialogMessage]) -> DialogMessage: | ||
async def send_user_message(self, user: User, messages: List[DialogMessage], is_cancelled: Callable[[], bool]) -> AsyncGenerator[DialogMessage, None]: | ||
if user.streaming_answers: | ||
return await self.send_user_message_streaming(user, tg_message, messages) | ||
return self.send_user_message_streaming(user, messages, is_cancelled) | ||
else: | ||
return await self.send_user_message_sync(user, messages) | ||
return self.send_user_message_sync(user, messages) | ||
|
||
async def send_user_message_sync(self, user: User, messages: List[DialogMessage]) -> DialogMessage: | ||
async def send_user_message_sync(self, user: User, messages: List[DialogMessage]) -> AsyncGenerator[DialogMessage, None]: | ||
dialog_message, completion_usage = await self.chatgpt.send_messages(messages) | ||
await self.db.create_completion_usage(user.id, completion_usage.prompt_tokens, completion_usage.completion_tokens, completion_usage.total_tokens, completion_usage.model) | ||
return dialog_message | ||
yield dialog_message | ||
|
||
async def send_user_message_streaming(self, user: User, tg_message: Message, messages: List[DialogMessage]) -> DialogMessage: | ||
async def send_user_message_streaming(self, user: User, messages: List[DialogMessage], is_cancelled: Callable[[], bool]) -> AsyncGenerator[DialogMessage, None]: | ||
dialog_message = None | ||
completion_usage = None | ||
message_id = None | ||
chat_id = None | ||
previous_content = None | ||
previous_time = None | ||
async for dialog_message, completion_usage in self.chatgpt.send_messages_streaming(messages): | ||
if dialog_message.function_call is not None: | ||
continue | ||
|
||
new_content = ' '.join(dialog_message.content.strip().split(' ')[:-1]) if dialog_message.content else '' | ||
if len(new_content) < 50: | ||
continue | ||
|
||
# send message | ||
if not message_id: | ||
resp = await tg_message.answer(dialog_message.content) | ||
chat_id = tg_message.chat.id | ||
# hack: most telegram clients remove "typing" status after receiving new message from bot | ||
await tg_message.bot.send_chat_action(chat_id, 'typing') | ||
message_id = resp.message_id | ||
previous_content = dialog_message.content | ||
previous_time = datetime.now() | ||
continue | ||
|
||
# update message | ||
time_passed_seconds = (datetime.now() - previous_time).seconds | ||
if previous_content != new_content and time_passed_seconds >= WAIT_BETWEEN_UPDATES: | ||
await tg_message.bot.edit_message_text(new_content, chat_id, message_id) | ||
previous_content = new_content | ||
previous_time = datetime.now() | ||
|
||
if message_id: | ||
await tg_message.bot.delete_message(chat_id, message_id) | ||
async for dialog_message, completion_usage in self.chatgpt.send_messages_streaming(messages, is_cancelled): | ||
yield dialog_message | ||
|
||
if dialog_message is None or completion_usage is None: | ||
raise ValueError("Call to ChatGPT failed") | ||
|
||
await self.db.create_completion_usage(user.id, completion_usage.prompt_tokens, completion_usage.completion_tokens, completion_usage.total_tokens, completion_usage.model) | ||
return dialog_message | ||
yield dialog_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters