Skip to content

Commit

Permalink
cool feature preview: streaming responses!
Browse files Browse the repository at this point in the history
It's a little jerky, and a lot hacky, but it kinda works!
Pass in --stream-responses to have the bot's responses streamed
into a single message.

This works by continuously editing the bot's response message.
So you'll see an "edited" tag on the message.  But if you can
ignore that, that's cool.

Also, Discord seems to have some limits on how fast edits can be
made, so it waits at least 0.2 seconds in between update attempts.
It may wait longer than this, but it shouldn't be any faster.

I don't consider this quite baked yet, but I wanted to put it
out there and see what people thought.
  • Loading branch information
chrisrude committed May 16, 2023
1 parent 723741c commit 823b52a
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ j = 8

[tool.pylint.'MESSAGES CONTROL']
max-line-length = 88
disable = "C0103,C0114,C0115,C0116,R0902,R0903,R0913,W0201,W0511,W0621"
disable = "C0103,C0114,C0115,C0116,R0902,R0903,R0913,R0914,W0201,W0511,W0621"
include-naming-hint = true
110 changes: 86 additions & 24 deletions src/oobabot/discord_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
persona: str,
ignore_dms: bool,
dont_split_responses: bool,
stream_responses: bool,
reply_in_thread: bool,
log_all_the_things: bool,
):
Expand All @@ -52,6 +53,7 @@ def __init__(

self.ignore_dms = ignore_dms
self.dont_split_responses = dont_split_responses
self.stream_responses = stream_responses
self.reply_in_thread = reply_in_thread
self.log_all_the_things = log_all_the_things

Expand Down Expand Up @@ -88,6 +90,8 @@ async def on_ready(self) -> None:
else:
fancy_logger.get().debug("listening to DMs")

if self.stream_responses:
fancy_logger.get().debug("Responses: streamed live")
if self.dont_split_responses:
fancy_logger.get().debug("Responses: returned as single messages")
else:
Expand Down Expand Up @@ -245,9 +249,9 @@ async def send_response(

response_coro = self.send_response_in_channel(
message=message,
raw_message=raw_message,
image_requested=image_requested,
response_channel=response_channel,
response_channel_id=response_channel.id,
)
response_task = asyncio.create_task(response_coro)
return (response_task, response_channel)
Expand Down Expand Up @@ -295,9 +299,9 @@ async def recent_messages_following_thread(
async def send_response_in_channel(
self,
message: types.GenericMessage,
raw_message: discord.Message,
image_requested: bool,
response_channel: discord.abc.Messageable,
response_channel_id: int,
) -> None:
fancy_logger.get().debug(
"Request from %s in %s", message.author_name, message.channel_name
Expand All @@ -306,7 +310,7 @@ async def send_response_in_channel(
recent_messages = await self.recent_messages_following_thread(response_channel)

repeated_id = self.repetition_tracker.get_throttle_message_id(
raw_message.channel.id
response_channel_id
)

prompt_prefix = await self.prompt_generator.generate(
Expand All @@ -323,30 +327,23 @@ async def send_response_in_channel(
print("Response:\n----------\n")

try:
if self.dont_split_responses:
generator = self.ooba_client.request_as_string(prompt_prefix)
else:
generator = self.ooba_client.request_by_sentence(prompt_prefix)

async for sentence in generator:
if self.log_all_the_things:
print(sentence)

sentence = self.filter_immersion_breaking_lines(sentence)
if not sentence:
# we can't send an empty message
continue

response_message = await response_channel.send(sentence)
generic_response_message = (
discord_utils.discord_message_to_generic_message(response_message)
if self.stream_responses:
generator = self.ooba_client.request_as_grouped_tokens(prompt_prefix)
await self.render_streaming_response(
generator,
this_response_stat,
response_channel,
response_channel_id,
)
self.repetition_tracker.log_message(
raw_message.channel.id, generic_response_message
else:
if self.dont_split_responses:
generator = self.ooba_client.request_as_string(prompt_prefix)
else:
generator = self.ooba_client.request_by_sentence(prompt_prefix)
await self.render_response(
generator, this_response_stat, response_channel, response_channel_id
)

this_response_stat.log_response_part()

except discord.DiscordException as err:
fancy_logger.get().error("Error: %s", err, exc_info=True)
self.response_stats.log_response_failure()
Expand All @@ -355,6 +352,71 @@ async def send_response_in_channel(
this_response_stat.write_to_log(f"Response to {message.author_name} done! ")
self.response_stats.log_response_success(this_response_stat)

async def render_response(
self,
response_iterator: typing.AsyncIterator[str],
this_response_stat: response_stats.ResponseStats,
response_channel: discord.abc.Messageable,
response_channel_id: int,
):
async for sentence in response_iterator:
if self.log_all_the_things:
print(sentence)

sentence = self.filter_immersion_breaking_lines(sentence)
if not sentence:
# we can't send an empty message
continue

response_message = await response_channel.send(sentence)
generic_response_message = discord_utils.discord_message_to_generic_message(
response_message
)
self.repetition_tracker.log_message(
response_channel_id, generic_response_message
)

this_response_stat.log_response_part()

async def render_streaming_response(
self,
response_iterator: typing.AsyncIterator[str],
this_response_stat: response_stats.ResponseStats,
response_channel: discord.abc.Messageable,
response_channel_id: int,
):
response = ""
last_message = None
async for token in response_iterator:
if self.log_all_the_things:
print(token, end="")

if "" == token:
continue

response += token
response = self.filter_immersion_breaking_lines(response)
if not response:
# we can't send an empty message
continue

if last_message is None:
last_message = await response_channel.send(response)
else:
await last_message.edit(content=response)

this_response_stat.log_response_part()

if last_message is None:
raise discord.DiscordException("No response was generated")

generic_response_message = discord_utils.discord_message_to_generic_message(
last_message
)
self.repetition_tracker.log_message(
response_channel_id, generic_response_message
)

def filter_immersion_breaking_lines(self, sentence: str) -> str:
lines = sentence.split("\n")
good_lines = []
Expand Down
29 changes: 28 additions & 1 deletion src/oobabot/ooba_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Purpose: Split a string into sentences, based on a set of terminators.
# This is a helper class for ooba_client.py.
import time
import typing

import aiohttp
Expand Down Expand Up @@ -119,6 +120,30 @@ async def request_as_string(self, prompt: str) -> typing.AsyncIterator[str]:
"""
yield "".join([token async for token in self.request_by_token(prompt)])

async def request_as_grouped_tokens(
self,
prompt: str,
interval: float = 0.2,
) -> typing.AsyncIterator[str]:
"""
Yields the response as a series of tokens, grouped by time.
"""

last_response = time.perf_counter()
tokens = ""
async for token in self.request_by_token(prompt):
if token == SentenceSplitter.END_OF_INPUT:
if tokens:
yield tokens
break
tokens += token
now = time.perf_counter()
if now < (last_response + interval):
continue
yield tokens
tokens = ""
last_response = time.perf_counter()

async def request_by_token(self, prompt: str) -> typing.AsyncIterator[str]:
"""
Yields each token of the response as it arrives.
Expand Down Expand Up @@ -154,7 +179,9 @@ async def request_by_token(self, prompt: str) -> typing.AsyncIterator[str]:
incoming_data = msg.json()
if "text_stream" == incoming_data["event"]:
self.total_response_tokens += 1
yield incoming_data["text"]
text = incoming_data["text"]
if text != SentenceSplitter.END_OF_INPUT:
yield text

elif "stream_end" == incoming_data["event"]:
# Make sure any unprinted text is flushed.
Expand Down
1 change: 1 addition & 0 deletions src/oobabot/oobabot.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def run(self):
persona=self.settings.persona,
ignore_dms=self.settings.get_bool("ignore_dms"),
dont_split_responses=self.settings.get_bool("dont_split_responses"),
stream_responses=self.settings.get_bool("stream_responses"),
reply_in_thread=self.settings.get_bool("reply_in_thread"),
log_all_the_things=self.settings.get_bool("log_all_the_things"),
)
Expand Down
6 changes: 6 additions & 0 deletions src/oobabot/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def __init__(self):
help="If not in a thread, generate one to respond into.",
action="store_true",
)
discord_group.add_argument(
"--stream-responses",
default=False,
help="Stream responses into a single message as it is generated.",
action="store_true",
)

###########################################################
# Oobabooga Settings
Expand Down

0 comments on commit 823b52a

Please sign in to comment.