Skip to content

Commit

Permalink
FEATURE: Add support for contextualizing a DM to a bot (#627)
Browse files Browse the repository at this point in the history
This brings the context of the current topic on screen into chat
  • Loading branch information
SamSaffron committed May 21, 2024
1 parent 232f12e commit d4116ec
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 12 deletions.
6 changes: 5 additions & 1 deletion app/jobs/regular/create_ai_chat_reply.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def execute(args)
user = User.find_by(id: personaClass.user_id)
bot = DiscourseAi::AiBot::Bot.as(user, persona: personaClass.new)

DiscourseAi::AiBot::Playground.new(bot).reply_to_chat_message(message, channel)
DiscourseAi::AiBot::Playground.new(bot).reply_to_chat_message(
message,
channel,
args[:context_post_ids],
)
end
end
end
22 changes: 18 additions & 4 deletions lib/ai_bot/playground.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ def self.schedule_chat_reply(message, channel, user, context)
persona = find_chat_persona(message, channel, user)
return if !persona

post_ids = nil
post_ids = context.dig(:context, :post_ids) if context.is_a?(Hash)

::Jobs.enqueue(
:create_ai_chat_reply,
channel_id: channel.id,
message_id: message.id,
persona_id: persona[:id],
context_post_ids: post_ids,
)
end

Expand Down Expand Up @@ -218,7 +222,7 @@ def title_playground(post)
end
end

def chat_context(message, channel, persona_user)
def chat_context(message, channel, persona_user, context_post_ids)
has_vision = bot.persona.class.vision_enabled
include_thread_titles = !channel.direct_message_channel? && !message.thread_id

Expand Down Expand Up @@ -260,6 +264,11 @@ def chat_context(message, channel, persona_user)

builder = DiscourseAi::Completions::PromptMessagesBuilder.new

guardian = Guardian.new(message.user)
if context_post_ids
builder.set_chat_context_posts(context_post_ids, guardian, include_uploads: has_vision)
end

messages.each do |m|
# restore stripped message
m.message = instruction_message if m.id == current_id && instruction_message
Expand All @@ -284,18 +293,23 @@ def chat_context(message, channel, persona_user)
end
end

builder.to_a(limit: max_messages, style: channel.direct_message_channel? ? :default : :chat)
builder.to_a(
limit: max_messages,
style: channel.direct_message_channel? ? :chat_with_context : :chat,
)
end

def reply_to_chat_message(message, channel)
def reply_to_chat_message(message, channel, context_post_ids)
persona_user = User.find(bot.persona.class.user_id)

participants = channel.user_chat_channel_memberships.map { |m| m.user.username }

context_post_ids = nil if !channel.direct_message_channel?

context =
get_context(
participants: participants.join(", "),
conversation_context: chat_context(message, channel, persona_user),
conversation_context: chat_context(message, channel, persona_user, context_post_ids),
user: message.user,
skip_tool_details: true,
)
Expand Down
52 changes: 46 additions & 6 deletions lib/completions/prompt_messages_builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,41 @@ module DiscourseAi
module Completions
class PromptMessagesBuilder
MAX_CHAT_UPLOADS = 5
attr_reader :chat_context_posts
attr_reader :chat_context_post_upload_ids

def initialize
@raw_messages = []
end

def set_chat_context_posts(post_ids, guardian, include_uploads:)
posts = []
Post
.where(id: post_ids)
.order("id asc")
.each do |post|
next if !guardian.can_see?(post)
posts << post
end
if posts.present?
posts_context =
+"\nThis chat is in the context of the Discourse topic '#{posts[0].topic.title}':\n\n"
posts_context = +"{{{\n"
posts.each do |post|
posts_context << "url: #{post.url}\n"
posts_context << "#{post.username}: #{post.raw}\n\n"
end
posts_context << "}}}"
@chat_context_posts = posts_context
if include_uploads
uploads = []
posts.each { |post| uploads.concat(post.uploads.pluck(:id)) }
uploads.uniq!
@chat_context_post_upload_ids = uploads.take(MAX_CHAT_UPLOADS)
end
end
end

def to_a(limit: nil, style: nil)
return chat_array(limit: limit) if style == :chat
result = []
Expand Down Expand Up @@ -51,6 +81,20 @@ def to_a(limit: nil, style: nil)
last_type = message[:type]
end

if style == :chat_with_context && @chat_context_posts
buffer = +"You are replying inside a Discourse chat."
buffer << "\n"
buffer << @chat_context_posts
buffer << "\n"
buffer << "Your instructions are:\n"
result[0][:content] = "#{buffer}#{result[0][:content]}"
if @chat_context_post_upload_ids.present?
result[0][:upload_ids] = (result[0][:upload_ids] || []).concat(
@chat_context_post_upload_ids,
)
end
end

if limit
result[0..limit]
else
Expand All @@ -75,13 +119,9 @@ def push(type:, content:, name: nil, upload_ids: nil, id: nil)
private

def chat_array(limit:)
buffer = +""

if @raw_messages.length > 1
buffer << (<<~TEXT).strip
You are replying inside a Discourse chat. Here is a summary of the conversation so far:
{{{
TEXT
buffer =
+"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{"

upload_ids = []

Expand Down
26 changes: 25 additions & 1 deletion spec/lib/modules/ai_bot/playground_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@
content = prompt.messages[1][:content]
# this is fragile by design, mainly so the example can be ultra clear
expected = (<<~TEXT).strip
You are replying inside a Discourse chat. Here is a summary of the conversation so far:
You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:
{{{
#{user.username}: (a magic thread)
thread 1 message 1
Expand Down Expand Up @@ -265,6 +265,30 @@

let(:guardian) { Guardian.new(user) }

it "can supply context" do
post = Fabricate(:post, raw: "this is post content")

prompts = nil
message =
DiscourseAi::Completions::Llm.with_prepared_responses(["World"]) do |_, _, _prompts|
prompts = _prompts

::Chat::CreateMessage.call!(
chat_channel_id: dm_channel.id,
message: "Hello",
guardian: guardian,
context_post_ids: [post.id],
).message_instance
end

expect(prompts[0].messages[1][:content]).to include("this is post content")

message.reload
reply = ChatSDK::Thread.messages(thread_id: message.thread_id, guardian: guardian).last
expect(reply.message).to eq("World")
expect(message.thread_id).to be_present
end

it "can run tools" do
persona.update!(commands: ["TimeCommand"])

Expand Down

0 comments on commit d4116ec

Please sign in to comment.