Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧠 feat: Cohere support as Custom Endpoint #2328

Merged
merged 7 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseClient {
throw new Error('Method \'setOptions\' must be implemented.');
}

getCompletion() {
async getCompletion() {
throw new Error('Method \'getCompletion\' must be implemented.');
}

Expand Down
40 changes: 39 additions & 1 deletion api/app/clients/ChatGPTClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ const crypto = require('crypto');
const {
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { CohereClient } = require('cohere-ai');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { createCoherePayload } = require('./llm');
const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
Expand Down Expand Up @@ -147,7 +150,8 @@ class ChatGPTClient extends BaseClient {
return tokenizer;
}

async getCompletion(input, onProgress, abortController = null) {
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
Expand Down Expand Up @@ -305,6 +309,11 @@ class ChatGPTClient extends BaseClient {
});
}

if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}

if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
Expand Down Expand Up @@ -408,6 +417,35 @@ class ChatGPTClient extends BaseClient {
return response.json();
}

/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});

if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}

const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}

if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
} else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
}

return reply;
}

async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',
Expand Down
37 changes: 32 additions & 5 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const {
EModelEndpoint,
resolveHeaders,
ImageDetailCost,
CohereConstants,
getResponseSender,
validateVisionModel,
mapModelToAzureConfig,
Expand All @@ -16,7 +17,13 @@ const {
getModelMaxTokens,
genAzureChatCompletion,
} = require('~/utils');
const { truncateText, formatMessage, createContextHandlers, CUT_OFF_PROMPT } = require('./prompts');
const {
truncateText,
formatMessage,
createContextHandlers,
CUT_OFF_PROMPT,
titleInstruction,
} = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
Expand All @@ -39,7 +46,10 @@ class OpenAIClient extends BaseClient {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
Expand All @@ -48,6 +58,9 @@ class OpenAIClient extends BaseClient {
this.azure = options.azure || false;
this.setOptions(options);
this.metadata = {};

/** @type {string | undefined} - The API Completions URL */
this.completionsUrl;
}

// TODO: PluginsClient calls this 3x, unneeded
Expand Down Expand Up @@ -533,6 +546,7 @@ class OpenAIClient extends BaseClient {
return result;
}

/** @type {sendCompletion} */
async sendCompletion(payload, opts = {}) {
let reply = '';
let result = null;
Expand All @@ -541,7 +555,7 @@ class OpenAIClient extends BaseClient {
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
if (typeof opts.onProgress === 'function' && useOldMethod) {
await this.getCompletion(
const completionResult = await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
Expand Down Expand Up @@ -574,8 +588,13 @@ class OpenAIClient extends BaseClient {
opts.onProgress(token);
reply += token;
},
opts.onProgress,
opts.abortController || new AbortController(),
);

if (completionResult && typeof completionResult === 'string') {
reply = completionResult;
}
} else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) {
reply = await this.chatCompletion({
payload,
Expand All @@ -586,9 +605,14 @@ class OpenAIClient extends BaseClient {
result = await this.getCompletion(
payload,
null,
opts.onProgress,
opts.abortController || new AbortController(),
);

if (result && typeof result === 'string') {
return result.trim();
}

logger.debug('[OpenAIClient] sendCompletion: result', result);

if (this.isChatCompletion) {
Expand Down Expand Up @@ -760,8 +784,7 @@ class OpenAIClient extends BaseClient {
const instructionsPayload = [
{
role: 'system',
content: `Detect user language and write in the same language an extremely concise title for this conversation, which you must accurately detect.
Write in the detected language. Title in 5 Words or Less. No Punctuation or Quotation. Do not mention the language. All first letters of every word should be capitalized and write the title in User Language only.
content: `Please generate ${titleInstruction}

${convo}

Expand All @@ -770,8 +793,12 @@ ${convo}
];

try {
let useChatCompletion = true;
if (CohereConstants.API_URL) {
useChatCompletion = false;
}
title = (
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion: true })
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion })
).replaceAll('"', '');
} catch (e) {
logger.error(
Expand Down
85 changes: 85 additions & 0 deletions api/app/clients/llm/createCoherePayload.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
const { CohereConstants } = require('librechat-data-provider');
const { titleInstruction } = require('../prompts/titlePrompts');

// Mapping OpenAI roles to Cohere roles
const roleMap = {
user: CohereConstants.ROLE_USER,
assistant: CohereConstants.ROLE_CHATBOT,
system: CohereConstants.ROLE_SYSTEM, // Recognize and map the system role explicitly
};

/**
* Adjusts an OpenAI ChatCompletionPayload to conform with Cohere's expected chat payload format.
* Now includes handling for "system" roles explicitly mentioned.
*
* @param {Object} options - Object containing the model options.
* @param {ChatCompletionPayload} options.modelOptions - The OpenAI model payload options.
* @returns {CohereChatStreamRequest} Cohere-compatible chat API payload.
*/
function createCoherePayload({ modelOptions }) {
/** @type {string | undefined} */
let preamble;
let latestUserMessageContent = '';
const {
stream,
stop,
top_p,
temperature,
frequency_penalty,
presence_penalty,
max_tokens,
messages,
model,
...rest
} = modelOptions;

// Filter out the latest user message and transform remaining messages to Cohere's chat_history format
let chatHistory = messages.reduce((acc, message, index, arr) => {
const isLastUserMessage = index === arr.length - 1 && message.role === 'user';

const messageContent =
typeof message.content === 'string'
? message.content
: message.content.map((part) => (part.type === 'text' ? part.text : '')).join(' ');

if (isLastUserMessage) {
latestUserMessageContent = messageContent;
} else {
acc.push({
role: roleMap[message.role] || CohereConstants.ROLE_USER,
message: messageContent,
});
}

return acc;
}, []);

if (
chatHistory.length === 1 &&
chatHistory[0].role === CohereConstants.ROLE_SYSTEM &&
!latestUserMessageContent.length
) {
const message = chatHistory[0].message;
latestUserMessageContent = message.includes(titleInstruction)
? CohereConstants.TITLE_MESSAGE
: '.';
preamble = message;
}

return {
message: latestUserMessageContent,
model: model,
chat_history: chatHistory,
stream: stream ?? false,
temperature: temperature,
frequency_penalty: frequency_penalty,
presence_penalty: presence_penalty,
max_tokens: max_tokens,
stop_sequences: stop,
preamble,
p: top_p,
...rest,
};
}

module.exports = createCoherePayload;
2 changes: 2 additions & 0 deletions api/app/clients/llm/index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
const createLLM = require('./createLLM');
const RunManager = require('./RunManager');
const createCoherePayload = require('./createCoherePayload');

module.exports = {
createLLM,
RunManager,
createCoherePayload,
};
5 changes: 4 additions & 1 deletion api/app/clients/prompts/titlePrompts.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ ${convo}`,
return titlePrompt;
};

const titleInstruction =
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.

You may call them like this:
Expand All @@ -51,7 +53,7 @@ Submit a brief title in the conversation's language, following the parameter des
<parameter>
<name>title</name>
<type>string</type>
<description>A concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"</description>
<description>${titleInstruction}</description>
</parameter>
</parameters>
</tool_description>
Expand Down Expand Up @@ -80,6 +82,7 @@ function parseTitleFromPrompt(prompt) {

module.exports = {
langPrompt,
titleInstruction,
createTitlePrompt,
titleFunctionPrompt,
parseTitleFromPrompt,
Expand Down
6 changes: 6 additions & 0 deletions api/models/tx.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const defaultRate = 6;

/**
* Mapping of model token sizes to their respective multipliers for prompt and completion.
* The rates are 1 USD per 1M tokens.
* @type {Object.<string, {prompt: number, completion: number}>}
*/
const tokenValues = {
Expand All @@ -19,6 +20,11 @@ const tokenValues = {
'claude-2.1': { prompt: 8, completion: 24 },
'claude-2': { prompt: 8, completion: 24 },
'claude-': { prompt: 0.8, completion: 2.4 },
'command-r-plus': { prompt: 3, completion: 15 },
'command-r': { prompt: 0.5, completion: 1.5 },
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
};

/**
Expand Down
4 changes: 2 additions & 2 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
"cohere-ai": "^6.0.0",
"cohere-ai": "^7.9.1",
"connect-redis": "^7.1.0",
"cookie": "^0.5.0",
"cors": "^2.8.5",
Expand All @@ -52,7 +52,7 @@
"express-rate-limit": "^6.9.0",
"express-session": "^1.17.3",
"file-type": "^18.7.0",
"firebase": "^10.8.0",
"firebase": "^10.6.0",
"googleapis": "^126.0.1",
"handlebars": "^4.7.7",
"html": "^1.0.0",
Expand Down