From edc1e0c9a4d37ef714ed9372ffa6b2a50a57d3cc Mon Sep 17 00:00:00 2001 From: bugfloyd Date: Sat, 15 Apr 2023 20:55:40 +0200 Subject: [PATCH 1/2] Add streaming support for inline-query responses --- bot/telegram_bot.py | 108 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 85 insertions(+), 23 deletions(-) diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index 371b66ee..b88f6384 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -243,7 +243,7 @@ async def _generate(): parse_mode=constants.ParseMode.MARKDOWN ) - await self.wrap_with_indicator(update, context, constants.ChatAction.UPLOAD_PHOTO, _generate) + await self.wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO) async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -368,7 +368,7 @@ async def _execute(): if os.path.exists(filename): os.remove(filename) - await self.wrap_with_indicator(update, context, constants.ChatAction.TYPING, _execute) + await self.wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -512,7 +512,7 @@ async def _reply(): except Exception as exception: raise exception - await self.wrap_with_indicator(update, context, constants.ChatAction.TYPING, _reply) + await self.wrap_with_indicator(update, context, _reply, constants.ChatAction.TYPING) self.add_chat_request_to_usage_tracker(user_id, total_tokens) @@ -580,8 +580,9 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo try: if callback_data.startswith(callback_data_suffix): unique_id = callback_data.split(':')[1] + total_tokens = 0 - # Retrieve the long text from the cache + # Retrieve the prompt from the cache query = self.inline_queries_cache.get(unique_id) if query: self.inline_queries_cache.pop(unique_id) @@ -590,30 +591,90 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo f'{localized_text("error", bot_language)}. ' f'{localized_text("try_again", bot_language)}' ) - await self.edit_message_with_retry(context, - chat_id=None, - message_id=inline_message_id, + await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, text=f'{query}\n\n_{answer_tr}:_\n{error_message}', is_inline=True) return - # Edit the current message to indicate that the answer is being processed - await context.bot.edit_message_text(inline_message_id=inline_message_id, - text=f'{query}\n\n_{answer_tr}:_\n{loading_tr}', - parse_mode=constants.ParseMode.MARKDOWN) + if self.config['stream']: + stream_response = self.openai.get_chat_response_stream(chat_id=user_id, query=query) + i = 0 + prev = '' + sent_message = None + backoff = 0 + async for content, tokens in stream_response: + if len(content.strip()) == 0: + continue - logging.info(f'Generating response for inline query by {name}') - response, used_tokens = await self.openai.get_chat_response(chat_id=user_id, query=query) - self.add_chat_request_to_usage_tracker(user_id, used_tokens) + cutoff = 180 if len(content) > 1000 else 120 if len(content) > 200 else 90 if len( + content) > 50 else 50 + cutoff += backoff + + if i == 0: + try: + if sent_message is not None: + await self.edit_message_with_retry(context, chat_id=None, + message_id=inline_message_id, + text=f'{query}\n\n_{answer_tr}:_\n{content}', + is_inline=True) + except: + continue + + elif abs(len(content) - len(prev)) > cutoff or tokens != 'not_finished': + prev = content + try: + use_markdown = tokens != 'not_finished' + await self.edit_message_with_retry(context, + chat_id=None, message_id=inline_message_id, + text=f'{query}\n\n_{answer_tr}:_\n{content}', + markdown=use_markdown, is_inline=True) + + except RetryAfter as e: + backoff += 5 + await asyncio.sleep(e.retry_after) + continue + except TimedOut: + backoff += 5 + await asyncio.sleep(0.5) + continue + except Exception: + backoff += 5 + continue + + await asyncio.sleep(0.01) + + i += 1 + if tokens != 'not_finished': + total_tokens = int(tokens) + + else: + async def _send_inline_query_response(): + nonlocal total_tokens + # Edit the current message to indicate that the answer is being processed + await context.bot.edit_message_text(inline_message_id=inline_message_id, + text=f'{query}\n\n_{answer_tr}:_\n{loading_tr}', + parse_mode=constants.ParseMode.MARKDOWN) + + logging.info(f'Generating response for inline query by {name}') + response, total_tokens = await self.openai.get_chat_response(chat_id=user_id, query=query) + + # Edit the original message with the generated content + await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=f'{query}\n\n_{answer_tr}:_\n{response}', + is_inline=True) + + await self.wrap_with_indicator(update, context, _send_inline_query_response, + constants.ChatAction.TYPING, is_inline=True) + + self.add_chat_request_to_usage_tracker(user_id, total_tokens) - # Edit the original message with the generated content - await self.edit_message_with_retry(context, - chat_id=None, - message_id=inline_message_id, - text=f'{query}\n\n_{answer_tr}:_\n{response}', - is_inline=True) except Exception as e: logging.error(f'Failed to respond to an inline query via button callback: {e}') + logging.exception(e) + localized_answer = localized_text('chat_fail', self.config['bot_language']) + await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=f"{query}\n\n_{answer_tr}:_\n{localized_answer} {str(e)}", + is_inline=True) async def edit_message_with_retry(self, context: ContextTypes.DEFAULT_TYPE, chat_id: int | None, message_id: str, text: str, markdown: bool = True, is_inline: bool = False): @@ -652,14 +713,15 @@ async def edit_message_with_retry(self, context: ContextTypes.DEFAULT_TYPE, chat logging.warning(str(e)) raise e - async def wrap_with_indicator(self, update: Update, context: CallbackContext, chat_action: constants.ChatAction, - coroutine): + async def wrap_with_indicator(self, update: Update, context: CallbackContext, coroutine, + chat_action: constants.ChatAction = "", is_inline=False): """ Wraps a coroutine while repeatedly sending a chat action to the user. """ task = context.application.create_task(coroutine(), update=update) while not task.done(): - context.application.create_task(update.effective_chat.send_action(chat_action)) + if not is_inline: + context.application.create_task(update.effective_chat.send_action(chat_action)) try: await asyncio.wait_for(asyncio.shield(task), 4.5) except asyncio.TimeoutError: From c0b57ba92ec26ccd9fe378eb6853b7c7bb354a47 Mon Sep 17 00:00:00 2001 From: ned Date: Sat, 15 Apr 2023 23:57:39 +0200 Subject: [PATCH 2/2] relax cutoff values and only consider first 4096 chars --- bot/telegram_bot.py | 49 ++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index b88f6384..0abcafbc 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -403,7 +403,6 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if self.config['stream']: await context.bot.send_chat_action(chat_id=chat_id, action=constants.ChatAction.TYPING) - is_group_chat = self.is_group_chat(update) stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, query=prompt) i = 0 @@ -435,14 +434,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): pass continue - if is_group_chat: - # group chats have stricter flood limits - cutoff = 180 if len(content) > 1000 else 120 if len(content) > 200 else 90 if len( - content) > 50 else 50 - else: - cutoff = 90 if len(content) > 1000 else 45 if len(content) > 200 else 25 if len( - content) > 50 else 15 - + cutoff = self.get_stream_cutoff_values(update, content) cutoff += backoff if i == 0: @@ -606,8 +598,7 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo if len(content.strip()) == 0: continue - cutoff = 180 if len(content) > 1000 else 120 if len(content) > 200 else 90 if len( - content) > 50 else 50 + cutoff = self.get_stream_cutoff_values(update, content) cutoff += backoff if i == 0: @@ -615,7 +606,7 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo if sent_message is not None: await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, - text=f'{query}\n\n_{answer_tr}:_\n{content}', + text=f'{query}\n\n{answer_tr}:\n{content}', is_inline=True) except: continue @@ -624,10 +615,14 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo prev = content try: use_markdown = tokens != 'not_finished' - await self.edit_message_with_retry(context, - chat_id=None, message_id=inline_message_id, - text=f'{query}\n\n_{answer_tr}:_\n{content}', - markdown=use_markdown, is_inline=True) + divider = '_' if use_markdown else '' + text = f'{query}\n\n{divider}{answer_tr}:{divider}\n{content}' + + # We only want to send the first 4096 characters. No chunking allowed in inline mode. + text = text[:4096] + + await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=text, markdown=use_markdown, is_inline=True) except RetryAfter as e: backoff += 5 @@ -658,10 +653,14 @@ async def _send_inline_query_response(): logging.info(f'Generating response for inline query by {name}') response, total_tokens = await self.openai.get_chat_response(chat_id=user_id, query=query) + text_content = f'{query}\n\n_{answer_tr}:_\n{response}' + + # We only want to send the first 4096 characters. No chunking allowed in inline mode. + text_content = text_content[:4096] + # Edit the original message with the generated content await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, - text=f'{query}\n\n_{answer_tr}:_\n{response}', - is_inline=True) + text=text_content, is_inline=True) await self.wrap_with_indicator(update, context, _send_inline_query_response, constants.ChatAction.TYPING, is_inline=True) @@ -760,10 +759,24 @@ async def error_handler(self, update: object, context: ContextTypes.DEFAULT_TYPE """ logging.error(f'Exception while handling an update: {context.error}') + def get_stream_cutoff_values(self, update: Update, content: str) -> int: + """ + Gets the stream cutoff values for the message length + """ + if self.is_group_chat(update): + # group chats have stricter flood limits + return 180 if len(content) > 1000 else 120 if len(content) > 200 else 90 if len( + content) > 50 else 50 + else: + return 90 if len(content) > 1000 else 45 if len(content) > 200 else 25 if len( + content) > 50 else 15 + def is_group_chat(self, update: Update) -> bool: """ Checks if the message was sent from a group chat """ + if not update.effective_chat: + return False return update.effective_chat.type in [ constants.ChatType.GROUP, constants.ChatType.SUPERGROUP