Skip to content

Commit

Permalink
Merge pull request #9 from ijwfly/feature/cancel-streaming
Browse files Browse the repository at this point in the history
Feature/cancel streaming
  • Loading branch information
ijwfly committed Nov 5, 2023
2 parents b8c313b + eb25c7c commit 5716ae1
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 58 deletions.
58 changes: 58 additions & 0 deletions app/bot/cancellation_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from aiogram import types


CANCELLATION_PREFIX = 'cancel'


class CancellationToken:
"""
Class that represents a cancellation token
"""
def __init__(self):
self.is_canceled = False

def __call__(self):
return self.is_canceled

def cancel(self):
self.is_canceled = True


class CancellationManager:
"""
Class that manages the cancellation of message processing for streaming messages
"""
def __init__(self, bot, dispatcher):
self._cancellation_tokens = {}
dispatcher.register_callback_query_handler(self.process_callback, lambda c: CANCELLATION_PREFIX in c.data)
self.bot = bot

async def process_callback(self, callback_query: types.CallbackQuery):
"""
Process the telegram callback query
"""
chat_id = callback_query.from_user.id
self.cancel(chat_id)
await self.bot.answer_callback_query(callback_query.id)

def get_token(self, tg_user_id):
"""
Get a cancellation token for the user
"""
key = str(tg_user_id)
if key not in self._cancellation_tokens:
self._cancellation_tokens[key] = CancellationToken()
return self._cancellation_tokens[key]

def cancel(self, tg_user_id):
"""
Cancel the message processing for the user
"""
key = str(tg_user_id)
if key in self._cancellation_tokens:
self._cancellation_tokens[key].cancel()
del self._cancellation_tokens[key]


def get_cancel_button():
return types.InlineKeyboardButton(text='Stop', callback_data=f'{CANCELLATION_PREFIX}.cancel')
56 changes: 10 additions & 46 deletions app/bot/chatgpt_manager.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,33 @@
from datetime import datetime
from typing import List
from typing import List, AsyncGenerator, Callable

from app.openai_helpers.chatgpt import DialogMessage
from app.storage.db import DB, User

from aiogram.types import Message


WAIT_BETWEEN_UPDATES = 2


class ChatGptManager:
def __init__(self, chatgpt, db):
self.chatgpt = chatgpt
self.db: DB = db

async def send_user_message(self, user: User, tg_message: Message, messages: List[DialogMessage]) -> DialogMessage:
async def send_user_message(self, user: User, messages: List[DialogMessage], is_cancelled: Callable[[], bool]) -> AsyncGenerator[DialogMessage, None]:
if user.streaming_answers:
return await self.send_user_message_streaming(user, tg_message, messages)
return self.send_user_message_streaming(user, messages, is_cancelled)
else:
return await self.send_user_message_sync(user, messages)
return self.send_user_message_sync(user, messages)

async def send_user_message_sync(self, user: User, messages: List[DialogMessage]) -> DialogMessage:
async def send_user_message_sync(self, user: User, messages: List[DialogMessage]) -> AsyncGenerator[DialogMessage, None]:
dialog_message, completion_usage = await self.chatgpt.send_messages(messages)
await self.db.create_completion_usage(user.id, completion_usage.prompt_tokens, completion_usage.completion_tokens, completion_usage.total_tokens, completion_usage.model)
return dialog_message
yield dialog_message

async def send_user_message_streaming(self, user: User, tg_message: Message, messages: List[DialogMessage]) -> DialogMessage:
async def send_user_message_streaming(self, user: User, messages: List[DialogMessage], is_cancelled: Callable[[], bool]) -> AsyncGenerator[DialogMessage, None]:
dialog_message = None
completion_usage = None
message_id = None
chat_id = None
previous_content = None
previous_time = None
async for dialog_message, completion_usage in self.chatgpt.send_messages_streaming(messages):
if dialog_message.function_call is not None:
continue

new_content = ' '.join(dialog_message.content.strip().split(' ')[:-1]) if dialog_message.content else ''
if len(new_content) < 50:
continue

# send message
if not message_id:
resp = await tg_message.answer(dialog_message.content)
chat_id = tg_message.chat.id
# hack: most telegram clients remove "typing" status after receiving new message from bot
await tg_message.bot.send_chat_action(chat_id, 'typing')
message_id = resp.message_id
previous_content = dialog_message.content
previous_time = datetime.now()
continue

