Skip to content

Commit

Permalink
✨ global request options (#1153)
Browse files Browse the repository at this point in the history
* ✨ global request options

* 🐛 fix jira context provider by injecting fetch

* ✨ request options for embeddings providers
  • Loading branch information
sestinj committed Apr 24, 2024
1 parent 9c4a1a4 commit 70952eb
Show file tree
Hide file tree
Showing 22 changed files with 139 additions and 53 deletions.
9 changes: 8 additions & 1 deletion binary/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { indexDocs } from "core/indexing/docs";
import TransformersJsEmbeddingsProvider from "core/indexing/embeddings/TransformersJsEmbeddingsProvider";
import { CodebaseIndexer, PauseToken } from "core/indexing/indexCodebase";
import { logDevData } from "core/util/devdata";
import { fetchwithRequestOptions } from "core/util/fetchWithOptions";
import historyManager from "core/util/history";
import { Message } from "core/util/messenger";
import { Telemetry } from "core/util/posthog";
Expand Down Expand Up @@ -138,7 +139,11 @@ export class Core {
const config = await this.config();
const items = config.contextProviders
?.find((provider) => provider.description.title === msg.data.title)
?.loadSubmenuItems({ ide: this.ide });
?.loadSubmenuItems({
ide: this.ide,
fetch: (url, init) =>
fetchwithRequestOptions(url, init, config.requestOptions),
});
return items || [];
});
on("context/getContextItems", async (msg) => {
Expand All @@ -160,6 +165,8 @@ export class Core {
ide,
selectedCode: msg.data.selectedCode,
reranker: config.reranker,
fetch: (url, init) =>
fetchwithRequestOptions(url, init, config.requestOptions),
});

Telemetry.capture("useContextProvider", {
Expand Down
3 changes: 2 additions & 1 deletion core/config/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ export class ConfigHandler {
const resp = await fetchwithRequestOptions(
new URL(input),
{ ...init },
llm.requestOptions,
{ ...llm.requestOptions, ...this.savedConfig?.requestOptions },
);

if (!resp.ok) {
let text = await resp.text();
if (resp.status === 404 && !resp.url.includes("/v1")) {
Expand Down
7 changes: 6 additions & 1 deletion core/config/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { BaseLLM } from "../llm";
import { llmFromDescription } from "../llm/llms";
import CustomLLMClass from "../llm/llms/CustomLLM";
import { copyOf } from "../util";
import { fetchwithRequestOptions } from "../util/fetchWithOptions";
import mergeJson from "../util/merge";
import {
getConfigJsPath,
Expand Down Expand Up @@ -279,7 +280,11 @@ async function intermediateToFinalConfig(
const { provider, ...options } = embeddingsProviderDescription;
const embeddingsProviderClass = AllEmbeddingsProviders[provider];
if (embeddingsProviderClass) {
config.embeddingsProvider = new embeddingsProviderClass(options);
config.embeddingsProvider = new embeddingsProviderClass(
options,
(url: string | URL, init: any) =>
fetchwithRequestOptions(url, init, config.requestOptions),
);
}
}

Expand Down
6 changes: 6 additions & 0 deletions core/context/providers/GitHubIssuesContextProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class GitHubIssuesContextProvider extends BaseContextProvider {

const octokit = new Octokit({
auth: this.options?.githubToken,
request: {
fetch: extras.fetch,
},
});

const { owner, repo, issue_number } = JSON.parse(issueId);
Expand Down Expand Up @@ -64,6 +67,9 @@ class GitHubIssuesContextProvider extends BaseContextProvider {

const octokit = new Octokit({
auth: this.options?.githubToken,
request: {
fetch: args.fetch,
},
});

const allIssues = [];
Expand Down
2 changes: 1 addition & 1 deletion core/context/providers/GoogleContextProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class GoogleContextProvider extends BaseContextProvider {
"Content-Type": "application/json",
};

const response = await fetch(url, {
const response = await extras.fetch(url, {
method: "POST",
headers: headers,
body: payload,
Expand Down
26 changes: 11 additions & 15 deletions core/context/providers/HttpContextProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
ContextProviderDescription,
ContextProviderExtras,
} from "../..";
import { fetchwithRequestOptions } from "../../util/fetchWithOptions";

class HttpContextProvider extends BaseContextProvider {
static description: ContextProviderDescription = {
Expand All @@ -29,20 +28,17 @@ class HttpContextProvider extends BaseContextProvider {
query: string,
extras: ContextProviderExtras,
): Promise<ContextItem[]> {
const response = await fetchwithRequestOptions(
new URL(this.options.url),
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
query: query || "",
fullInput: extras.fullInput,
}),
}
);

const response = await extras.fetch(new URL(this.options.url), {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
query: query || "",
fullInput: extras.fullInput,
}),
});

const json: any = await response.json();
return [
{
Expand Down
22 changes: 12 additions & 10 deletions core/context/providers/JiraIssuesContextProvider/JiraClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { RequestOptions } from "../../..";
import { fetchwithRequestOptions } from "../../../util/fetchWithOptions";
const { convert: adf2md } = require("adf-to-md");

interface JiraClientOptions {
Expand Down Expand Up @@ -85,12 +84,15 @@ export class JiraClient {
};
}

async issue(issueId: string): Promise<Issue> {
async issue(
issueId: string,
customFetch: (url: string | URL, init: any) => Promise<any>,
): Promise<Issue> {
const result = {} as Issue;

const response = await fetchwithRequestOptions(
const response = await customFetch(
new URL(
this.baseUrl + `/issue/${issueId}?fields=description,comment,summary`
this.baseUrl + `/issue/${issueId}?fields=description,comment,summary`,
),
{
method: "GET",
Expand All @@ -99,7 +101,6 @@ export class JiraClient {
...this.authHeader,
},
},
this.options.requestOptions
);

const issue = (await response.json()) as any;
Expand Down Expand Up @@ -133,14 +134,16 @@ export class JiraClient {
return result;
}

async listIssues(): Promise<Array<QueryResult>> {
const response = await fetchwithRequestOptions(
async listIssues(
customFetch: (url: string | URL, init: any) => Promise<any>,
): Promise<Array<QueryResult>> {
const response = await customFetch(
new URL(
this.baseUrl +
`/search?fields=summary&jql=${
this.options.issueQuery ??
`assignee = currentUser() AND resolution = Unresolved order by updated DESC`
}`
}`,
),
{
method: "GET",
Expand All @@ -149,13 +152,12 @@ export class JiraClient {
...this.authHeader,
},
},
this.options.requestOptions
);

if (response.status != 200) {
console.warn(
"Unable to get jira tickets. Response code from API is",
response.status
response.status,
);
return Promise.resolve([]);
}
Expand Down
10 changes: 5 additions & 5 deletions core/context/providers/JiraIssuesContextProvider/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class JiraIssuesContextProvider extends BaseContextProvider {

async getContextItems(
query: string,
extras: ContextProviderExtras
extras: ContextProviderExtras,
): Promise<ContextItem[]> {
const issueId = query;

const api = this.getApi();
const issue = await api.issue(query);
const issue = await api.issue(query, extras.fetch);

const parts = [
`# Jira Issue ${issue.key}: ${issue.summary}`,
Expand All @@ -48,7 +48,7 @@ class JiraIssuesContextProvider extends BaseContextProvider {
parts.push(
...issue.comments.map((comment) => {
return `### ${comment.author.displayName} on ${comment.created}\n\n${comment.body}`;
})
}),
);
}

Expand All @@ -64,12 +64,12 @@ class JiraIssuesContextProvider extends BaseContextProvider {
}

async loadSubmenuItems(
args: LoadSubmenuItemsArgs
args: LoadSubmenuItemsArgs,
): Promise<ContextSubmenuItem[]> {
const api = await this.getApi();

try {
const issues = await api.listIssues();
const issues = await api.listIssues(args.fetch);

return issues.map((issue) => ({
id: issue.id,
Expand Down
9 changes: 9 additions & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,21 @@ export interface ContextProviderDescription {
type: ContextProviderType;
}

export type FetchFunction = (url: string | URL, init?: any) => Promise<any>;

export interface ContextProviderExtras {
fullInput: string;
embeddingsProvider: EmbeddingsProvider;
reranker: Reranker | undefined;
llm: ILLM;
ide: IDE;
selectedCode: RangeInFile[];
fetch: FetchFunction;
}

export interface LoadSubmenuItemsArgs {
ide: IDE;
fetch: FetchFunction;
}

export interface CustomContextProvider {
Expand Down Expand Up @@ -706,6 +710,7 @@ export interface SerializedContinueConfig {
models: ModelDescription[];
systemMessage?: string;
completionOptions?: BaseCompletionOptions;
requestOptions?: RequestOptions;
slashCommands?: SlashCommandDescription[];
customCommands?: CustomCommand[];
contextProviders?: ContextProviderWithParams[];
Expand Down Expand Up @@ -738,6 +743,8 @@ export interface Config {
systemMessage?: string;
/** The default completion options for all models */
completionOptions?: BaseCompletionOptions;
/** Request options that will be applied to all models and context providers */
requestOptions?: RequestOptions;
/** The list of slash commands that will be available in the sidebar */
slashCommands?: SlashCommand[];
/** Each entry in this array will originally be a ContextProviderWithParams, the same object from your config.json, but you may add CustomContextProviders.
Expand Down Expand Up @@ -769,6 +776,7 @@ export interface ContinueConfig {
models: ILLM[];
systemMessage?: string;
completionOptions?: BaseCompletionOptions;
requestOptions?: RequestOptions;
slashCommands?: SlashCommand[];
contextProviders?: IContextProvider[];
disableSessionTitles?: boolean;
Expand All @@ -787,6 +795,7 @@ export interface BrowserSerializedContinueConfig {
models: ModelDescription[];
systemMessage?: string;
completionOptions?: BaseCompletionOptions;
requestOptions?: RequestOptions;
slashCommands?: SlashCommandDescription[];
contextProviders?: ContextProviderDescription[];
disableIndexing?: boolean;
Expand Down
6 changes: 4 additions & 2 deletions core/indexing/embeddings/BaseEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import { EmbedOptions, EmbeddingsProvider } from "../..";
import { EmbedOptions, EmbeddingsProvider, FetchFunction } from "../..";

class BaseEmbeddingsProvider implements EmbeddingsProvider {
options: EmbedOptions;
fetch: FetchFunction;
static defaultOptions: Partial<EmbedOptions> | undefined = undefined;

get id(): string {
throw new Error("Method not implemented.");
}

constructor(options: EmbedOptions) {
constructor(options: EmbedOptions, fetch: FetchFunction) {
this.options = {
...(this.constructor as typeof BaseEmbeddingsProvider).defaultOptions,
...options,
};
this.fetch = fetch;
}

embed(chunks: string[]): Promise<number[][]> {
Expand Down
15 changes: 9 additions & 6 deletions core/indexing/embeddings/DeepInfraEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ class DeepInfraEmbeddingsProvider extends BaseEmbeddingsProvider {
async embed(chunks: string[]) {
const fetchWithBackoff = () =>
withExponentialBackoff<Response>(() =>
fetch(`https://api.deepinfra.com/v1/inference/${this.options.model}`, {
method: "POST",
headers: {
Authorization: `bearer ${this.options.apiKey}`,
this.fetch(
`https://api.deepinfra.com/v1/inference/${this.options.model}`,
{
method: "POST",
headers: {
Authorization: `bearer ${this.options.apiKey}`,
},
body: JSON.stringify({ inputs: chunks }),
},
body: JSON.stringify({ inputs: chunks }),
}),
),
);
const resp = await fetchWithBackoff();
const data = await resp.json();
Expand Down
4 changes: 2 additions & 2 deletions core/indexing/embeddings/FreeTrialEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import fetch, { Response } from "node-fetch";
import { Response } from "node-fetch";
import { EmbedOptions } from "../..";
import { getHeaders } from "../../continueServer/stubs/headers";
import { SERVER_URL } from "../../util/parameters";
Expand Down Expand Up @@ -31,7 +31,7 @@ class FreeTrialEmbeddingsProvider extends BaseEmbeddingsProvider {
batchedChunks.map(async (batch) => {
const fetchWithBackoff = () =>
withExponentialBackoff<Response>(() =>
fetch(new URL("embeddings", SERVER_URL), {
this.fetch(new URL("embeddings", SERVER_URL), {
method: "POST",
body: JSON.stringify({
input: batch,
Expand Down
12 changes: 8 additions & 4 deletions core/indexing/embeddings/OllamaEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import { EmbedOptions } from "../..";
import { EmbedOptions, FetchFunction } from "../..";
import { withExponentialBackoff } from "../../util/withExponentialBackoff";
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider";

async function embedOne(chunk: string, options: EmbedOptions) {
async function embedOne(
chunk: string,
options: EmbedOptions,
customFetch: FetchFunction,
) {
const fetchWithBackoff = () =>
withExponentialBackoff<Response>(() =>
fetch(new URL("api/embeddings", options.apiBase), {
customFetch(new URL("api/embeddings", options.apiBase), {
method: "POST",
body: JSON.stringify({
model: options.model,
Expand Down Expand Up @@ -33,7 +37,7 @@ class OllamaEmbeddingsProvider extends BaseEmbeddingsProvider {
async embed(chunks: string[]) {
const results: any = [];
for (const chunk of chunks) {
results.push(await embedOne(chunk, this.options));
results.push(await embedOne(chunk, this.options, this.fetch));
}
return results;
}
Expand Down
4 changes: 2 additions & 2 deletions core/indexing/embeddings/OpenAIEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import fetch, { Response } from "node-fetch";
import { Response } from "node-fetch";
import { EmbedOptions } from "../..";
import { withExponentialBackoff } from "../../util/withExponentialBackoff";
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider";
Expand Down Expand Up @@ -37,7 +37,7 @@ class OpenAIEmbeddingsProvider extends BaseEmbeddingsProvider {
batchedChunks.map(async (batch) => {
const fetchWithBackoff = () =>
withExponentialBackoff<Response>(() =>
fetch(new URL("embeddings", this.options.apiBase), {
this.fetch(new URL("embeddings", this.options.apiBase), {
method: "POST",
body: JSON.stringify({
input: batch,
Expand Down
Loading

0 comments on commit 70952eb

Please sign in to comment.