Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new providers: Koboldcpp and Goinfer #22

Merged
merged 5 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@
"default": "openai",
"enum": [
"anthropic",
"openai"
"openai",
"goinfer",
"koboldcpp"
],
"description": "Which provider should this command use?"
}
Expand Down Expand Up @@ -348,6 +350,7 @@
"cheerio": "1.0.0-rc.12",
"fast-glob": "^3.2.12",
"fetch": "^1.1.0",
"llama-tokenizer-js": "^1.1.3",
"node-fetch": "^3.3.1"
}
}
32 changes: 31 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

155 changes: 155 additions & 0 deletions src/providers/goinfer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/* eslint-disable unused-imports/no-unused-vars */
import * as vscode from "vscode";

import { type PostableViewProvider, type ProviderResponse, type Provider } from ".";
import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_CTX, DEFAULT_TEMPLATE } from "./sdks/goinfer";
import { type Command } from "../templates/render";
import { handleResponseCallbackType } from "../templates/runner";
import { displayWarning, getConfig, getSecret, getSelectionInfo, llamaMaxTokens, setSecret, unsetConfig } from "../utils";

let lastMessage: string | undefined;
let lastTemplate: Command | undefined;
let lastSystemMessage: string | undefined;

export class GoinferProvider implements Provider {
viewProvider: PostableViewProvider | undefined;
instance: Client | undefined;
conversationTextHistory: string | undefined;
_abort: AbortController = new AbortController();

async create(provider: PostableViewProvider, template: Command) {
const apiKey = await getSecret<string>("openai.apiKey", "");

// If the user still uses the now deprecated openai.apiKey config, move it to the secrets store
// and unset the config.
if (getConfig<string>("openai.apiKey")) {
setSecret("openai.apiKey", getConfig<string>("openai.apiKey"));
unsetConfig("openai.apiKey");
}

const { apiBaseUrl } = {
apiBaseUrl: getConfig("openai.apiBaseUrl") as string | undefined,
};

this.viewProvider = provider;
this.conversationTextHistory = undefined;
this.instance = new Client(apiKey, { apiUrl: apiBaseUrl });
}

destroy() {
this.instance = undefined;
this.conversationTextHistory = undefined;
}

abort() {
this._abort.abort();
this._abort = new AbortController();
}

async send(message: string, systemMessage?: string, template?: Command): Promise<void | ProviderResponse> {
let isFollowup = false;

lastMessage = message;

if (template) {
lastTemplate = template;
}

if (!template && !lastTemplate) {
return;
}

if (!template) {
template = lastTemplate!;
isFollowup = true;
}

if (systemMessage) {
lastSystemMessage = systemMessage;
}
if (!systemMessage && !lastSystemMessage) {
return;
}
if (!systemMessage) {
systemMessage = lastSystemMessage!;
}

let prompt;
if (!isFollowup) {
this.viewProvider?.postMessage({ type: "newChat" });
// The first message should have the system message prepended
prompt = `${message}`;
} else {
// followups should have the conversation history prepended
prompt = `${this.conversationTextHistory ?? ""}${message}`;
}

const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE;
const samplingParameters: InferParams = {
prompt,
template: modelTemplate.replace("{system}", systemMessage),
...template?.completionParams,
temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number),
model: {
name: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2",
ctx: template?.completionParams?.ctx ?? DEFAULT_CTX,
},
n_predict: llamaMaxTokens(prompt, DEFAULT_CTX),
};

try {
this.viewProvider?.postMessage({ type: "requestMessage", value: message });

const editor = vscode.window.activeTextEditor!;
const selection = getSelectionInfo(editor);
let partialText = "";

const goinferResponse: InferResult = await this.instance!.completeStream(samplingParameters, {
onOpen: (response) => {
console.log("Opened stream, HTTP status code", response.status);
},
onUpdate: (partialResponse: StreamedMessage) => {
partialText += partialResponse.content;
// console.log("P", partialText);
const msg = this.toProviderResponse(partialText);
// console.log("MSG:", msg.text);
this.viewProvider?.postMessage({
type: "partialResponse",
value: msg,
});
},
signal: this._abort.signal,
});

// Reformat the API response into a ProvderResponse
const response = this.toProviderResponse(goinferResponse.text);

// Append the last response to the conversation history
this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`;
this.viewProvider?.postMessage({ type: "responseFinished", value: response });

if (!isFollowup) {
handleResponseCallbackType(template, editor, selection, response.text);
}
} catch (error) {
displayWarning(String(error));
}
}

async repeatLast() {
if (!lastMessage || !lastSystemMessage || !lastTemplate) {
return;
}

await this.send(lastMessage, lastSystemMessage, lastTemplate);
}

toProviderResponse(text: string) {
return {
text,
parentMessageId: "",
converastionId: "",
id: "",
};
}
}
4 changes: 4 additions & 0 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import type * as vscode from "vscode";

import { AnthropicProvider } from "./anthropic";
import { GoinferProvider } from "./goinfer";
import { KoboldcppProvider } from "./koboldcpp";
import { OpenAIProvider } from "./openai";
import { AIProvider, type Command } from "../templates/render";

Expand Down Expand Up @@ -41,4 +43,6 @@ export interface Provider {
export const providers = {
[AIProvider.OpenAI]: OpenAIProvider,
[AIProvider.Anthropic]: AnthropicProvider,
[AIProvider.Goinfer]: GoinferProvider,
[AIProvider.KoboldCpp]: KoboldcppProvider,
};
Loading