# update message
time_passed_seconds = (datetime.now() - previous_time).seconds
if previous_content != new_content and time_passed_seconds >= WAIT_BETWEEN_UPDATES:
await tg_message.bot.edit_message_text(new_content, chat_id, message_id)
previous_content = new_content
previous_time = datetime.now()

if message_id:
await tg_message.bot.delete_message(chat_id, message_id)
async for dialog_message, completion_usage in self.chatgpt.send_messages_streaming(messages, is_cancelled):
yield dialog_message

if dialog_message is None or completion_usage is None:
raise ValueError("Call to ChatGPT failed")

await self.db.create_completion_usage(user.id, completion_usage.prompt_tokens, completion_usage.completion_tokens, completion_usage.total_tokens, completion_usage.model)
return dialog_message
yield dialog_message
72 changes: 63 additions & 9 deletions app/bot/message_processor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from aiogram.types import Message, ParseMode
from datetime import datetime

from aiogram import types

from app.bot.cancellation_manager import get_cancel_button
from app.bot.chatgpt_manager import ChatGptManager
from app.bot.utils import send_telegram_message, detect_and_extract_code
from app.bot.utils import send_telegram_message, detect_and_extract_code, edit_telegram_message
from app.context.context_manager import build_context_manager
from app.context.dialog_manager import DialogUtils
from app.openai_helpers.chatgpt import ChatGPT
from app.storage.db import DB, User

from aiogram.types import Message, ParseMode

WAIT_BETWEEN_MESSAGE_UPDATES = 2


class MessageProcessor:
def __init__(self, db: DB, user: User, message: Message):
Expand All @@ -19,21 +26,22 @@ async def add_text_as_context(self, text: str, message_id: int):
speech_dialog_message = DialogUtils.prepare_user_message(text)
await context_manager.add_message(speech_dialog_message, message_id)

async def process_message(self):
async def process_message(self, is_cancelled):
context_manager = await build_context_manager(self.db, self.user, self.message)

function_storage = await context_manager.get_function_storage()
chat_gpt_manager = ChatGptManager(ChatGPT(self.user.current_model, self.user.gpt_mode, function_storage), self.db)

user_dialog_message = DialogUtils.prepare_user_message(self.message.text)
context_dialog_messages = await context_manager.add_message(user_dialog_message, self.message.message_id)
response_dialog_message = await chat_gpt_manager.send_user_message(self.user, self.message, context_dialog_messages)
response_generator = await chat_gpt_manager.send_user_message(self.user, context_dialog_messages, is_cancelled)

await self.handle_gpt_response(
chat_gpt_manager, context_manager, response_dialog_message, function_storage
chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled
)

async def handle_gpt_response(self, chat_gpt_manager, context_manager, response_dialog_message, function_storage):
async def handle_gpt_response(self, chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled):
response_dialog_message, message_id = await self.handle_response_generator(response_generator)
if response_dialog_message.function_call:
function_name = response_dialog_message.function_call.name
function_args = response_dialog_message.function_call.arguments
Expand All @@ -47,11 +55,57 @@ async def handle_gpt_response(self, chat_gpt_manager, context_manager, response_
else:
function_response_message_id = -1
context_dialog_messages = await context_manager.add_message(function_response, function_response_message_id)
response_dialog_message = await chat_gpt_manager.send_user_message(self.user, self.message, context_dialog_messages)
response_generator = await chat_gpt_manager.send_user_message(self.user, context_dialog_messages, is_cancelled)

await self.handle_gpt_response(chat_gpt_manager, context_manager, response_dialog_message, function_storage)
await self.handle_gpt_response(chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled)
else:
code_fragments = detect_and_extract_code(response_dialog_message.content)
parse_mode = ParseMode.MARKDOWN if code_fragments else None
response = await send_telegram_message(self.message, response_dialog_message.content, parse_mode)
if message_id is not None:
response = await edit_telegram_message(self.message, response_dialog_message.content, message_id, parse_mode)
else:
response = await send_telegram_message(self.message, response_dialog_message.content, parse_mode)
await context_manager.add_message(response_dialog_message, response.message_id)

async def handle_response_generator(self, response_generator):
dialog_message = None
message_id = None
chat_id = None
previous_content = None
previous_time = None

keyboard = types.InlineKeyboardMarkup()
keyboard.add(get_cancel_button())

first_iteration = True
async for dialog_message in response_generator:
if first_iteration:
# HACK: skip first iteration for case with full synchronous openai response
first_iteration = False
continue

if dialog_message.function_call is not None:
continue

new_content = ' '.join(dialog_message.content.strip().split(' ')[:-1]) if dialog_message.content else ''
if len(new_content) < 50:
continue

# send message
if not message_id:
resp = await send_telegram_message(self.message, dialog_message.content, reply_markup=keyboard)
chat_id = self.message.chat.id
# hack: most telegram clients remove "typing" status after receiving new message from bot
await self.message.bot.send_chat_action(chat_id, 'typing')
message_id = resp.message_id
previous_content = dialog_message.content
previous_time = datetime.now()
continue

# update message
time_passed_seconds = (datetime.now() - previous_time).seconds
if previous_content != new_content and time_passed_seconds >= WAIT_BETWEEN_MESSAGE_UPDATES:
await self.message.bot.edit_message_text(new_content, chat_id, message_id, reply_markup=keyboard)
previous_content = new_content
previous_time = datetime.now()
return dialog_message, message_id
6 changes: 5 additions & 1 deletion app/bot/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dateutil.relativedelta import relativedelta

import settings
from app.bot.cancellation_manager import CancellationManager
from app.bot.message_processor import MessageProcessor
from app.bot.scheduled_tasks import build_monthly_usage_task
from app.bot.settings_menu import Settings
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self, bot: Bot, dispatcher: Dispatcher):

