Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions lib/embeddings/vector_representations/all_mpnet_base_v2.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@ def dependant_setting_names
end

def vector_from(text, asymetric: false)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
self.class.name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
inference_client.perform!(text)
end

def dimensions
Expand Down Expand Up @@ -59,6 +54,10 @@ def pg_index_type
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end

def inference_client
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
end
end
end
end
Expand Down
12 changes: 2 additions & 10 deletions lib/embeddings/vector_representations/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -426,16 +426,8 @@ def save_to_db(target, vector, digest)
end
end

def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
def inference_client
raise NotImplementedError
end
end
end
Expand Down
39 changes: 21 additions & 18 deletions lib/embeddings/vector_representations/bge_large_en.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,12 @@ def dependant_setting_names
def vector_from(text, asymetric: false)
text = "#{asymmetric_query_prefix} #{text}" if asymetric

if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi
.perform!(inference_model_name, { text: text })
.dig(:result, :data)
.first
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
inference_model_name.split("/").last,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
else
raise "No inference endpoint configured"
end
client = inference_client

needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation

inference_client.perform!(text)
end

def inference_model_name
Expand Down Expand Up @@ -88,6 +76,21 @@ def tokenizer
def asymmetric_query_prefix
"Represent this sentence for searching relevant passages:"
end

def inference_client
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.instance(
inference_model_name.split("/").last,
)
else
raise "No inference endpoint configured"
end
end
end
end
end
Expand Down
6 changes: 5 additions & 1 deletion lib/embeddings/vector_representations/bge_m3.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dependant_setting_names

def vector_from(text, asymetric: false)
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
inference_client.perform!(truncated_text)
end

def dimensions
Expand Down Expand Up @@ -50,6 +50,10 @@ def pg_index_type
def tokenizer
DiscourseAi::Tokenizer::BgeM3Tokenizer
end

def inference_client
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
end
end
end
end
Expand Down
7 changes: 5 additions & 2 deletions lib/embeddings/vector_representations/gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def pg_index_type
end

def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
response[:embedding][:values]
inference_client.perform!(text).dig(:embedding, :values)
end

# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
Expand All @@ -53,6 +52,10 @@ def vector_from(text, asymetric: false)
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

def inference_client
DiscourseAi::Inference::GeminiEmbeddings.instance
end
end
end
end
Expand Down
30 changes: 19 additions & 11 deletions lib/embeddings/vector_representations/multilingual_e5_large.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,16 @@ def dependant_setting_names
end

def vector_from(text, asymetric: false)
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
self.class.name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
client = inference_client

needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2)
else
raise "No inference endpoint configured"
text = "query: #{text}"
end

client.perform!(text)
end

def id
Expand Down Expand Up @@ -71,6 +68,17 @@ def pg_index_type
def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end

def inference_client
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
else
raise "No inference endpoint configured"
end
end
end
end
end
Expand Down
15 changes: 8 additions & 7 deletions lib/embeddings/vector_representations/text_embedding_3_large.rb
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ def pg_index_type
end

def vector_from(text, asymetric: false)
response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text,
model: self.class.name,
dimensions: dimensions,
)
response[:data].first[:embedding]
inference_client.perform!(text)
end

def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(
model: self.class.name,
dimensions: dimensions,
)
end
end
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def pg_index_type
end

def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
inference_client.perform!(text)
end

def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
end
end
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def pg_index_type
end

def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
inference_client.perform!(text)
end

def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
end
end
end
end
Expand Down
33 changes: 23 additions & 10 deletions lib/inference/cloudflare_workers_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,38 @@
module ::DiscourseAi
module Inference
class CloudflareWorkersAi
def self.perform!(model, content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def initialize(account_id, api_token, model, referer = Discourse.base_url)
@account_id = account_id
@api_token = api_token
@model = model
@referer = referer
end

def self.instance(model)
new(
SiteSetting.ai_cloudflare_workers_account_id,
SiteSetting.ai_cloudflare_workers_api_token,
model,
)
end

account_id = SiteSetting.ai_cloudflare_workers_account_id
token = SiteSetting.ai_cloudflare_workers_api_token
attr_reader :account_id, :api_token, :model, :referer

base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
headers["Authorization"] = "Bearer #{token}"
def perform!(content)
headers = {
"Referer" => Discourse.base_url,
"Content-Type" => "application/json",
"Authorization" => "Bearer #{api_token}",
}

endpoint = "#{base_url}#{model}"
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"

conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(endpoint, content.to_json, headers)

raise Net::HTTPBadResponse if ![200].include?(response.status)

case response.status
when 200
JSON.parse(response.body, symbolize_names: true)
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
when 429
# TODO add a AdminDashboard Problem?
else
Expand Down
31 changes: 29 additions & 2 deletions lib/inference/discourse_classifier.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,36 @@
module ::DiscourseAi
module Inference
class DiscourseClassifier
def self.perform!(endpoint, model, content, api_key)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
@endpoint = endpoint
@api_key = api_key
@model = model
@referer = referer
end

def self.instance(model)
endpoint =
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end

new(
"#{endpoint}/api/v1/classify",
SiteSetting.ai_embeddings_discourse_service_api_key,
model,
)
end

attr_reader :endpoint, :api_key, :model, :referer

def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
headers["X-API-KEY"] = api_key if api_key.present?

conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
Expand Down
13 changes: 9 additions & 4 deletions lib/inference/gemini_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
module ::DiscourseAi
module Inference
class GeminiEmbeddings
def self.perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def initialize(api_key, referer = Discourse.base_url)
@api_key = api_key
@referer = referer
end

url =
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}"
attr_reader :api_key, :referer

def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
url =
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
body = { content: { parts: [{ text: content }] } }

conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
Expand Down
Loading