diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js new file mode 100644 index 000000000..aafc69f2f --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js @@ -0,0 +1,16 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default DiscourseRoute.extend({ + async model() { + const record = this.store.createRecord("ai-llm"); + return record; + }, + + setupController(controller, model) { + this._super(controller, model); + controller.set( + "allLlms", + this.modelFor("adminPlugins.show.discourse-ai-llms") + ); + }, +}); diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js new file mode 100644 index 000000000..7a9fa379d --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js @@ -0,0 +1,17 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default DiscourseRoute.extend({ + async model(params) { + const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms"); + const id = parseInt(params.id, 10); + return allLlms.findBy("id", id); + }, + + setupController(controller, model) { + this._super(controller, model); + controller.set( + "allLlms", + this.modelFor("adminPlugins.show.discourse-ai-llms") + ); + }, +}); diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js new file mode 100644 index 000000000..21f8f44b6 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js @@ -0,0 +1,7 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default class DiscourseAiAiLlmsRoute extends DiscourseRoute { + model() { + return this.store.findAll("ai-llm"); + } +} diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs new file mode 100644 index 000000000..e1ab7f35c --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs new file mode 100644 index 000000000..77f3b0f31 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs new file mode 100644 index 000000000..77f3b0f31 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/controllers/discourse_ai/admin/ai_llms_controller.rb b/app/controllers/discourse_ai/admin/ai_llms_controller.rb new file mode 100644 index 000000000..6e57f8417 --- /dev/null +++ b/app/controllers/discourse_ai/admin/ai_llms_controller.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +module DiscourseAi + module Admin + class AiLlmsController < ::Admin::AdminController + requires_plugin ::DiscourseAi::PLUGIN_NAME + + def index + llms = LlmModel.all + + render json: { + ai_llms: + ActiveModel::ArraySerializer.new( + llms, + each_serializer: LlmModelSerializer, + root: false, + ).as_json, + meta: { + providers: DiscourseAi::Completions::Llm.provider_names, + tokenizers: + DiscourseAi::Completions::Llm.tokenizer_names.map { |tn| + { id: tn, name: tn.split("::").last } + }, + }, + } + end + + def show + llm_model = LlmModel.find(params[:id]) + render json: LlmModelSerializer.new(llm_model) + end + + def create + if llm_model = LlmModel.new(ai_llm_params).save + render json: { ai_persona: llm_model }, status: :created + else + render_json_error llm_model + end + end + + def update + llm_model = LlmModel.find(params[:id]) + + if llm_model.update(ai_llm_params) + render json: llm_model + else + render_json_error llm_model + end + end + + private + + def ai_llm_params + params.require(:ai_llm).permit( + :display_name, + :name, + :provider, + :tokenizer, + :max_prompt_tokens, + ) + end + end + end +end diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb new file mode 100644 index 000000000..aefb92020 --- /dev/null +++ b/app/models/llm_model.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class LlmModel < ActiveRecord::Base + def tokenizer_class + tokenizer.constantize + end +end diff --git a/app/serializers/llm_model_serializer.rb b/app/serializers/llm_model_serializer.rb new file mode 100644 index 000000000..77f264b84 --- /dev/null +++ b/app/serializers/llm_model_serializer.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class LlmModelSerializer < ApplicationSerializer + root "llm" + + attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer +end diff --git a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js index 54411f905..97ed05e51 100644 --- a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js +++ b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js @@ -8,5 +8,10 @@ export default { this.route("new"); this.route("show", { path: "/:id" }); }); + + this.route("discourse-ai-llms", { path: "ai-llms" }, function () { + this.route("new"); + this.route("show", { path: "/:id" }); + }); }, }; diff --git a/assets/javascripts/discourse/admin/adapters/ai-llm.js b/assets/javascripts/discourse/admin/adapters/ai-llm.js new file mode 100644 index 000000000..fe82163d5 --- /dev/null +++ b/assets/javascripts/discourse/admin/adapters/ai-llm.js @@ -0,0 +1,21 @@ +import RestAdapter from "discourse/adapters/rest"; + +export default class Adapter extends RestAdapter { + jsonMode = true; + + basePath() { + return "/admin/plugins/discourse-ai/"; + } + + pathFor(store, type, findArgs) { + // removes underscores which are implemented in base + let path = + this.basePath(store, type, findArgs) + + store.pluralize(this.apiNameFor(type)); + return this.appendQueryParams(path, findArgs); + } + + apiNameFor() { + return "ai-llm"; + } +} diff --git a/assets/javascripts/discourse/admin/models/ai-llm.js b/assets/javascripts/discourse/admin/models/ai-llm.js new file mode 100644 index 000000000..69f9fbfe6 --- /dev/null +++ b/assets/javascripts/discourse/admin/models/ai-llm.js @@ -0,0 +1,20 @@ +import RestModel from "discourse/models/rest"; + +export default class AiLlm extends RestModel { + createProperties() { + return this.getProperties( + "display_name", + "name", + "provider", + "tokenizer", + "max_prompt_tokens" + ); + } + + updateProperties() { + const attrs = this.createProperties(); + attrs.id = this.id; + + return attrs; + } +} diff --git a/assets/javascripts/discourse/components/ai-llm-editor.gjs b/assets/javascripts/discourse/components/ai-llm-editor.gjs new file mode 100644 index 000000000..282a248eb --- /dev/null +++ b/assets/javascripts/discourse/components/ai-llm-editor.gjs @@ -0,0 +1,123 @@ +import Component from "@glimmer/component"; +import { tracked } from "@glimmer/tracking"; +import { Input } from "@ember/component"; +import { action } from "@ember/object"; +import { later } from "@ember/runloop"; +import { inject as service } from "@ember/service"; +import DButton from "discourse/components/d-button"; +import { popupAjaxError } from "discourse/lib/ajax-error"; +import i18n from "discourse-common/helpers/i18n"; +import I18n from "discourse-i18n"; +import ComboBox from "select-kit/components/combo-box"; +import DTooltip from "float-kit/components/d-tooltip"; + +export default class AiLlmEditor extends Component { + @service toasts; + @service router; + + @tracked isSaving = false; + + get selectedProviders() { + const t = (provName) => { + return I18n.t(`discourse_ai.llms.providers.${provName}`); + }; + + return this.args.llms.resultSetMeta.providers.map((prov) => { + return { id: prov, name: t(prov) }; + }); + } + + @action + async save() { + this.isSaving = true; + const isNew = this.args.model.isNew; + + debugger; + try { + await this.args.model.save(); + + if (isNew) { + this.args.llms.addObject(this.args.model); + this.router.transitionTo( + "adminPlugins.show.discourse-ai-llms.show", + this.args.model + ); + } else { + this.toasts.success({ + data: { message: I18n.t("discourse_ai.llms.saved") }, + duration: 2000, + }); + } + } catch (e) { + popupAjaxError(e); + } finally { + later(() => { + this.isSaving = false; + }, 1000); + } + } + + +} diff --git a/assets/javascripts/discourse/components/ai-llms-list-editor.gjs b/assets/javascripts/discourse/components/ai-llms-list-editor.gjs new file mode 100644 index 000000000..73ba310ee --- /dev/null +++ b/assets/javascripts/discourse/components/ai-llms-list-editor.gjs @@ -0,0 +1,61 @@ +import Component from "@glimmer/component"; +import { LinkTo } from "@ember/routing"; +import icon from "discourse-common/helpers/d-icon"; +import i18n from "discourse-common/helpers/i18n"; +import I18n from "discourse-i18n"; +import AiLlmEditor from "./ai-llm-editor"; + +export default class AiLlmsListEditor extends Component { + get hasNoLLMElements() { + this.args.llms.length !== 0; + } + + +} diff --git a/assets/javascripts/initializers/admin-plugin-configuration-nav.js b/assets/javascripts/initializers/admin-plugin-configuration-nav.js index 2c2bac59c..51d882223 100644 --- a/assets/javascripts/initializers/admin-plugin-configuration-nav.js +++ b/assets/javascripts/initializers/admin-plugin-configuration-nav.js @@ -16,6 +16,10 @@ export default { label: "discourse_ai.ai_persona.short_title", route: "adminPlugins.show.discourse-ai-personas", }, + { + label: "discourse_ai.llms.short_title", + route: "adminPlugins.show.discourse-ai-llms", + }, ]); }); }, diff --git a/assets/stylesheets/modules/llms/common/ai-llms-editor.scss b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss new file mode 100644 index 000000000..351deeef6 --- /dev/null +++ b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss @@ -0,0 +1,32 @@ +.ai-llms-list-editor { + &__header { + display: flex; + justify-content: space-between; + align-items: center; + margin: 0 0 1em 0; + + h3 { + margin: 0; + } + } + + &__container { + display: flex; + flex-direction: row; + align-items: center; + gap: 20px; + width: 100%; + align-items: stretch; + } + + &__empty_list, + &__content_list { + min-width: 300px; + } + + &__empty_list { + align-content: center; + text-align: center; + font-size: var(--font-up-1); + } +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 1c3d9c0c7..044d00f38 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -194,6 +194,32 @@ en: uploading: "Uploading..." remove: "Remove upload" + llms: + short_title: "LLMs" + no_llms: "No LLMs yet" + new: "New Model" + display_name: "Name to display:" + name: "Model name:" + provider: "Service hosting the model:" + tokenizer: "Tokenizer:" + max_prompt_tokens: "Number of tokens for the prompt:" + save: "Save" + saved: "LLM Model Saved" + + hints: + max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window." + name: "We include this in the API call to specify which model we'll use." + + providers: + aws_bedrock: "AWS Bedrock" + anthropic: "Anthropic" + vllm: "vLLM" + hugging_face: "Hugging Face" + cohere: "Cohere" + open_ai: "OpenAI" + google: "Google" + azure: "Azure" + related_topics: title: "Related Topics" pill: "Related" diff --git a/config/routes.rb b/config/routes.rb index eb20cf98c..f33dfd685 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -45,5 +45,10 @@ post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file" put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file" get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check" + + resources :ai_llms, + only: %i[index create show update], + path: "ai-llms", + controller: "discourse_ai/admin/ai_llms" end end diff --git a/db/migrate/20240504222307_create_llm_model_table.rb b/db/migrate/20240504222307_create_llm_model_table.rb new file mode 100644 index 000000000..96bcc3dd4 --- /dev/null +++ b/db/migrate/20240504222307_create_llm_model_table.rb @@ -0,0 +1,14 @@ +# frozen_string_literal: true + +class CreateLlmModelTable < ActiveRecord::Migration[7.0] + def change + create_table :llm_models do |t| + t.string :display_name + t.string :name, null: false + t.string :provider, null: false + t.string :tokenizer, null: false + t.integer :max_prompt_tokens, null: false + t.timestamps + end + end +end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 3e3c932e5..d2ca71f59 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -150,59 +150,73 @@ def invoke_tool(tool, llm, cancel, context, &update_blk) def self.guess_model(bot_user) # HACK(roman): We'll do this until we define how we represent different providers in the bot settings - case bot_user.id - when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - "aws_bedrock:claude-2" - else - "anthropic:claude-2" - end - when DiscourseAi::AiBot::EntryPoint::GPT4_ID - "open_ai:gpt-4" - when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID - "open_ai:gpt-4-turbo" - when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID - "open_ai:gpt-3.5-turbo-16k" - when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) - "vllm:#{mixtral_model}" - elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( - mixtral_model, - ) - "hugging_face:#{mixtral_model}" - else - "ollama:mistral" - end - when DiscourseAi::AiBot::EntryPoint::GEMINI_ID - "google:gemini-pro" - when DiscourseAi::AiBot::EntryPoint::FAKE_ID - "fake:fake" - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-opus") - "aws_bedrock:claude-3-opus" - else - "anthropic:claude-3-opus" - end - when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS - "cohere:command-r-plus" - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( - "claude-3-sonnet", - ) - "aws_bedrock:claude-3-sonnet" - else - "anthropic:claude-3-sonnet" - end - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-haiku") - "aws_bedrock:claude-3-haiku" + guess = + case bot_user.id + when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" + else + "anthropic:claude-2" + end + when DiscourseAi::AiBot::EntryPoint::GPT4_ID + "open_ai:gpt-4" + when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID + "open_ai:gpt-4-turbo" + when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID + "open_ai:gpt-3.5-turbo-16k" + when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID + mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) + "vllm:#{mixtral_model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( + mixtral_model, + ) + "hugging_face:#{mixtral_model}" + else + "ollama:mistral" + end + when DiscourseAi::AiBot::EntryPoint::GEMINI_ID + "google:gemini-pro" + when DiscourseAi::AiBot::EntryPoint::FAKE_ID + "fake:fake" + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-opus", + ) + "aws_bedrock:claude-3-opus" + else + "anthropic:claude-3-opus" + end + when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS + "cohere:command-r-plus" + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-sonnet", + ) + "aws_bedrock:claude-3-sonnet" + else + "anthropic:claude-3-sonnet" + end + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-haiku", + ) + "aws_bedrock:claude-3-haiku" + else + "anthropic:claude-3-haiku" + end else - "anthropic:claude-3-haiku" + nil end - else - nil + + if guess + provider, model_name = guess.split(":") + llm_model = LlmModel.find_by(provider: provider, name: model_name) + + return "custom:#{llm_model.id}" if llm_model end + + guess end def build_placeholder(summary, details, custom_raw: nil) diff --git a/lib/automation.rb b/lib/automation.rb index b755f1dba..71cde97b6 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -25,6 +25,9 @@ module Automation ] def self.translate_model(model) + llm_model = LlmModel.find_by(name: model) + return "custom:#{llm_model.id}" if llm_model + return "google:#{model}" if model.start_with? "gemini" return "open_ai:#{model}" if model.start_with? "gpt" return "cohere:#{model}" if model.start_with? "command" diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 3552147be..f6142d091 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -9,7 +9,7 @@ def can_translate?(model_name) model_name.starts_with?("gpt-") end - def tokenizer(_) + def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end end @@ -30,13 +30,15 @@ def translate end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + # provide a buffer of 120 tokens - our function counting is not # 100% accurate and getting numbers to align exactly is very hard buffer = (opts[:max_tokens] || 2500) + 50 if tools.present? # note this is about 100 tokens over, OpenAI have a more optimal representation - @function_size ||= self.class.tokenizer(model_name).size(tools.to_json.to_s) + @function_size ||= self.class.tokenizer.size(tools.to_json.to_s) buffer += @function_size end @@ -110,7 +112,7 @@ def per_message_overhead end def calculate_message_token(context) - self.class.tokenizer(model_name).size(context[:content].to_s + context[:name].to_s) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end def model_max_tokens diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index cb82c9ae6..8be67e548 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -11,7 +11,7 @@ def can_translate?(model_name) ) end - def tokenizer(_) + def tokenizer DiscourseAi::Tokenizer::AnthropicTokenizer end end @@ -50,6 +50,7 @@ def translate end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? # Longer term it will have over 1 million 200_000 # Claude-3 has a 200k context window for now end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 3a09541c5..622403726 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -11,7 +11,7 @@ def can_translate?(model_name) %w[command-light command command-r command-r-plus].include?(model_name) end - def tokenizer(_) + def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end end @@ -38,6 +38,8 @@ def translate end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + case model_name when "command-light" 4096 @@ -59,48 +61,7 @@ def per_message_overhead end def calculate_message_token(context) - self.class.tokenizer(model_name).size(context[:content].to_s + context[:name].to_s) - end - - def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) - end - - def system_msg(msg) - cmd_msg = { role: "SYSTEM", message: msg[:content] } - - if tools_dialect.instructions.present? - cmd_msg[:message] = [ - msg[:content], - tools_dialect.instructions, - "NEVER attempt to run tools using JSON, always use XML. Lives depend on it.", - ].join("\n") - end - - cmd_msg - end - - def model_msg(msg) - { role: "CHATBOT", message: msg[:content] } - end - - def tool_call_msg(msg) - { role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) } - end - - def tool_msg(msg) - { role: "USER", message: tools_dialect.from_raw_tool(msg) } - end - - def user_msg(msg) - user_message = { role: "USER", message: msg[:content] } - user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] - - user_message - end - - def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end def system_msg(msg) diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 6e3be5914..13275ad41 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -9,19 +9,22 @@ def can_translate?(_model_name) raise NotImplemented end - def dialect_for(model_name) - dialects = [ + def all_dialects + [ DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::Gemini, -<<<<<<< HEAD DiscourseAi::Completions::Dialects::Mistral, -======= - DiscourseAi::Completions::Dialects::Tgi, - DiscourseAi::Completions::Dialects::Vllm, ->>>>>>> 07139a2 (REFACTOR: Migrate Vllm/TGI-served models to the OpenAI format.) DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, ] + end + + def available_tokenizers + all_dialects.map(&:tokenizer) + end + + def dialect_for(model_name) + dialects = all_dialects if Rails.env.test? || Rails.env.development? dialects << DiscourseAi::Completions::Dialects::Fake @@ -32,7 +35,7 @@ def dialect_for(model_name) dialect end - def tokenizer(_) + def tokenizer raise NotImplemented end end @@ -140,31 +143,7 @@ def per_message_overhead end def calculate_message_token(msg) - self.class.tokenizer(model_name).size(msg[:content].to_s) - end - - def tools_dialect - raise NotImplemented - end - - def system_msg(msg) - raise NotImplemented - end - - def assistant_msg(msg) - raise NotImplemented - end - - def user_msg(msg) - raise NotImplemented - end - - def tool_call_msg(msg) - raise NotImplemented - end - - def tool_msg(msg) - raise NotImplemented + self.class.tokenizer.size(msg[:content].to_s) end def tools_dialect diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 3c7612a81..fde9cdda9 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -9,7 +9,7 @@ def can_translate?(model_name) %w[gemini-pro gemini-1.5-pro].include?(model_name) end - def tokenizer(_) + def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer end end @@ -68,6 +68,8 @@ def tools end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + if model_name == "gemini-1.5-pro" # technically we support 1 million tokens, but we're being conservative 800_000 @@ -79,47 +81,7 @@ def max_prompt_tokens protected def calculate_message_token(context) - self.class.tokenizer(model_name).size(context[:content].to_s + context[:name].to_s) - end - - def system_msg(msg) - { role: "user", parts: { text: msg[:content] } } - end - - def model_msg(msg) - { role: "model", parts: { text: msg[:content] } } - end - - def user_msg(msg) - { role: "user", parts: { text: msg[:content] } } - end - - def tool_call_msg(msg) - call_details = JSON.parse(msg[:content], symbolize_names: true) - - { - role: "model", - parts: { - functionCall: { - name: msg[:name] || call_details[:name], - args: call_details[:arguments], - }, - }, - } - end - - def tool_msg(msg) - { - role: "function", - parts: { - functionResponse: { - name: msg[:name] || msg[:id], - response: { - content: msg[:content], - }, - }, - }, - } + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end def system_msg(msg) diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb index 7752a8765..d34130f9b 100644 --- a/lib/completions/dialects/mistral.rb +++ b/lib/completions/dialects/mistral.rb @@ -23,6 +23,8 @@ def tools end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + 32_000 end diff --git a/lib/completions/dialects/tgi.rb b/lib/completions/dialects/tgi.rb deleted file mode 100644 index bf34f6b19..000000000 --- a/lib/completions/dialects/tgi.rb +++ /dev/null @@ -1,76 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class Tgi < Dialect - class << self - def can_translate?(model_name) - %w[ - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - Llama2-*-chat-hf - Llama2-chat-hf - ].include?(model_name) - end - - def tokenizer(model_name) - if %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) - DiscourseAi::Tokenizer::MixtralTokenizer - else - DiscourseAi::Tokenizer::Llama2Tokenizer - end - end - end - - def tools - @tools ||= tools_dialect.translated_tools - end - - def max_prompt_tokens - if %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) - 32_000 - else - SiteSetting.ai_hugging_face_token_limit - end - end - - private - - def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::OpenAiTools.new(prompt.tools) - end - - def system_msg(msg) - { role: "system", content: msg[:content] } - end - - def model_msg(msg) - { role: "assistant", content: msg[:content] } - end - - def tool_call_msg(msg) - tools_dialect.from_raw_tool_call(msg) - end - - def tool_msg(msg) - tools_dialect.from_raw_tool(msg) - end - - def user_msg(msg) - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - - { role: "user", content: content } - end - end - end - end -end diff --git a/lib/completions/dialects/vllm.rb b/lib/completions/dialects/vllm.rb deleted file mode 100644 index 7f65bab95..000000000 --- a/lib/completions/dialects/vllm.rb +++ /dev/null @@ -1,78 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class Vllm < Dialect - class << self - def can_translate?(model_name) - %w[ - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - Llama2-*-chat-hf - Llama2-chat-hf - ].include?(model_name) - end - - def tokenizer(model_name) - if %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) - DiscourseAi::Tokenizer::MixtralTokenizer - else - DiscourseAi::Tokenizer::Llama2Tokenizer - end - end - end - - def max_prompt_tokens - if %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) - 32_000 - else - SiteSetting.ai_hugging_face_token_limit - end - end - - private - - def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) - end - - def system_msg(msg) - msg = { role: "system", content: msg[:content] } - - if tools_dialect.instructions.present? - msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" - end - - msg - end - - def model_msg(msg) - { role: "assistant", content: msg[:content] } - end - - def tool_call_msg(msg) - { role: "assistant", content: tools_dialect.from_raw_tool_call(msg) } - end - - def tool_msg(msg) - { role: "user", content: tools_dialect.from_raw_tool(msg) } - end - - def user_msg(msg) - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - - { role: "user", content: content } - end - end - end - end -end diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index ed4f78403..d6237c05b 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -51,7 +51,7 @@ def model_uri URI(SiteSetting.ai_hugging_face_api_url) end - def prepare_payload(prompt, model_params, dialect) + def prepare_payload(prompt, model_params, _dialect) default_options .merge(model_params) .merge(messages: prompt) @@ -63,7 +63,6 @@ def prepare_payload(prompt, model_params, dialect) end payload[:stream] = true if @streaming_mode - payload[:tools] = dialect.tools if dialect.tools.present? end end @@ -92,75 +91,11 @@ def partials_from(decoded_chunk) decoded_chunk .split("\n") .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data + data = line.split("data:", 2)[1] + data&.squish == "[DONE]" ? nil : data end .compact end - - def has_tool?(_response_data) - @has_function_call - end - - def native_tool_support? - true - end - - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - @args_buffer ||= +"" - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - end - - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{value}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer - end - end - - function_buffer - end end end end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index b3d45f1aa..a2f70172a 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -18,6 +18,14 @@ class Llm UNKNOWN_MODEL = Class.new(StandardError) class << self + def provider_names + %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure] + end + + def tokenizer_names + DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq + end + def models_by_provider # ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure. # However, since they use the same URL/key settings, there's no reason to duplicate them. @@ -80,36 +88,54 @@ def record_prompt(prompt) end def proxy(model_name) + # We are in the process of transitioning to always use objects here. + # We'll live with this hack for a while. provider_and_model_name = model_name.split(":") - provider_name = provider_and_model_name.first model_name_without_prov = provider_and_model_name[1..].join + is_custom_model = provider_name == "custom" + + if is_custom_model + llm_model = LlmModel.find(model_name_without_prov) + provider_name = llm_model.provider + model_name_without_prov = llm_model.name + end dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov) + if is_custom_model + tokenizer = llm_model.tokenizer_class + else + tokenizer = dialect_klass.tokenizer + end + if @canned_response if @canned_llm && @canned_llm != model_name raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" end - return new(dialect_klass, nil, model_name, gateway: @canned_response) + return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response }) end + opts = {} + opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model + gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for( provider_name, model_name_without_prov, - ).new(model_name_without_prov, dialect_klass.tokenizer(model_name_without_prov)) + ) - new(dialect_klass, gateway_klass, model_name_without_prov) + new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts) end end - def initialize(dialect_klass, gateway_klass, model_name, gateway: nil) + def initialize(dialect_klass, gateway_klass, model_name, opts: {}) @dialect_klass = dialect_klass @gateway_klass = gateway_klass @model_name = model_name - @gateway = gateway + @gateway = opts[:gateway] + @max_prompt_tokens = opts[:max_prompt_tokens] end # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object @@ -166,11 +192,18 @@ def generate( model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer) - dialect = dialect_klass.new(prompt, model_name, opts: model_params) + dialect = + dialect_klass.new( + prompt, + model_name, + opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens), + ) gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) end def max_prompt_tokens + return @max_prompt_tokens if @max_prompt_tokens.present? + dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens end diff --git a/lib/configuration/llm_enumerator.rb b/lib/configuration/llm_enumerator.rb index d76187955..fd870b57f 100644 --- a/lib/configuration/llm_enumerator.rb +++ b/lib/configuration/llm_enumerator.rb @@ -10,14 +10,22 @@ def self.valid_value?(val) end def self.values - # do not cache cause settings can change this - DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| - endpoint = - DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) + begin + llm_models = + DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| + endpoint = + DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) - models.map do |model_name| - { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + models.map do |model_name| + { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + end + end + + LlmModel.all.each do |model| + llm_models << { name: model.display_name, value: "custom:#{model.id}" } end + + llm_models end end end diff --git a/plugin.rb b/plugin.rb index ad255b70c..c5a571bab 100644 --- a/plugin.rb +++ b/plugin.rb @@ -26,6 +26,8 @@ register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile +register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss" + module ::DiscourseAi PLUGIN_NAME = "discourse-ai" end diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb new file mode 100644 index 000000000..c419341e9 --- /dev/null +++ b/spec/fabricators/llm_model_fabricator.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +Fabricator(:llm_model) do + display_name "A good model" + name "gpt-4-turbo" + provider "open_ai" + tokenizer "DiscourseAi::Tokenizers::OpenAi" + max_prompt_tokens 32_000 +end diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb index 05c4fbc83..c54d18389 100644 --- a/spec/lib/completions/dialects/dialect_spec.rb +++ b/spec/lib/completions/dialects/dialect_spec.rb @@ -7,7 +7,7 @@ def trim(messages) trim_messages(messages) end - def self.tokenizer(_) + def self.tokenizer Class.new do def self.size(str) str.length diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 93b7dc424..f84fc43a9 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -23,8 +23,8 @@ def response(content) def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") - .with(body: request_body(prompt, tool_call: tool_call)) - .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) + .with(body: request_body(prompt)) + .to_return(status: 200, body: JSON.dump(response(response_text))) end def stream_line(delta, finish_reason: nil) @@ -60,7 +60,7 @@ def stub_streamed_response(prompt, deltas, tool_call: false) WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") - .with(body: request_body(prompt, stream: true, tool_call: tool_call)) + .with(body: request_body(prompt, stream: true)) .to_return(status: 200, body: chunks) yield if block_given? @@ -81,7 +81,10 @@ def request_body(prompt, stream: false, tool_call: false) RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do subject(:endpoint) do - described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer) + described_class.new( + "mistralai/Mistral-7B-Instruct-v0.2", + DiscourseAi::Tokenizer::MixtralTokenizer, + ) end before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index e91a4a60f..c3107dc00 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -6,7 +6,9 @@ DiscourseAi::Completions::Dialects::Mistral, canned_response, "hugging_face:Upstage-Llama-2-*-instruct-v2", - gateway: canned_response, + opts: { + gateway: canned_response, + }, ) end diff --git a/spec/requests/admin/ai_llms_controller_spec.rb b/spec/requests/admin/ai_llms_controller_spec.rb new file mode 100644 index 000000000..e47478344 --- /dev/null +++ b/spec/requests/admin/ai_llms_controller_spec.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Admin::AiLlmsController do + fab!(:admin) + + before { sign_in(admin) } + + describe "GET #index" do + it "includes all available providers metadata" do + get "/admin/plugins/discourse-ai/ai-llms.json" + expect(response).to be_successful + + expect(response.parsed_body["meta"]["providers"]).to contain_exactly( + *DiscourseAi::Completions::Llm.provider_names, + ) + end + end + + describe "POST #create" do + context "with valid attributes" do + let(:valid_attrs) do + { + display_name: "My cool LLM", + name: "gpt-3.5", + provider: "open_ai", + tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer", + max_prompt_tokens: 16_000, + } + end + + it "creates a new LLM model" do + post "/admin/plugins/discourse-ai/ai-llms.json", params: { ai_llm: valid_attrs } + + created_model = LlmModel.last + + expect(created_model.display_name).to eq(valid_attrs[:display_name]) + expect(created_model.name).to eq(valid_attrs[:name]) + expect(created_model.provider).to eq(valid_attrs[:provider]) + expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer]) + expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens]) + end + end + end + + describe "PUT #update" do + fab!(:llm_model) + + context "with valid update params" do + let(:update_attrs) { { provider: "anthropic" } } + + it "updates the model" do + put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json", + params: { + ai_llm: update_attrs, + } + + expect(response.status).to eq(200) + expect(llm_model.reload.provider).to eq(update_attrs[:provider]) + end + + it "returns a 404 if there is no model with the given Id" do + put "/admin/plugins/discourse-ai/ai-llms/9999999.json" + + expect(response.status).to eq(404) + end + end + end +end