# initialized in on_startup
self.settings = None
self.cancellation_manager = None
self.role_manager = None
self.monthly_usage_task = None

Expand All @@ -51,6 +53,7 @@ async def on_startup(self, _):
settings.POSTGRES_HOST, settings.POSTGRES_PORT, settings.POSTGRES_DATABASE
)
self.settings = Settings(self.bot, self.dispatcher, self.db)
self.cancellation_manager = CancellationManager(self.bot, self.dispatcher)
self.role_manager = UserRoleManager(self.bot, self.dispatcher, self.db)
self.dispatcher.middleware.setup(UserMiddleware(self.db))

Expand Down Expand Up @@ -143,7 +146,8 @@ async def handle_voice(self, message: types.Message, user: User):

async def answer_text_message(self, message: types.Message, user: User):
message_processor = MessageProcessor(self.db, user, message)
await message_processor.process_message()
is_cancelled = self.cancellation_manager.get_token(message.from_user.id)
await message_processor.process_message(is_cancelled)

async def reset_dialog(self, message: types.Message, user: User):
await self.db.create_reset_message(user.id, message.chat.id)
Expand Down
9 changes: 9 additions & 0 deletions app/bot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ async def send_telegram_message(message: types.Message, text: str, parse_mode=No
return await send_message(text, reply_markup=reply_markup)


async def edit_telegram_message(message: types.Message, text: str, message_id, parse_mode=None):
chat_id = message.chat.id
try:
return await message.bot.edit_message_text(text, chat_id, message_id, parse_mode=parse_mode)
except CantParseEntities:
# try to edit message without parse_mode once
return await message.bot.edit_message_text(text, chat_id, message_id)


def merge_dicts(dict_1, dict_2):
"""
This function merge two dicts containing strings using plus operator on each key
Expand Down
12 changes: 10 additions & 2 deletions app/openai_helpers/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import List, Any, Optional
from contextlib import suppress
from typing import List, Any, Optional, Callable

import settings
from app.bot.utils import merge_dicts
Expand Down Expand Up @@ -82,7 +83,7 @@ async def send_messages(self, messages_to_send: List[DialogMessage]) -> (DialogM
response = DialogMessage(**message)
return response, completion_usage

async def send_messages_streaming(self, messages_to_send: List[DialogMessage]) -> (DialogMessage, CompletionUsage):
async def send_messages_streaming(self, messages_to_send: List[DialogMessage], is_cancelled: Callable[[], bool]) -> (DialogMessage, CompletionUsage):
prompt_tokens = 0

additional_fields = {}
Expand Down Expand Up @@ -133,6 +134,13 @@ async def send_messages_streaming(self, messages_to_send: List[DialogMessage]) -
total_tokens=prompt_tokens + completion_tokens,
)
yield dialog_message, completion_usage
if is_cancelled():
# some more tokens may be generated after cancellation
completion_usage.completion_tokens += 20
with suppress(BaseException):
# sometimes this call throws an exception since python 3.8
await resp_generator.aclose()
break

@staticmethod
def create_context(messages: List[DialogMessage], gpt_mode) -> List[Any]:
Expand Down

0 comments on commit 5716ae1

Please sign in to comment.