Skip to content

Commit

Permalink
replace <@___user_id___> with the user's display name in history
Browse files Browse the repository at this point in the history
When presenting the message history to the bot, replace the
<@00000000> syntax with the user's display name instead.  This is
to help the bot understand the message history, as well as to have
it not "leak" the encoding into its response.

Also, fix a regression in the last commit where we would include
extra lins in the message history.  Ooops!
  • Loading branch information
chrisrude committed May 20, 2023
1 parent 6e2ebcf commit ac226cc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/oobabot/discord_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,20 @@ async def _filter_history_message(
# this is a message generated by our image generator
return (None, True)

if message.channel.guild is None:
fn_user_id_to_name = discord_utils.dm_user_id_to_name(
self.ai_user_id,
self.persona.ai_name,
)
else:
fn_user_id_to_name = discord_utils.guild_user_id_to_name(
message.channel.guild,
)

discord_utils.replace_mention_ids_with_names(
generic_message,
fn_user_id_to_name=fn_user_id_to_name,
)
return (generic_message, True)

async def _filtered_history_iterator(
Expand All @@ -617,6 +631,9 @@ async def _filtered_history_iterator(
last_returned = None
ignoring_all = ignore_all_until_message_id is not None
async for item in async_iter_history:
if items >= limit:
return

if ignoring_all:
if item.id == ignore_all_until_message_id:
ignoring_all = False
Expand Down
60 changes: 60 additions & 0 deletions src/oobabot/discord_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,63 @@ def discord_message_to_generic_message(
+ f"unsolicited replies disabled.: {raw_message.channel}"
)
return types.GenericMessage(**generic_args)


def replace_mention_ids_with_names(
generic_message: types.GenericMessage,
fn_user_id_to_name: typing.Callable[[re.Match[str]], str],
):
"""
Replace user ID mentions with the user's chosen display
name in the given guild (aka server)
"""
# it looks like normal IDs are 18 digits. But give it some
# wiggle room in case things change in the future.
# e.g.: <@009999999999999999>
at_mention_pattern = r"<@(\d{16,20})>"
while True:
match = re.search(at_mention_pattern, generic_message.body_text)
if not match:
break
generic_message.body_text = (
generic_message.body_text[: match.start()]
+ fn_user_id_to_name(match)
+ generic_message.body_text[match.end() :]
)


def dm_user_id_to_name(
bot_user_id: int,
bot_name: str,
) -> typing.Callable[[re.Match[str]], str]:
"""
Replace user ID mentions with the bot's name. Used when
we are in a DM with the bot.
"""
if " " in bot_name:
bot_name = f'"{bot_name}"'

def _replace_user_id_mention(match: typing.Match[str]) -> str:
user_id = int(match.group(1))
print(f"bot_user_id={bot_user_id}, user_id={user_id}")
if user_id == bot_user_id:
return f"@{bot_name}"
return match.group(0)

return _replace_user_id_mention


def guild_user_id_to_name(
guild: discord.Guild,
) -> typing.Callable[[re.Match[str]], str]:
def _replace_user_id_mention(match: typing.Match[str]) -> str:
user_id = int(match.group(1))
member = guild.get_member(user_id)
if member is None:
return match.group(0)
display_name = member.display_name
if " " in display_name:
display_name = f'"{display_name}"'
return f"@{display_name}"

return _replace_user_id_mention

0 comments on commit ac226cc

Please sign in to comment.