From bfea49fad9c7abc2a62ca98fd5e9b50d4304131c Mon Sep 17 00:00:00 2001 From: Keegan George Date: Wed, 12 Mar 2025 13:02:57 -0700 Subject: [PATCH 1/2] DEV: Use existing topic embeddings when suggesting... When editing a topic (instead of creating one) and using the tag/category suggestion buttons. We want to use existing topic embeddings instead of creating new ones. --- .../ai_helper/assistant_controller.rb | 24 +++++++------ .../ai-category-suggester.gjs | 32 +++++++---------- .../suggestion-menus/ai-tag-suggester.gjs | 34 ++++++++----------- .../ai-category-suggestion.gjs | 2 +- .../ai-tag-suggestion.gjs | 2 +- .../ai-category-suggestion.gjs | 5 ++- .../ai-tag-suggestion.gjs | 2 +- lib/ai_helper/semantic_categorizer.rb | 21 +++++++++--- .../ai_helper/semantic_categorizer_spec.rb | 2 +- 9 files changed, 66 insertions(+), 58 deletions(-) diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index 0d0f85ee6..d2ee87387 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -78,22 +78,26 @@ def suggest_title end def suggest_category - input = get_text_param! - input_hash = { text: input } + if params[:topic_id] + opts = { topic_id: params[:topic_id] } + else + input = get_text_param! + opts = { text: input } + end - render json: - DiscourseAi::AiHelper::SemanticCategorizer.new( - input_hash, - current_user, - ).categories, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).categories, status: 200 end def suggest_tags - input = get_text_param! - input_hash = { text: input } + if params[:topic_id] + opts = { topic_id: params[:topic_id] } + else + input = get_text_param! + opts = { text: input } + end - render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).tags, status: 200 end diff --git a/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs b/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs index 18051b52f..ff44af916 100644 --- a/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs +++ b/assets/javascripts/discourse/components/suggestion-menus/ai-category-suggester.gjs @@ -20,27 +20,13 @@ export default class AiCategorySuggester extends Component { @tracked untriggers = []; @tracked triggerIcon = "discourse-sparkles"; @tracked content = null; - @tracked topicContent = null; - - constructor() { - super(...arguments); - if (!this.topicContent && this.args.composer?.reply === undefined) { - this.fetchTopicContent(); - } - } - - async fetchTopicContent() { - await ajax(`/t/${this.args.buffered.content.id}.json`).then( - ({ post_stream }) => { - this.topicContent = post_stream.posts[0].cooked; - } - ); - } get showSuggestionButton() { const composerFields = document.querySelector(".composer-fields"); - this.content = this.args.composer?.reply || this.topicContent; - const showTrigger = this.content?.length > MIN_CHARACTER_COUNT; + this.content = this.args.composer?.reply; + const showTrigger = + this.content?.length > MIN_CHARACTER_COUNT || + this.args.topicState === "edit"; if (composerFields) { if (showTrigger) { @@ -62,12 +48,20 @@ export default class AiCategorySuggester extends Component { this.loading = true; this.triggerIcon = "spinner"; + const data = {}; + + if (this.content) { + data.text = this.content; + } else { + data.topic_id = this.args.buffered.content.id; + } + try { const { assistant } = await ajax( "/discourse-ai/ai-helper/suggest_category", { method: "POST", - data: { text: this.content }, + data, } ); this.suggestions = assistant; diff --git a/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs b/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs index 47ad38f6d..84ce19589 100644 --- a/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs +++ b/assets/javascripts/discourse/components/suggestion-menus/ai-tag-suggester.gjs @@ -21,27 +21,13 @@ export default class AiTagSuggester extends Component { @tracked untriggers = []; @tracked triggerIcon = "discourse-sparkles"; @tracked content = null; - @tracked topicContent = null; - - constructor() { - super(...arguments); - if (!this.topicContent && this.args.composer?.reply === undefined) { - this.fetchTopicContent(); - } - } - - async fetchTopicContent() { - await ajax(`/t/${this.args.buffered.content.id}.json`).then( - ({ post_stream }) => { - this.topicContent = post_stream.posts[0].cooked; - } - ); - } get showSuggestionButton() { const composerFields = document.querySelector(".composer-fields"); - this.content = this.args.composer?.reply || this.topicContent; - const showTrigger = this.content?.length > MIN_CHARACTER_COUNT; + this.content = this.args.composer?.reply; + const showTrigger = + this.content?.length > MIN_CHARACTER_COUNT || + this.args.topicState === "edit"; if (composerFields) { if (showTrigger) { @@ -74,15 +60,25 @@ export default class AiTagSuggester extends Component { this.loading = true; this.triggerIcon = "spinner"; + const data = {}; + + if (this.content) { + data.text = this.content; + } else { + data.topic_id = this.args.buffered.content.id; + } + try { const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", { method: "POST", - data: { text: this.content }, + data, }); this.suggestions = assistant; + const model = this.args.composer ? this.args.composer : this.args.buffered; + if (this.#tagSelectorHasValues()) { this.suggestions = this.suggestions.filter( (s) => !model.get("tags").includes(s.name) diff --git a/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs b/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs index d7bef6428..47cc9d88b 100644 --- a/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/after-composer-category-input/ai-category-suggestion.gjs @@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component { } } diff --git a/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs b/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs index ac6ad686e..9f02ed19d 100644 --- a/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/after-composer-tag-input/ai-tag-suggestion.gjs @@ -13,6 +13,6 @@ export default class AiTagSuggestion extends Component { } } diff --git a/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs b/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs index c1f5c0c01..1dcf34830 100644 --- a/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/edit-topic-category__after/ai-category-suggestion.gjs @@ -13,6 +13,9 @@ export default class AiCategorySuggestion extends Component { } } diff --git a/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs b/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs index 3ab8656f0..7404822b7 100644 --- a/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs +++ b/assets/javascripts/discourse/connectors/edit-topic-tags__after/ai-tag-suggestion.gjs @@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component { } } diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index b05c3ece3..f1a588577 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -2,15 +2,16 @@ module DiscourseAi module AiHelper class SemanticCategorizer - def initialize(input, user) + def initialize(user, opts) @user = user - @text = input[:text] + @text = opts[:text] @vector = DiscourseAi::Embeddings::Vector.instance @schema = DiscourseAi::Embeddings::Schema.for(Topic) + @topic_id = opts[:topic_id] end def categories - return [] if @text.blank? + return [] if @text.blank? && !@topic_id return [] if !DiscourseAi::Embeddings.enabled? candidates = nearest_neighbors @@ -55,7 +56,7 @@ def categories end def tags - return [] if @text.blank? + return [] if @text.blank? && !@topic_id return [] if !DiscourseAi::Embeddings.enabled? candidates = nearest_neighbors(limit: 100) @@ -100,7 +101,17 @@ def tags private def nearest_neighbors(limit: 50) - raw_vector = @vector.vector_from(@text) + if @topic_id + table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE + embeddings = + DB + .query("SELECT embeddings::text FROM #{table_name} WHERE topic_id=#{@topic_id}") + .first + .embeddings + raw_vector = JSON.parse(embeddings) + else + raw_vector = @vector.vector_from(@text) + end muted_category_ids = nil if @user.present? diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb index bbbfe6af4..4390959b5 100644 --- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -16,7 +16,7 @@ fab!(:topic) { Fabricate(:topic, category: category) } let(:vector) { DiscourseAi::Embeddings::Vector.instance } - let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } + let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new(user, { text: "hello" }) } let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } before do From 37d4c36971963d5291afe752d5679f7351f1b2ee Mon Sep 17 00:00:00 2001 From: Keegan George Date: Wed, 12 Mar 2025 14:52:32 -0700 Subject: [PATCH 2/2] DEV: Follow-up from review --- lib/ai_helper/semantic_categorizer.rb | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index f1a588577..488741de8 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -11,7 +11,7 @@ def initialize(user, opts) end def categories - return [] if @text.blank? && !@topic_id + return [] if @text.blank? && @topic_id.nil? return [] if !DiscourseAi::Embeddings.enabled? candidates = nearest_neighbors @@ -56,7 +56,7 @@ def categories end def tags - return [] if @text.blank? && !@topic_id + return [] if @text.blank? && @topic_id.nil? return [] if !DiscourseAi::Embeddings.enabled? candidates = nearest_neighbors(limit: 100) @@ -102,13 +102,19 @@ def tags def nearest_neighbors(limit: 50) if @topic_id - table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE - embeddings = - DB - .query("SELECT embeddings::text FROM #{table_name} WHERE topic_id=#{@topic_id}") - .first - .embeddings - raw_vector = JSON.parse(embeddings) + target = Topic.find_by(id: @topic_id) + embeddings = @schema.find_by_target(target)&.embeddings + + if embeddings.blank? + @text = + DiscourseAi::Summarization::Strategies::TopicSummary + .new(target) + .targets_data + .pluck(:text) + raw_vector = @vector.vector_from(@text) + else + raw_vector = JSON.parse(embeddings) + end else raw_vector = @vector.vector_from(@text) end