diff --git a/binary/src/core.ts b/binary/src/core.ts index 63caec3476..5f92df0f96 100644 --- a/binary/src/core.ts +++ b/binary/src/core.ts @@ -6,6 +6,7 @@ import { indexDocs } from "core/indexing/docs"; import TransformersJsEmbeddingsProvider from "core/indexing/embeddings/TransformersJsEmbeddingsProvider"; import { CodebaseIndexer, PauseToken } from "core/indexing/indexCodebase"; import { logDevData } from "core/util/devdata"; +import { fetchwithRequestOptions } from "core/util/fetchWithOptions"; import historyManager from "core/util/history"; import { Message } from "core/util/messenger"; import { Telemetry } from "core/util/posthog"; @@ -138,7 +139,11 @@ export class Core { const config = await this.config(); const items = config.contextProviders ?.find((provider) => provider.description.title === msg.data.title) - ?.loadSubmenuItems({ ide: this.ide }); + ?.loadSubmenuItems({ + ide: this.ide, + fetch: (url, init) => + fetchwithRequestOptions(url, init, config.requestOptions), + }); return items || []; }); on("context/getContextItems", async (msg) => { @@ -160,6 +165,8 @@ export class Core { ide, selectedCode: msg.data.selectedCode, reranker: config.reranker, + fetch: (url, init) => + fetchwithRequestOptions(url, init, config.requestOptions), }); Telemetry.capture("useContextProvider", { diff --git a/core/config/handler.ts b/core/config/handler.ts index c64eea703e..4963cfe595 100644 --- a/core/config/handler.ts +++ b/core/config/handler.ts @@ -98,8 +98,9 @@ export class ConfigHandler { const resp = await fetchwithRequestOptions( new URL(input), { ...init }, - llm.requestOptions, + { ...llm.requestOptions, ...this.savedConfig?.requestOptions }, ); + if (!resp.ok) { let text = await resp.text(); if (resp.status === 404 && !resp.url.includes("/v1")) { diff --git a/core/config/load.ts b/core/config/load.ts index cdd91aacfc..ab4528e8af 100644 --- a/core/config/load.ts +++ b/core/config/load.ts @@ -32,6 +32,7 @@ import { BaseLLM } from "../llm"; import { llmFromDescription } from "../llm/llms"; import CustomLLMClass from "../llm/llms/CustomLLM"; import { copyOf } from "../util"; +import { fetchwithRequestOptions } from "../util/fetchWithOptions"; import mergeJson from "../util/merge"; import { getConfigJsPath, @@ -279,7 +280,11 @@ async function intermediateToFinalConfig( const { provider, ...options } = embeddingsProviderDescription; const embeddingsProviderClass = AllEmbeddingsProviders[provider]; if (embeddingsProviderClass) { - config.embeddingsProvider = new embeddingsProviderClass(options); + config.embeddingsProvider = new embeddingsProviderClass( + options, + (url: string | URL, init: any) => + fetchwithRequestOptions(url, init, config.requestOptions), + ); } } diff --git a/core/context/providers/GitHubIssuesContextProvider.ts b/core/context/providers/GitHubIssuesContextProvider.ts index c3464af750..f983dce6d6 100644 --- a/core/context/providers/GitHubIssuesContextProvider.ts +++ b/core/context/providers/GitHubIssuesContextProvider.ts @@ -24,6 +24,9 @@ class GitHubIssuesContextProvider extends BaseContextProvider { const octokit = new Octokit({ auth: this.options?.githubToken, + request: { + fetch: extras.fetch, + }, }); const { owner, repo, issue_number } = JSON.parse(issueId); @@ -64,6 +67,9 @@ class GitHubIssuesContextProvider extends BaseContextProvider { const octokit = new Octokit({ auth: this.options?.githubToken, + request: { + fetch: args.fetch, + }, }); const allIssues = []; diff --git a/core/context/providers/GoogleContextProvider.ts b/core/context/providers/GoogleContextProvider.ts index c2ee4877b2..3ee81654cc 100644 --- a/core/context/providers/GoogleContextProvider.ts +++ b/core/context/providers/GoogleContextProvider.ts @@ -32,7 +32,7 @@ class GoogleContextProvider extends BaseContextProvider { "Content-Type": "application/json", }; - const response = await fetch(url, { + const response = await extras.fetch(url, { method: "POST", headers: headers, body: payload, diff --git a/core/context/providers/HttpContextProvider.ts b/core/context/providers/HttpContextProvider.ts index e66869828b..05b5a9a8f3 100644 --- a/core/context/providers/HttpContextProvider.ts +++ b/core/context/providers/HttpContextProvider.ts @@ -4,7 +4,6 @@ import { ContextProviderDescription, ContextProviderExtras, } from "../.."; -import { fetchwithRequestOptions } from "../../util/fetchWithOptions"; class HttpContextProvider extends BaseContextProvider { static description: ContextProviderDescription = { @@ -29,20 +28,17 @@ class HttpContextProvider extends BaseContextProvider { query: string, extras: ContextProviderExtras, ): Promise { - const response = await fetchwithRequestOptions( - new URL(this.options.url), - { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - query: query || "", - fullInput: extras.fullInput, - }), - } - ); - + const response = await extras.fetch(new URL(this.options.url), { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + query: query || "", + fullInput: extras.fullInput, + }), + }); + const json: any = await response.json(); return [ { diff --git a/core/context/providers/JiraIssuesContextProvider/JiraClient.ts b/core/context/providers/JiraIssuesContextProvider/JiraClient.ts index 573c5549f1..d9efaa6018 100644 --- a/core/context/providers/JiraIssuesContextProvider/JiraClient.ts +++ b/core/context/providers/JiraIssuesContextProvider/JiraClient.ts @@ -1,5 +1,4 @@ import { RequestOptions } from "../../.."; -import { fetchwithRequestOptions } from "../../../util/fetchWithOptions"; const { convert: adf2md } = require("adf-to-md"); interface JiraClientOptions { @@ -85,12 +84,15 @@ export class JiraClient { }; } - async issue(issueId: string): Promise { + async issue( + issueId: string, + customFetch: (url: string | URL, init: any) => Promise, + ): Promise { const result = {} as Issue; - const response = await fetchwithRequestOptions( + const response = await customFetch( new URL( - this.baseUrl + `/issue/${issueId}?fields=description,comment,summary` + this.baseUrl + `/issue/${issueId}?fields=description,comment,summary`, ), { method: "GET", @@ -99,7 +101,6 @@ export class JiraClient { ...this.authHeader, }, }, - this.options.requestOptions ); const issue = (await response.json()) as any; @@ -133,14 +134,16 @@ export class JiraClient { return result; } - async listIssues(): Promise> { - const response = await fetchwithRequestOptions( + async listIssues( + customFetch: (url: string | URL, init: any) => Promise, + ): Promise> { + const response = await customFetch( new URL( this.baseUrl + `/search?fields=summary&jql=${ this.options.issueQuery ?? `assignee = currentUser() AND resolution = Unresolved order by updated DESC` - }` + }`, ), { method: "GET", @@ -149,13 +152,12 @@ export class JiraClient { ...this.authHeader, }, }, - this.options.requestOptions ); if (response.status != 200) { console.warn( "Unable to get jira tickets. Response code from API is", - response.status + response.status, ); return Promise.resolve([]); } diff --git a/core/context/providers/JiraIssuesContextProvider/index.ts b/core/context/providers/JiraIssuesContextProvider/index.ts index 4ee7a4b6a1..e1ca1841e2 100644 --- a/core/context/providers/JiraIssuesContextProvider/index.ts +++ b/core/context/providers/JiraIssuesContextProvider/index.ts @@ -29,12 +29,12 @@ class JiraIssuesContextProvider extends BaseContextProvider { async getContextItems( query: string, - extras: ContextProviderExtras + extras: ContextProviderExtras, ): Promise { const issueId = query; const api = this.getApi(); - const issue = await api.issue(query); + const issue = await api.issue(query, extras.fetch); const parts = [ `# Jira Issue ${issue.key}: ${issue.summary}`, @@ -48,7 +48,7 @@ class JiraIssuesContextProvider extends BaseContextProvider { parts.push( ...issue.comments.map((comment) => { return `### ${comment.author.displayName} on ${comment.created}\n\n${comment.body}`; - }) + }), ); } @@ -64,12 +64,12 @@ class JiraIssuesContextProvider extends BaseContextProvider { } async loadSubmenuItems( - args: LoadSubmenuItemsArgs + args: LoadSubmenuItemsArgs, ): Promise { const api = await this.getApi(); try { - const issues = await api.listIssues(); + const issues = await api.listIssues(args.fetch); return issues.map((issue) => ({ id: issue.id, diff --git a/core/index.d.ts b/core/index.d.ts index de35a97d9c..df650f33dd 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -122,6 +122,8 @@ export interface ContextProviderDescription { type: ContextProviderType; } +export type FetchFunction = (url: string | URL, init?: any) => Promise; + export interface ContextProviderExtras { fullInput: string; embeddingsProvider: EmbeddingsProvider; @@ -129,10 +131,12 @@ export interface ContextProviderExtras { llm: ILLM; ide: IDE; selectedCode: RangeInFile[]; + fetch: FetchFunction; } export interface LoadSubmenuItemsArgs { ide: IDE; + fetch: FetchFunction; } export interface CustomContextProvider { @@ -706,6 +710,7 @@ export interface SerializedContinueConfig { models: ModelDescription[]; systemMessage?: string; completionOptions?: BaseCompletionOptions; + requestOptions?: RequestOptions; slashCommands?: SlashCommandDescription[]; customCommands?: CustomCommand[]; contextProviders?: ContextProviderWithParams[]; @@ -738,6 +743,8 @@ export interface Config { systemMessage?: string; /** The default completion options for all models */ completionOptions?: BaseCompletionOptions; + /** Request options that will be applied to all models and context providers */ + requestOptions?: RequestOptions; /** The list of slash commands that will be available in the sidebar */ slashCommands?: SlashCommand[]; /** Each entry in this array will originally be a ContextProviderWithParams, the same object from your config.json, but you may add CustomContextProviders. @@ -769,6 +776,7 @@ export interface ContinueConfig { models: ILLM[]; systemMessage?: string; completionOptions?: BaseCompletionOptions; + requestOptions?: RequestOptions; slashCommands?: SlashCommand[]; contextProviders?: IContextProvider[]; disableSessionTitles?: boolean; @@ -787,6 +795,7 @@ export interface BrowserSerializedContinueConfig { models: ModelDescription[]; systemMessage?: string; completionOptions?: BaseCompletionOptions; + requestOptions?: RequestOptions; slashCommands?: SlashCommandDescription[]; contextProviders?: ContextProviderDescription[]; disableIndexing?: boolean; diff --git a/core/indexing/embeddings/BaseEmbeddingsProvider.ts b/core/indexing/embeddings/BaseEmbeddingsProvider.ts index a4e46ebb16..e15bf9c762 100644 --- a/core/indexing/embeddings/BaseEmbeddingsProvider.ts +++ b/core/indexing/embeddings/BaseEmbeddingsProvider.ts @@ -1,18 +1,20 @@ -import { EmbedOptions, EmbeddingsProvider } from "../.."; +import { EmbedOptions, EmbeddingsProvider, FetchFunction } from "../.."; class BaseEmbeddingsProvider implements EmbeddingsProvider { options: EmbedOptions; + fetch: FetchFunction; static defaultOptions: Partial | undefined = undefined; get id(): string { throw new Error("Method not implemented."); } - constructor(options: EmbedOptions) { + constructor(options: EmbedOptions, fetch: FetchFunction) { this.options = { ...(this.constructor as typeof BaseEmbeddingsProvider).defaultOptions, ...options, }; + this.fetch = fetch; } embed(chunks: string[]): Promise { diff --git a/core/indexing/embeddings/DeepInfraEmbeddingsProvider.ts b/core/indexing/embeddings/DeepInfraEmbeddingsProvider.ts index 86d0bed507..60560fe3a8 100644 --- a/core/indexing/embeddings/DeepInfraEmbeddingsProvider.ts +++ b/core/indexing/embeddings/DeepInfraEmbeddingsProvider.ts @@ -14,13 +14,16 @@ class DeepInfraEmbeddingsProvider extends BaseEmbeddingsProvider { async embed(chunks: string[]) { const fetchWithBackoff = () => withExponentialBackoff(() => - fetch(`https://api.deepinfra.com/v1/inference/${this.options.model}`, { - method: "POST", - headers: { - Authorization: `bearer ${this.options.apiKey}`, + this.fetch( + `https://api.deepinfra.com/v1/inference/${this.options.model}`, + { + method: "POST", + headers: { + Authorization: `bearer ${this.options.apiKey}`, + }, + body: JSON.stringify({ inputs: chunks }), }, - body: JSON.stringify({ inputs: chunks }), - }), + ), ); const resp = await fetchWithBackoff(); const data = await resp.json(); diff --git a/core/indexing/embeddings/FreeTrialEmbeddingsProvider.ts b/core/indexing/embeddings/FreeTrialEmbeddingsProvider.ts index a7821ebdb1..b5622788de 100644 --- a/core/indexing/embeddings/FreeTrialEmbeddingsProvider.ts +++ b/core/indexing/embeddings/FreeTrialEmbeddingsProvider.ts @@ -1,4 +1,4 @@ -import fetch, { Response } from "node-fetch"; +import { Response } from "node-fetch"; import { EmbedOptions } from "../.."; import { getHeaders } from "../../continueServer/stubs/headers"; import { SERVER_URL } from "../../util/parameters"; @@ -31,7 +31,7 @@ class FreeTrialEmbeddingsProvider extends BaseEmbeddingsProvider { batchedChunks.map(async (batch) => { const fetchWithBackoff = () => withExponentialBackoff(() => - fetch(new URL("embeddings", SERVER_URL), { + this.fetch(new URL("embeddings", SERVER_URL), { method: "POST", body: JSON.stringify({ input: batch, diff --git a/core/indexing/embeddings/OllamaEmbeddingsProvider.ts b/core/indexing/embeddings/OllamaEmbeddingsProvider.ts index 83ef42f99c..e66d2f3e80 100644 --- a/core/indexing/embeddings/OllamaEmbeddingsProvider.ts +++ b/core/indexing/embeddings/OllamaEmbeddingsProvider.ts @@ -1,11 +1,15 @@ -import { EmbedOptions } from "../.."; +import { EmbedOptions, FetchFunction } from "../.."; import { withExponentialBackoff } from "../../util/withExponentialBackoff"; import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider"; -async function embedOne(chunk: string, options: EmbedOptions) { +async function embedOne( + chunk: string, + options: EmbedOptions, + customFetch: FetchFunction, +) { const fetchWithBackoff = () => withExponentialBackoff(() => - fetch(new URL("api/embeddings", options.apiBase), { + customFetch(new URL("api/embeddings", options.apiBase), { method: "POST", body: JSON.stringify({ model: options.model, @@ -33,7 +37,7 @@ class OllamaEmbeddingsProvider extends BaseEmbeddingsProvider { async embed(chunks: string[]) { const results: any = []; for (const chunk of chunks) { - results.push(await embedOne(chunk, this.options)); + results.push(await embedOne(chunk, this.options, this.fetch)); } return results; } diff --git a/core/indexing/embeddings/OpenAIEmbeddingsProvider.ts b/core/indexing/embeddings/OpenAIEmbeddingsProvider.ts index 1bb2b5fdfe..36987cb80a 100644 --- a/core/indexing/embeddings/OpenAIEmbeddingsProvider.ts +++ b/core/indexing/embeddings/OpenAIEmbeddingsProvider.ts @@ -1,4 +1,4 @@ -import fetch, { Response } from "node-fetch"; +import { Response } from "node-fetch"; import { EmbedOptions } from "../.."; import { withExponentialBackoff } from "../../util/withExponentialBackoff"; import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider"; @@ -37,7 +37,7 @@ class OpenAIEmbeddingsProvider extends BaseEmbeddingsProvider { batchedChunks.map(async (batch) => { const fetchWithBackoff = () => withExponentialBackoff(() => - fetch(new URL("embeddings", this.options.apiBase), { + this.fetch(new URL("embeddings", this.options.apiBase), { method: "POST", body: JSON.stringify({ input: batch, diff --git a/core/indexing/embeddings/TransformersJsEmbeddingsProvider.ts b/core/indexing/embeddings/TransformersJsEmbeddingsProvider.ts index 249aaf36fb..6efcf0d915 100644 --- a/core/indexing/embeddings/TransformersJsEmbeddingsProvider.ts +++ b/core/indexing/embeddings/TransformersJsEmbeddingsProvider.ts @@ -32,7 +32,7 @@ export class TransformersJsEmbeddingsProvider extends BaseEmbeddingsProvider { static MaxGroupSize: number = 4; constructor() { - super({ model: "all-MiniLM-L2-v6" }); + super({ model: "all-MiniLM-L2-v6" }, () => Promise.resolve(null)); } get id(): string { diff --git a/core/util/fetchWithOptions.ts b/core/util/fetchWithOptions.ts index 0ec07c2d77..6fe1e419c3 100644 --- a/core/util/fetchWithOptions.ts +++ b/core/util/fetchWithOptions.ts @@ -8,10 +8,15 @@ import tls from "tls"; import { RequestOptions } from ".."; export function fetchwithRequestOptions( - url: URL, + url_: URL | string, init: RequestInit, requestOptions?: RequestOptions, ): Promise { + let url = url_; + if (typeof url === "string") { + url = new URL(url); + } + const TIMEOUT = 7200; // 7200 seconds = 2 hours let globalCerts: string[] = []; diff --git a/docs/static/schemas/config.json b/docs/static/schemas/config.json index 682af2661e..aa062fb7ca 100644 --- a/docs/static/schemas/config.json +++ b/docs/static/schemas/config.json @@ -1621,6 +1621,15 @@ } ] }, + "requestOptions": { + "title": "Request Options", + "description": "Default request options for all fetch requests from models and context providers. These will be overriden by any model-specific request options.", + "allOf": [ + { + "$ref": "#/definitions/RequestOptions" + } + ] + }, "slashCommands": { "title": "Slash Commands", "markdownDescription": "An array of slash commands that let you take custom actions from the sidebar. Learn more in the [documentation](https://continue.dev/docs/customization/slash-commands).", diff --git a/extensions/intellij/src/main/resources/config_schema.json b/extensions/intellij/src/main/resources/config_schema.json index 682af2661e..aa062fb7ca 100644 --- a/extensions/intellij/src/main/resources/config_schema.json +++ b/extensions/intellij/src/main/resources/config_schema.json @@ -1621,6 +1621,15 @@ } ] }, + "requestOptions": { + "title": "Request Options", + "description": "Default request options for all fetch requests from models and context providers. These will be overriden by any model-specific request options.", + "allOf": [ + { + "$ref": "#/definitions/RequestOptions" + } + ] + }, "slashCommands": { "title": "Slash Commands", "markdownDescription": "An array of slash commands that let you take custom actions from the sidebar. Learn more in the [documentation](https://continue.dev/docs/customization/slash-commands).", diff --git a/extensions/vscode/config_schema.json b/extensions/vscode/config_schema.json index 682af2661e..aa062fb7ca 100644 --- a/extensions/vscode/config_schema.json +++ b/extensions/vscode/config_schema.json @@ -1621,6 +1621,15 @@ } ] }, + "requestOptions": { + "title": "Request Options", + "description": "Default request options for all fetch requests from models and context providers. These will be overriden by any model-specific request options.", + "allOf": [ + { + "$ref": "#/definitions/RequestOptions" + } + ] + }, "slashCommands": { "title": "Slash Commands", "markdownDescription": "An array of slash commands that let you take custom actions from the sidebar. Learn more in the [documentation](https://continue.dev/docs/customization/slash-commands).", diff --git a/extensions/vscode/continue_rc_schema.json b/extensions/vscode/continue_rc_schema.json index d560b5e3c0..f8b82f8ad1 100644 --- a/extensions/vscode/continue_rc_schema.json +++ b/extensions/vscode/continue_rc_schema.json @@ -1810,6 +1810,15 @@ } ] }, + "requestOptions": { + "title": "Request Options", + "description": "Default request options for all fetch requests from models and context providers. These will be overriden by any model-specific request options.", + "allOf": [ + { + "$ref": "#/definitions/RequestOptions" + } + ] + }, "slashCommands": { "title": "Slash Commands", "markdownDescription": "An array of slash commands that let you take custom actions from the sidebar. Learn more in the [documentation](https://continue.dev/docs/customization/slash-commands).", diff --git a/extensions/vscode/src/commands.ts b/extensions/vscode/src/commands.ts index fa85a3ea26..57e01d9e77 100644 --- a/extensions/vscode/src/commands.ts +++ b/extensions/vscode/src/commands.ts @@ -7,6 +7,7 @@ import { IDE } from "core"; import { AutocompleteOutcome } from "core/autocomplete/completionProvider"; import { ConfigHandler } from "core/config/handler"; import { logDevData } from "core/util/devdata"; +import { fetchwithRequestOptions } from "core/util/fetchWithOptions"; import { getConfigJsonPath } from "core/util/paths"; import { Telemetry } from "core/util/posthog"; import { ContinueGUIWebviewViewProvider } from "./debugPanel"; @@ -289,6 +290,8 @@ const commandsMap: ( llm, fullInput: text || "", selectedCode: [], + fetch: (url, init) => + fetchwithRequestOptions(url, init, config.requestOptions), }); }) || [], ) diff --git a/extensions/vscode/src/webviewProtocol.ts b/extensions/vscode/src/webviewProtocol.ts index 092c65e991..a871238364 100644 --- a/extensions/vscode/src/webviewProtocol.ts +++ b/extensions/vscode/src/webviewProtocol.ts @@ -10,6 +10,7 @@ import { indexDocs } from "core/indexing/docs"; import TransformersJsEmbeddingsProvider from "core/indexing/embeddings/TransformersJsEmbeddingsProvider"; import { logDevData } from "core/util/devdata"; import { DevDataSqliteDb } from "core/util/devdataSqlite"; +import { fetchwithRequestOptions } from "core/util/fetchWithOptions"; import historyManager from "core/util/history"; import { Message } from "core/util/messenger"; import { editConfigJson, getConfigJsonPath } from "core/util/paths"; @@ -470,7 +471,11 @@ export class VsCodeWebviewProtocol { } try { - const items = await provider.loadSubmenuItems({ ide }); + const items = await provider.loadSubmenuItems({ + ide, + fetch: (url, init) => + fetchwithRequestOptions(url, init, config.requestOptions), + }); return items; } catch (e) { vscode.window.showErrorMessage( @@ -508,6 +513,8 @@ export class VsCodeWebviewProtocol { fullInput, ide, selectedCode, + fetch: (url, init) => + fetchwithRequestOptions(url, init, config.requestOptions), }); Telemetry.capture("useContextProvider", {