diff --git a/core/config/types.ts b/core/config/types.ts index 8c64de1ab1a..bd2747083dc 100644 --- a/core/config/types.ts +++ b/core/config/types.ts @@ -853,6 +853,8 @@ declare global { proxy?: string; headers?: { [key: string]: string }; extraBodyProperties?: { [key: string]: any }; + keepAlive?: number; + options?: { [key: string]: any }; noProxy?: string[]; clientCertificate?: ClientCertificateOptions; } diff --git a/core/llm/llms/Ollama.test.ts b/core/llm/llms/Ollama.test.ts index 78d69dd157a..fe12b63c662 100644 --- a/core/llm/llms/Ollama.test.ts +++ b/core/llm/llms/Ollama.test.ts @@ -8,8 +8,12 @@ import Ollama from "./Ollama.js"; function createOllama(): Ollama { // Create instance without triggering constructor's fetch call const instance = Object.create(Ollama.prototype); + instance.apiBase = "http://localhost:11434/"; instance.model = "test-model"; + instance.modelMap = {}; instance.completionOptions = {}; + instance.requestOptions = {}; + instance._contextLength = 4096; instance.fetch = jest.fn(); return instance; } @@ -223,4 +227,116 @@ describe("Ollama", () => { expect(result[1].role).toBe("tool"); }); }); + + describe("request option overrides", () => { + let ollama: Ollama; + + beforeEach(() => { + ollama = createOllama(); + }); + + it("should merge requestOptions.options into generate requests", () => { + (ollama as any).requestOptions = { + options: { + num_gpu: 20, + num_thread: 8, + keep_alive: -1, + repeat_penalty: 1.2, + }, + }; + + const result = (ollama as any)._getGenerateOptions({}, "hello"); + + expect(result.options).toMatchObject({ + num_ctx: 4096, + num_gpu: 20, + num_thread: 8, + repeat_penalty: 1.2, + }); + expect(result.keep_alive).toBe(-1); + }); + + it("should let completion options override request option defaults", () => { + (ollama as any).requestOptions = { + keepAlive: -1, + options: { + num_gpu: 20, + num_thread: 8, + }, + }; + + const result = (ollama as any)._getGenerateOptions( + { + keepAlive: 120, + numGpu: 4, + numThreads: 2, + }, + "hello", + ); + + expect(result.options).toMatchObject({ + num_ctx: 4096, + num_gpu: 4, + num_thread: 2, + }); + expect(result.keep_alive).toBe(120); + }); + }); + + describe("tool attachment", () => { + const tool = { + type: "function", + function: { + name: "get_weather", + description: "Get weather", + parameters: { + type: "object", + properties: { + city: { type: "string" }, + }, + }, + }, + } as any; + + async function runChatRequest(ollama: Ollama) { + (ollama as any).modelInfoPromise = Promise.resolve(); + (ollama as any).fetch = jest.fn().mockResolvedValue({ + status: 200, + json: async () => ({ + message: { role: "assistant", content: "ok" }, + }), + }); + + const messages: ChatMessage[] = [{ role: "user", content: "hello" }]; + for await (const _ of (ollama as any)._streamChat( + messages, + new AbortController().signal, + { tools: [tool], stream: false }, + )) { + } + + const [, init] = ((ollama as any).fetch as jest.Mock).mock.calls[0]; + return JSON.parse(init.body); + } + + it("should skip tools when the model template does not support them", async () => { + const ollama = createOllama(); + (ollama as any).templateSupportsTools = false; + + const requestBody = await runChatRequest(ollama); + + expect(requestBody.tools).toBeUndefined(); + }); + + it("should let explicit capabilities override the template tool gate", async () => { + const ollama = createOllama(); + (ollama as any).templateSupportsTools = false; + (ollama as any).capabilities = { tools: true }; + + const requestBody = await runChatRequest(ollama); + + expect(requestBody.tools).toHaveLength(1); + expect(requestBody.tools[0].function.name).toBe("get_weather"); + }); + }); }); diff --git a/core/llm/llms/Ollama.ts b/core/llm/llms/Ollama.ts index 4bcd9fb1e0f..acc9880f812 100644 --- a/core/llm/llms/Ollama.ts +++ b/core/llm/llms/Ollama.ts @@ -161,6 +161,7 @@ class Ollama extends BaseLLM implements ModelInstaller { private static modelsBeingInstalledMutex = new Mutex(); private fimSupported: boolean = false; + private templateSupportsTools: boolean | undefined = undefined; private modelInfoPromise: Promise | undefined = undefined; private explicitContextLength: boolean; @@ -240,6 +241,9 @@ class Ollama extends BaseLLM implements ModelInstaller { * it's a good indication the model supports FIM. */ this.fimSupported = !!body?.template?.includes(".Suffix"); + if (body?.template) { + this.templateSupportsTools = body.template.includes(".Tools"); + } }) .catch((e) => { // console.warn("Error calling the Ollama /api/show endpoint: ", e); @@ -321,6 +325,54 @@ class Ollama extends BaseLLM implements ModelInstaller { }; } + private _getRequestOptionOverrides(): { + modelFileParams: Partial; + keepAlive?: number; + } { + const requestOptions = this.requestOptions as + | (typeof this.requestOptions & { + keepAlive?: number; + options?: Record; + }) + | undefined; + + const rawOptions = + requestOptions?.options && typeof requestOptions.options === "object" + ? requestOptions.options + : {}; + const { keep_alive, ...modelFileParams } = rawOptions; + + return { + modelFileParams: modelFileParams as Partial, + keepAlive: + requestOptions?.keepAlive ?? + (typeof keep_alive === "number" ? keep_alive : undefined), + }; + } + + private _getBaseOptions(options: CompletionOptions): OllamaBaseOptions { + const { modelFileParams, keepAlive } = this._getRequestOptionOverrides(); + const completionModelFileParams = Object.fromEntries( + Object.entries(this._getModelFileParams(options)).filter( + ([_, value]) => value !== undefined, + ), + ) as Partial; + + return { + model: this._getModel(), + options: { + ...modelFileParams, + ...completionModelFileParams, + }, + keep_alive: options.keepAlive ?? keepAlive ?? 60 * 30, + stream: options.stream, + }; + } + + private _shouldAttachTools(): boolean { + return this.capabilities?.tools ?? this.templateSupportsTools ?? true; + } + private _convertToOllamaMessage(message: ChatMessage): OllamaChatMessage { const ollamaMessage: OllamaChatMessage = { role: message.role, @@ -394,13 +446,10 @@ class Ollama extends BaseLLM implements ModelInstaller { suffix?: string, ): OllamaRawOptions { return { - model: this._getModel(), + ...this._getBaseOptions(options), prompt, suffix, raw: options.raw, - options: this._getModelFileParams(options), - keep_alive: options.keepAlive ?? 60 * 30, // 30 minutes - stream: options.stream, // Not supported yet: context, images, system, template, format }; } @@ -503,15 +552,16 @@ class Ollama extends BaseLLM implements ModelInstaller { messages.map(this._convertToOllamaMessage), ); const chatOptions: OllamaChatOptions = { - model: this._getModel(), + ...this._getBaseOptions(options), messages: ollamaMessages, - options: this._getModelFileParams(options), think: options.reasoning, - keep_alive: options.keepAlive ?? 60 * 30, // 30 minutes - stream: options.stream, // format: options.format, // Not currently in base completion options }; - if (options.tools?.length && ollamaMessages.at(-1)?.role === "user") { + if ( + options.tools?.length && + ollamaMessages.at(-1)?.role === "user" && + this._shouldAttachTools() + ) { chatOptions.tools = options.tools.map((tool) => ({ type: "function", function: { diff --git a/packages/config-types/src/index.ts b/packages/config-types/src/index.ts index 8561500e662..61b32436229 100644 --- a/packages/config-types/src/index.ts +++ b/packages/config-types/src/index.ts @@ -35,6 +35,8 @@ export const requestOptionsSchema = z.object({ proxy: z.string().optional(), headers: z.record(z.string()).optional(), extraBodyProperties: z.record(z.any()).optional(), + keepAlive: z.number().optional(), + options: z.record(z.any()).optional(), noProxy: z.array(z.string()).optional(), clientCertificate: clientCertificateOptionsSchema.optional(), }); @@ -97,6 +99,8 @@ export const modelDescriptionSchema = z.object({ proxy: z.string().optional(), headers: z.record(z.string()).optional(), extraBodyProperties: z.record(z.any()).optional(), + keepAlive: z.number().optional(), + options: z.record(z.any()).optional(), noProxy: z.array(z.string()).optional(), }) .optional(), diff --git a/packages/config-yaml/src/schemas/models.test.ts b/packages/config-yaml/src/schemas/models.test.ts new file mode 100644 index 00000000000..092581fc554 --- /dev/null +++ b/packages/config-yaml/src/schemas/models.test.ts @@ -0,0 +1,24 @@ +import { describe, expect, it } from "@jest/globals"; +import { requestOptionsSchema } from "./models.js"; + +describe("requestOptionsSchema", () => { + it("should preserve Ollama request body overrides", () => { + const result = requestOptionsSchema.parse({ + keepAlive: -1, + options: { + num_gpu: 20, + num_thread: 8, + keep_alive: -1, + }, + }); + + expect(result).toEqual({ + keepAlive: -1, + options: { + num_gpu: 20, + num_thread: 8, + keep_alive: -1, + }, + }); + }); +}); diff --git a/packages/config-yaml/src/schemas/models.ts b/packages/config-yaml/src/schemas/models.ts index ee20d5a0540..78650fec950 100644 --- a/packages/config-yaml/src/schemas/models.ts +++ b/packages/config-yaml/src/schemas/models.ts @@ -16,6 +16,8 @@ export const requestOptionsSchema = z.object({ proxy: z.string().optional(), headers: z.record(z.string()).optional(), extraBodyProperties: z.record(z.any()).optional(), + keepAlive: z.number().optional(), + options: z.record(z.any()).optional(), noProxy: z.array(z.string()).optional(), clientCertificate: clientCertificateOptionsSchema.optional(), });