Skip to content

Commit

Permalink
feat: add support for Anthropic (Claude) LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
chhoumann committed Mar 18, 2024
1 parent 5b0e20b commit 19c911a
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/ai/AIAssistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ export async function runAIAssistant(
}
);

const output = result.choices[0].message.content;
const output = result.content;
const outputInMarkdownBlockQuote = ("> " + output).replace(
/\n/g,
"\n> "
Expand Down Expand Up @@ -266,7 +266,7 @@ export async function Prompt(
}
);

const output = result.choices[0].message.content;
const output = result.content;
const outputInMarkdownBlockQuote = ("> " + output).replace(
/\n/g,
"\n> "
Expand Down Expand Up @@ -486,7 +486,7 @@ export async function ChunkedPrompt(
}
);

const outputs = result.map((r) => r.choices[0].message.content);
const outputs = result.map((r) => r.content);

const output = outputs.join(settings.resultJoiner);
const outputInMarkdownBlockQuote = ("> " + output).replace(
Expand Down
178 changes: 153 additions & 25 deletions src/ai/OpenAIRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,61 @@ import type { OpenAIModelParameters } from "./OpenAIModelParameters";
import { settingsStore } from "src/settingsStore";
import { getTokenCount } from "./AIAssistant";
import { preventCursorChange } from "./preventCursorChange";
import type { Model } from "./Provider";
import type { AIProvider, Model } from "./Provider";
import { getModelProvider } from "./aiHelpers";

type ReqResponse = {
export interface CommonResponse {
id: string;
model: string;
content: string;
usage: {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
stopReason: string;
stopSequence: string | null;
created: number;
}

function mapOpenAIResponseToCommon(
response: OpenAIReqResponse
): CommonResponse {
return {
id: response.id,
model: response.model,
content: response.choices[0].message.content,
usage: {
promptTokens: response.usage.prompt_tokens,
completionTokens: response.usage.completion_tokens,
totalTokens: response.usage.total_tokens,
},
stopReason: response.choices[0].finish_reason,
stopSequence: null,
created: response.created,
};
}

function mapAnthropicResponseToCommon(
response: AnthropicResponse
): CommonResponse {
return {
id: response.id,
model: response.model,
content: response.content[0].text,
usage: {
promptTokens: response.usage.input_tokens,
completionTokens: response.usage.output_tokens,
totalTokens:
response.usage.input_tokens + response.usage.output_tokens,
},
stopReason: response.stop_reason,
stopSequence: response.stop_sequence,
created: Date.now(),
};
}

type OpenAIReqResponse = {
id: string;
model: string;
object: string;
Expand All @@ -23,13 +74,85 @@ type ReqResponse = {
created: number;
};

export interface AnthropicResponse {
content: { text: string; type: string }[];
id: string;
model: string;
role: string;
stop_reason: string;
stop_sequence: null;
type: string;
usage: { input_tokens: number; output_tokens: number };
}

async function makeOpenAIRequest(
apiKey: string,
model: Model,
modelProvider: AIProvider,
systemPrompt: string,
modelParams: Partial<OpenAIModelParameters>,
prompt: string,
afterRequestCallback?: () => void
): Promise<OpenAIReqResponse> {
const _response = requestUrl({
url: `${modelProvider.endpoint}/chat/completions`,
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({
model: model.name,
...modelParams,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: prompt },
],
}),
});

if (afterRequestCallback) afterRequestCallback();

const response = await _response;
return response.json as OpenAIReqResponse;
}

async function makeAnthropicRequest(
apiKey: string,
model: Model,
modelProvider: AIProvider,
modelParams: Partial<OpenAIModelParameters>,
prompt: string,
afterRequestCallback?: () => void
): Promise<AnthropicResponse> {
const _response = requestUrl({
url: `${modelProvider.endpoint}/v1/messages`,
method: "POST",
headers: {
"Content-Type": "application/json",
"x-api-key": apiKey,
"anthropic-version": "2023-06-01",
},
body: JSON.stringify({
model: model.name,
max_tokens: 4096,
messages: [{ role: "user", content: prompt }],
}),
});

if (afterRequestCallback) afterRequestCallback();

const response = await _response;
return response.json as AnthropicResponse;
}

export function OpenAIRequest(
apiKey: string,
model: Model,
systemPrompt: string,
modelParams: Partial<OpenAIModelParameters> = {}
) {
return async function makeRequest(prompt: string) {
): (prompt: string) => Promise<CommonResponse> {
return async function makeRequest(prompt: string): Promise<CommonResponse> {
if (settingsStore.getState().disableOnlineFeatures) {
throw new Error(
"Blocking request to OpenAI: Online features are disabled in settings."
Expand All @@ -54,27 +177,32 @@ export function OpenAIRequest(

try {
const restoreCursor = preventCursorChange();
const _response = requestUrl({
url: `${modelProvider.endpoint}/chat/completions`,
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({
model: model.name,
...modelParams,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: prompt },
],
}),
});
restoreCursor();

const response = await _response;

return response.json as ReqResponse;

let response: CommonResponse;
if (modelProvider.name === "Anthropic") {
const anthropicResponse = await makeAnthropicRequest(
apiKey,
model,
modelProvider,
modelParams,
prompt,
restoreCursor
);
response = mapAnthropicResponseToCommon(anthropicResponse);
} else {
const openaiResponse = await makeOpenAIRequest(
apiKey,
model,
modelProvider,
systemPrompt,
modelParams,
prompt,
restoreCursor
);
response = mapOpenAIResponseToCommon(openaiResponse);
}

return response;
} catch (error) {
console.log(error);
throw new Error(
Expand Down

0 comments on commit 19c911a

Please sign in to comment.