Skip to content

Commit

Permalink
Merge pull request #234 from n3d1117/feature/topic-support
Browse files Browse the repository at this point in the history
Support for telegram topics
  • Loading branch information
n3d1117 committed Apr 18, 2023
2 parents 3cc6948 + 10f47d6 commit c837eb5
Showing 1 changed file with 56 additions and 36 deletions.
92 changes: 56 additions & 36 deletions bot/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ async def resend(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if chat_id not in self.last_message:
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})'
f' does not have anything to resend')
await context.bot.send_message(chat_id=chat_id,
text=localized_text('resend_failed', self.config['bot_language']))
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
text=localized_text('resend_failed', self.config['bot_language'])
)
return

# Update message text, clear self.last_message and send the request to prompt
Expand All @@ -199,7 +201,10 @@ async def reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
chat_id = update.effective_chat.id
reset_content = message_text(update.message)
self.openai.reset_chat_history(chat_id=chat_id, content=reset_content)
await context.bot.send_message(chat_id=chat_id, text=localized_text('reset_done', self.config['bot_language']))
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
text=localized_text('reset_done', self.config['bot_language'])
)

async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Expand All @@ -212,8 +217,10 @@ async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
chat_id = update.effective_chat.id
image_query = message_text(update.message)
if image_query == '':
await context.bot.send_message(chat_id=chat_id,
text=localized_text('image_no_prompt', self.config['bot_language']))
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
text=localized_text('image_no_prompt', self.config['bot_language'])
)
return

logging.info(f'New image generation request received from user {update.message.from_user.name} '
Expand All @@ -222,8 +229,7 @@ async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
async def _generate():
try:
image_url, image_size = await self.openai.generate_image(prompt=image_query)
await context.bot.send_photo(
chat_id=chat_id,
await update.effective_message.reply_photo(
reply_to_message_id=self.get_reply_to_message_id(update),
photo=image_url
)
Expand All @@ -236,8 +242,8 @@ async def _generate():

except Exception as e:
logging.exception(e)
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update),
text=f"{localized_text('image_fail', self.config['bot_language'])}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
Expand Down Expand Up @@ -267,8 +273,8 @@ async def _execute():
await media_file.download_to_drive(filename)
except Exception as e:
logging.exception(e)
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update),
text=(
f"{localized_text('media_download_fail', bot_language)[0]}: "
Expand All @@ -287,8 +293,8 @@ async def _execute():

except Exception as e:
logging.exception(e)
await context.bot.send_message(
chat_id=update.effective_chat.id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update),
text=localized_text('media_type_fail', bot_language)
)
Expand Down Expand Up @@ -322,8 +328,8 @@ async def _execute():
chunks = self.split_into_chunks(transcript_output)

for index, transcript_chunk in enumerate(chunks):
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
text=transcript_chunk,
parse_mode=constants.ParseMode.MARKDOWN
Expand All @@ -346,17 +352,17 @@ async def _execute():
chunks = self.split_into_chunks(transcript_output)

for index, transcript_chunk in enumerate(chunks):
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
text=transcript_chunk,
parse_mode=constants.ParseMode.MARKDOWN
)

except Exception as e:
logging.exception(e)
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update),
text=f"{localized_text('transcribe_fail', bot_language)}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
Expand Down Expand Up @@ -402,7 +408,10 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
total_tokens = 0

if self.config['stream']:
await context.bot.send_chat_action(chat_id=chat_id, action=constants.ChatAction.TYPING)
await update.effective_message.reply_chat_action(
action=constants.ChatAction.TYPING,
message_thread_id=self.get_thread_id(update)
)

stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, query=prompt)
i = 0
Expand All @@ -426,8 +435,8 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
except:
pass
try:
sent_message = await context.bot.send_message(
chat_id=sent_message.chat_id,
sent_message = await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
text=content if len(content) > 0 else "..."
)
except:
Expand All @@ -442,8 +451,8 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if sent_message is not None:
await context.bot.delete_message(chat_id=sent_message.chat_id,
message_id=sent_message.message_id)
sent_message = await context.bot.send_message(
chat_id=chat_id,
sent_message = await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update),
text=content
)
Expand Down Expand Up @@ -488,16 +497,16 @@ async def _reply():

for index, chunk in enumerate(chunks):
try:
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
text=chunk,
parse_mode=constants.ParseMode.MARKDOWN
)
except Exception:
try:
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
text=chunk
)
Expand All @@ -510,8 +519,8 @@ async def _reply():

except Exception as e:
logging.exception(e)
await context.bot.send_message(
chat_id=chat_id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
reply_to_message_id=self.get_reply_to_message_id(update),
text=f"{localized_text('chat_fail', self.config['bot_language'])} {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
Expand Down Expand Up @@ -701,7 +710,8 @@ async def edit_message_with_retry(self, context: ContextTypes.DEFAULT_TYPE, chat
try:
await context.bot.edit_message_text(
chat_id=chat_id,
message_id=message_id,
message_id=int(message_id) if not is_inline else None,
inline_message_id=message_id if is_inline else None,
text=text
)
except Exception as e:
Expand All @@ -720,7 +730,9 @@ async def wrap_with_indicator(self, update: Update, context: CallbackContext, co
task = context.application.create_task(coroutine(), update=update)
while not task.done():
if not is_inline:
context.application.create_task(update.effective_chat.send_action(chat_action))
context.application.create_task(
update.effective_chat.send_action(chat_action, message_thread_id=self.get_thread_id(update))
)
try:
await asyncio.wait_for(asyncio.shield(task), 4.5)
except asyncio.TimeoutError:
Expand All @@ -731,8 +743,8 @@ async def send_disallowed_message(self, update: Update, context: ContextTypes.DE
Sends the disallowed message to the user.
"""
if not is_inline:
await context.bot.send_message(
chat_id=update.effective_chat.id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
text=self.disallowed_message,
disable_web_page_preview=True
)
Expand All @@ -745,8 +757,8 @@ async def send_budget_reached_message(self, update: Update, context: ContextType
Sends the budget reached message to the user.
"""
if not is_inline:
await context.bot.send_message(
chat_id=update.effective_chat.id,
await update.effective_message.reply_text(
message_thread_id=self.get_thread_id(update),
text=self.budget_limit_message
)
else:
Expand All @@ -759,6 +771,14 @@ async def error_handler(self, update: object, context: ContextTypes.DEFAULT_TYPE
"""
logging.error(f'Exception while handling an update: {context.error}')

def get_thread_id(self, update: Update) -> int | None:
"""
Gets the message thread id for the update, if any
"""
if update.effective_message and update.effective_message.is_topic_message:
return update.effective_message.message_thread_id
return None

def get_stream_cutoff_values(self, update: Update, content: str) -> int:
"""
Gets the stream cutoff values for the message length
Expand Down

0 comments on commit c837eb5

Please sign in to comment.