diff --git a/app/jobs/regular/create_ai_chat_reply.rb b/app/jobs/regular/create_ai_chat_reply.rb index 48258c72..41776a3b 100644 --- a/app/jobs/regular/create_ai_chat_reply.rb +++ b/app/jobs/regular/create_ai_chat_reply.rb @@ -18,7 +18,11 @@ def execute(args) user = User.find_by(id: personaClass.user_id) bot = DiscourseAi::AiBot::Bot.as(user, persona: personaClass.new) - DiscourseAi::AiBot::Playground.new(bot).reply_to_chat_message(message, channel) + DiscourseAi::AiBot::Playground.new(bot).reply_to_chat_message( + message, + channel, + args[:context_post_ids], + ) end end end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index d6418bd5..5d752747 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -37,11 +37,15 @@ def self.schedule_chat_reply(message, channel, user, context) persona = find_chat_persona(message, channel, user) return if !persona + post_ids = nil + post_ids = context.dig(:context, :post_ids) if context.is_a?(Hash) + ::Jobs.enqueue( :create_ai_chat_reply, channel_id: channel.id, message_id: message.id, persona_id: persona[:id], + context_post_ids: post_ids, ) end @@ -218,7 +222,7 @@ def title_playground(post) end end - def chat_context(message, channel, persona_user) + def chat_context(message, channel, persona_user, context_post_ids) has_vision = bot.persona.class.vision_enabled include_thread_titles = !channel.direct_message_channel? && !message.thread_id @@ -260,6 +264,11 @@ def chat_context(message, channel, persona_user) builder = DiscourseAi::Completions::PromptMessagesBuilder.new + guardian = Guardian.new(message.user) + if context_post_ids + builder.set_chat_context_posts(context_post_ids, guardian, include_uploads: has_vision) + end + messages.each do |m| # restore stripped message m.message = instruction_message if m.id == current_id && instruction_message @@ -284,18 +293,23 @@ def chat_context(message, channel, persona_user) end end - builder.to_a(limit: max_messages, style: channel.direct_message_channel? ? :default : :chat) + builder.to_a( + limit: max_messages, + style: channel.direct_message_channel? ? :chat_with_context : :chat, + ) end - def reply_to_chat_message(message, channel) + def reply_to_chat_message(message, channel, context_post_ids) persona_user = User.find(bot.persona.class.user_id) participants = channel.user_chat_channel_memberships.map { |m| m.user.username } + context_post_ids = nil if !channel.direct_message_channel? + context = get_context( participants: participants.join(", "), - conversation_context: chat_context(message, channel, persona_user), + conversation_context: chat_context(message, channel, persona_user, context_post_ids), user: message.user, skip_tool_details: true, ) diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb index faed03d1..8fbd70f8 100644 --- a/lib/completions/prompt_messages_builder.rb +++ b/lib/completions/prompt_messages_builder.rb @@ -4,11 +4,41 @@ module DiscourseAi module Completions class PromptMessagesBuilder MAX_CHAT_UPLOADS = 5 + attr_reader :chat_context_posts + attr_reader :chat_context_post_upload_ids def initialize @raw_messages = [] end + def set_chat_context_posts(post_ids, guardian, include_uploads:) + posts = [] + Post + .where(id: post_ids) + .order("id asc") + .each do |post| + next if !guardian.can_see?(post) + posts << post + end + if posts.present? + posts_context = + +"\nThis chat is in the context of the Discourse topic '#{posts[0].topic.title}':\n\n" + posts_context = +"{{{\n" + posts.each do |post| + posts_context << "url: #{post.url}\n" + posts_context << "#{post.username}: #{post.raw}\n\n" + end + posts_context << "}}}" + @chat_context_posts = posts_context + if include_uploads + uploads = [] + posts.each { |post| uploads.concat(post.uploads.pluck(:id)) } + uploads.uniq! + @chat_context_post_upload_ids = uploads.take(MAX_CHAT_UPLOADS) + end + end + end + def to_a(limit: nil, style: nil) return chat_array(limit: limit) if style == :chat result = [] @@ -51,6 +81,20 @@ def to_a(limit: nil, style: nil) last_type = message[:type] end + if style == :chat_with_context && @chat_context_posts + buffer = +"You are replying inside a Discourse chat." + buffer << "\n" + buffer << @chat_context_posts + buffer << "\n" + buffer << "Your instructions are:\n" + result[0][:content] = "#{buffer}#{result[0][:content]}" + if @chat_context_post_upload_ids.present? + result[0][:upload_ids] = (result[0][:upload_ids] || []).concat( + @chat_context_post_upload_ids, + ) + end + end + if limit result[0..limit] else @@ -75,13 +119,9 @@ def push(type:, content:, name: nil, upload_ids: nil, id: nil) private def chat_array(limit:) - buffer = +"" - if @raw_messages.length > 1 - buffer << (<<~TEXT).strip - You are replying inside a Discourse chat. Here is a summary of the conversation so far: - {{{ - TEXT + buffer = + +"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{" upload_ids = [] diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index ea5d1a32..6fbe0f26 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -204,7 +204,7 @@ content = prompt.messages[1][:content] # this is fragile by design, mainly so the example can be ultra clear expected = (<<~TEXT).strip - You are replying inside a Discourse chat. Here is a summary of the conversation so far: + You are replying inside a Discourse chat channel. Here is a summary of the conversation so far: {{{ #{user.username}: (a magic thread) thread 1 message 1 @@ -265,6 +265,30 @@ let(:guardian) { Guardian.new(user) } + it "can supply context" do + post = Fabricate(:post, raw: "this is post content") + + prompts = nil + message = + DiscourseAi::Completions::Llm.with_prepared_responses(["World"]) do |_, _, _prompts| + prompts = _prompts + + ::Chat::CreateMessage.call!( + chat_channel_id: dm_channel.id, + message: "Hello", + guardian: guardian, + context_post_ids: [post.id], + ).message_instance + end + + expect(prompts[0].messages[1][:content]).to include("this is post content") + + message.reload + reply = ChatSDK::Thread.messages(thread_id: message.thread_id, guardian: guardian).last + expect(reply.message).to eq("World") + expect(message.thread_id).to be_present + end + it "can run tools" do persona.update!(commands: ["TimeCommand"])